Browse Source

Add image level flag annotation

Kentaro Wada 7 years ago
parent
commit
68166666af
3 changed files with 55 additions and 15 deletions
  1. 39 13
      labelme/app.py
  2. 1 0
      labelme/config/default_config.yaml
  3. 15 2
      labelme/labelFile.py

+ 39 - 13
labelme/app.py

@@ -165,6 +165,15 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.labelListContainer = QtWidgets.QWidget()
         self.labelListContainer.setLayout(listLayout)
 
+        self.flag_dock = self.flag_widget = None
+        self.flag_dock = QtWidgets.QDockWidget('Flags', self)
+        self.flag_dock.setObjectName('Flags')
+        self.flag_widget = QtWidgets.QListWidget()
+        if config['flags']:
+            self.loadFlags({k: False for k in config['flags']})
+        self.flag_dock.setWidget(self.flag_widget)
+        self.flag_widget.itemChanged.connect(self.setDirty)
+
         self.uniqLabelList = EscapableQListWidget()
         self.uniqLabelList.setToolTip(
             "Select label to start annotating for it. "
@@ -214,6 +223,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 
         self.setCentralWidget(scrollArea)
 
+        self.addDockWidget(Qt.RightDockWidgetArea, self.flag_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.labelsdock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.filedock)
@@ -686,6 +696,14 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                 shape.fill_color = QtGui.QColor(*fill_color)
         self.loadShapes(s)
 
+    def loadFlags(self, flags):
+        self.flag_widget.clear()
+        for key, flag in flags.items():
+            item = QtWidgets.QListWidgetItem(key)
+            item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
+            item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
+            self.flag_widget.addItem(item)
+
     def saveLabels(self, filename):
         lf = LabelFile()
 
@@ -698,13 +716,26 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                         points=[(p.x(), p.y()) for p in s.points])
 
         shapes = [format_shape(shape) for shape in self.labelList.shapes]
+        flags = {}
+        for i in range(self.flag_widget.count()):
+            item = self.flag_widget.item(i)
+            key = item.text()
+            flag = item.checkState() == Qt.Checked
+            flags[key] = flag
         try:
             imagePath = os.path.relpath(
                 self.imagePath, os.path.dirname(filename))
             imageData = self.imageData if self._config['store_data'] else None
-            lf.save(filename, shapes, imagePath, imageData,
-                    self.lineColor.getRgb(), self.fillColor.getRgb(),
-                    self.otherData)
+            lf.save(
+                filename=filename,
+                shapes=shapes,
+                imagePath=imagePath,
+                imageData=imageData,
+                lineColor=self.lineColor.getRgb(),
+                fillColor=self.fillColor.getRgb(),
+                otherData=self.otherData,
+                flags=flags,
+            )
             self.labelFile = lf
             # disable allows next and previous image to proceed
             # self.filename = filename
@@ -880,6 +911,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.canvas.loadPixmap(QtGui.QPixmap.fromImage(image))
         if self.labelFile:
             self.loadLabels(self.labelFile.shapes)
+            if self.labelFile.flags is not None:
+                self.loadFlags(self.labelFile.flags)
         self.setClean()
         self.canvas.setEnabled(True)
         self.adjustScale(initial=True)
@@ -1243,16 +1276,9 @@ def main():
     output = config_from_args.pop('output')
     config_file = config_from_args.pop('config_file')
     # drop the default config
-    if not config_from_args['auto_save']:
-        config_from_args.pop('auto_save')
-    if config_from_args['store_data']:
-        config_from_args.pop('store_data')
-    if not config_from_args['labels']:
-        config_from_args.pop('labels')
-    if not config_from_args['sort_labels']:
-        config_from_args.pop('sort_labels')
-    if not config_from_args['validate_label']:
-        config_from_args.pop('validate_label')
+    for k, v in list(config_from_args.items()):
+        if v is None:
+            config_from_args.pop(k)
     config = get_config(config_from_args, config_file)
 
     app = QtWidgets.QApplication(sys.argv)

+ 1 - 0
labelme/config/default_config.yaml

@@ -1,6 +1,7 @@
 auto_save: false
 store_data: true
 
+flags: null
 labels: null
 sort_labels: true
 validate_label: null

+ 15 - 2
labelme/labelFile.py

@@ -24,7 +24,14 @@ class LabelFile(object):
         self.filename = filename
 
     def load(self, filename):
-        keys = ['imageData', 'imagePath', 'lineColor', 'fillColor', 'shapes']
+        keys = [
+            'imageData',
+            'imagePath',
+            'lineColor',
+            'fillColor',
+            'shapes',  # polygonal annotations
+            'flags',   # image level flags
+        ]
         try:
             with open(filename, 'rb' if PY2 else 'r') as f:
                 data = json.load(f)
@@ -36,6 +43,7 @@ class LabelFile(object):
                                          data['imagePath'])
                 with open(imagePath, 'rb') as f:
                     imageData = f.read()
+            flags = data.get('flags')
             imagePath = data['imagePath']
             lineColor = data['lineColor']
             fillColor = data['fillColor']
@@ -52,6 +60,7 @@ class LabelFile(object):
                 otherData[key] = value
 
         # Only replace data after everything is loaded.
+        self.flags = flags
         self.shapes = shapes
         self.imagePath = imagePath
         self.imageData = imageData
@@ -61,12 +70,16 @@ class LabelFile(object):
         self.otherData = otherData
 
     def save(self, filename, shapes, imagePath, imageData=None,
-             lineColor=None, fillColor=None, otherData=None):
+             lineColor=None, fillColor=None, otherData=None,
+             flags=None):
         if imageData is not None:
             imageData = base64.b64encode(imageData).decode('utf-8')
         if otherData is None:
             otherData = {}
+        if flags is None:
+            flags = []
         data = dict(
+            flags=flags,
             shapes=shapes,
             lineColor=lineColor,
             fillColor=fillColor,