Jelajahi Sumber

Preserve other keys on shape level

Grzegorz Ruciński 5 tahun lalu
induk
melakukan
9bc8a754e0
4 mengubah file dengan 23 tambahan dan 6 penghapusan
  1. 8 4
      labelme/app.py
  2. 11 1
      labelme/label_file.py
  3. 1 0
      labelme/shape.py
  4. 3 1
      tests/labelme_tests/test_app.py

+ 8 - 4
labelme/app.py

@@ -1039,7 +1039,8 @@ class MainWindow(QtWidgets.QMainWindow):
             points = shape['points']
             shape_type = shape['shape_type']
             flags = shape['flags']
-            group_id = shape.get('group_id')
+            group_id = shape['group_id']
+            other_data = shape['other_data']
 
             shape = Shape(
                 label=label,
@@ -1058,6 +1059,7 @@ class MainWindow(QtWidgets.QMainWindow):
                             default_flags[key] = False
             shape.flags = default_flags
             shape.flags.update(flags)
+            shape.other_data = other_data
 
             s.append(shape)
         self.loadShapes(s)
@@ -1074,13 +1076,15 @@ class MainWindow(QtWidgets.QMainWindow):
         lf = LabelFile()
 
         def format_shape(s):
-            return dict(
+            data = s.other_data.copy()
+            data.update(dict(
                 label=s.label.encode('utf-8') if PY2 else s.label,
                 points=[(p.x(), p.y()) for p in s.points],
                 group_id=s.group_id,
                 shape_type=s.shape_type,
-                flags=s.flags
-            )
+                flags=s.flags,
+            ))
+            return data
 
         shapes = [format_shape(shape) for shape in self.labelList.shapes]
         flags = {}

+ 11 - 1
labelme/label_file.py

@@ -64,6 +64,13 @@ class LabelFile(object):
             'imageHeight',
             'imageWidth',
         ]
+        shape_keys = [
+            'label',
+            'points',
+            'group_id',
+            'shape_type',
+            'flags',
+        ]
         try:
             with open(filename, 'rb' if PY2 else 'r') as f:
                 data = json.load(f)
@@ -103,7 +110,10 @@ class LabelFile(object):
                     points=s['points'],
                     shape_type=s.get('shape_type', 'polygon'),
                     flags=s.get('flags', {}),
-                    group_id=s.get('group_id')
+                    group_id=s.get('group_id'),
+                    other_data={
+                        k: v for k, v in s.items() if k not in shape_keys
+                    }
                 )
                 for s in data['shapes']
             ]

+ 1 - 0
labelme/shape.py

@@ -46,6 +46,7 @@ class Shape(object):
         self.selected = False
         self.shape_type = shape_type
         self.flags = flags
+        self.other_data = {}
 
         self._highlightIndex = None
         self._highlightMode = self.NEAR_VERTEX

+ 3 - 1
tests/labelme_tests/test_app.py

@@ -91,9 +91,11 @@ def test_MainWindow_annotate_jpg(qtbot):
     ]
     shapes = [dict(
         label=label,
+        group_id=None,
         points=points,
         shape_type='polygon',
-        flags={}
+        flags={},
+        other_data={}
     )]
     win.loadLabels(shapes)
     win.saveFile()