ソースを参照

Preserve custom JSON keys

Kentaro Wada 7 年 前
コミット
588dc993ec
2 ファイル変更41 行追加23 行削除
  1. 6 1
      labelme/app.py
  2. 35 22
      labelme/labelFile.py

+ 6 - 1
labelme/app.py

@@ -422,6 +422,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.maxRecent = 7
         self.lineColor = None
         self.fillColor = None
+        self.otherData = None
         self.zoom_level = 100
         self.fit_window = False
 
@@ -551,8 +552,10 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     def resetState(self):
         self.labelList.clear()
         self.filename = None
+        self.imagePath = None
         self.imageData = None
         self.labelFile = None
+        self.otherData = None
         self.canvas.resetState()
 
     def currentItem(self):
@@ -727,7 +730,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                 self.imagePath, os.path.dirname(filename))
             imageData = self.imageData if self._store_data else None
             lf.save(filename, shapes, imagePath, imageData,
-                    self.lineColor.getRgb(), self.fillColor.getRgb())
+                    self.lineColor.getRgb(), self.fillColor.getRgb(),
+                    self.otherData)
             self.labelFile = lf
             # disable allows next and previous image to proceed
             # self.filename = filename
@@ -868,6 +872,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                                           self.labelFile.imagePath)
             self.lineColor = QtGui.QColor(*self.labelFile.lineColor)
             self.fillColor = QtGui.QColor(*self.labelFile.fillColor)
+            self.otherData = self.labelFile.otherData
         else:
             # Load image:
             # read data first and store for saving into label file.

+ 35 - 22
labelme/labelFile.py

@@ -24,37 +24,48 @@ class LabelFile(object):
         self.filename = filename
 
     def load(self, filename):
+        keys = ['imageData', 'imagePath', 'lineColor', 'fillColor', 'shapes']
         try:
             with open(filename, 'rb' if PY2 else 'r') as f:
                 data = json.load(f)
-                if data['imageData'] is not None:
-                    imageData = base64.b64decode(data['imageData'])
-                else:
-                    # relative path from label file to relative path from cwd
-                    imagePath = os.path.join(os.path.dirname(filename),
-                                             data['imagePath'])
-                    with open(imagePath, 'rb') as f:
-                        imageData = f.read()
-                lineColor = data['lineColor']
-                fillColor = data['fillColor']
-                shapes = (
-                    (s['label'], s['points'], s['line_color'], s['fill_color'])
-                    for s in data['shapes']
-                )
-                # Only replace data after everything is loaded.
-                self.shapes = shapes
-                self.imagePath = data['imagePath']
-                self.imageData = imageData
-                self.lineColor = lineColor
-                self.fillColor = fillColor
-                self.filename = filename
+            if data['imageData'] is not None:
+                imageData = base64.b64decode(data['imageData'])
+            else:
+                # relative path from label file to relative path from cwd
+                imagePath = os.path.join(os.path.dirname(filename),
+                                         data['imagePath'])
+                with open(imagePath, 'rb') as f:
+                    imageData = f.read()
+            imagePath = data['imagePath']
+            lineColor = data['lineColor']
+            fillColor = data['fillColor']
+            shapes = (
+                (s['label'], s['points'], s['line_color'], s['fill_color'])
+                for s in data['shapes']
+            )
         except Exception as e:
             raise LabelFileError(e)
 
+        otherData = {}
+        for key, value in data.items():
+            if key not in keys:
+                otherData[key] = value
+
+        # Only replace data after everything is loaded.
+        self.shapes = shapes
+        self.imagePath = imagePath
+        self.imageData = imageData
+        self.lineColor = lineColor
+        self.fillColor = fillColor
+        self.filename = filename
+        self.otherData = otherData
+
     def save(self, filename, shapes, imagePath, imageData=None,
-             lineColor=None, fillColor=None):
+             lineColor=None, fillColor=None, otherData=None):
         if imageData is not None:
             imageData = base64.b64encode(imageData).decode('utf-8')
+        if otherData is None:
+            otherData = {}
         data = dict(
             shapes=shapes,
             lineColor=lineColor,
@@ -62,6 +73,8 @@ class LabelFile(object):
             imagePath=imagePath,
             imageData=imageData,
         )
+        for key, value in otherData.items():
+            data[key] = value
         try:
             with open(filename, 'wb' if PY2 else 'w') as f:
                 json.dump(data, f, ensure_ascii=True, indent=2)