Forráskód Böngészése

Added the ability to have different flags for each label plus shared flags for all labels

cmerchant 6 éve
szülő
commit
c2ba19759b
3 módosított fájl, 102 hozzáadás és 46 törlés
  1. 33 22
      labelme/app.py
  2. 16 8
      labelme/main.py
  3. 53 16
      labelme/widgets/label_dialog.py

+ 33 - 22
labelme/app.py

@@ -100,8 +100,6 @@ class MainWindow(QtWidgets.QMainWindow):
         self.label_flag_dock = QtWidgets.QDockWidget('Label Flags', self)
         self.label_flag_dock.setObjectName('Label Flags')
         self.label_flag_widget = QtWidgets.QListWidget()
-        if config['label_flags']:
-            self.loadLabelFlags({k: False for k in config['label_flags']})
         self.label_flag_dock.setWidget(self.label_flag_widget)
         self.label_flag_widget.itemChanged.connect(self.labelFlagChanged)
 
@@ -863,7 +861,7 @@ class MainWindow(QtWidgets.QMainWindow):
             return
         item = item if item else self.currentItem()
         shape = self.labelList.get_shape_from_item(item)
-        text, flags = self.labelDialog.popUp(item.text() if item else None, flags=(shape.flags if item else None))
+        text, flags = self.labelDialog.popUp(item.text() if item else None, flags=(shape.flags if shape else None))
         if text is None:
             return
         if not self.validateLabel(text):
@@ -872,7 +870,9 @@ class MainWindow(QtWidgets.QMainWindow):
                               .format(text, self._config['validate_label']))
             return
         shape.flags = flags
-        self.loadLabelFlags(flags)
+        shape.label = text
+        if self._config['label_flags']:
+            self.loadLabelFlags(flags, shape.label)
         item.setText(text)
         self.setDirty()
         if not self.uniqLabelList.findItems(text, Qt.MatchExactly):
@@ -920,7 +920,12 @@ class MainWindow(QtWidgets.QMainWindow):
 
     def addLabel(self, shape):
         if not shape.flags:
-            shape.flags =  {k: False for k in self._config['label_flags']}
+            shape.flags = {}
+            if self._config['label_flags']:
+                for label in ["__all__", shape.label]:
+                    if label in self._config['label_flags']:
+                        for key in self._config['label_flags'][label]:
+                            shape.flags[key] = False
         item = QtWidgets.QListWidgetItem(shape.label)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setCheckState(Qt.Checked)
@@ -941,8 +946,6 @@ class MainWindow(QtWidgets.QMainWindow):
         for shape in shapes:
             self.addLabel(shape)
         self.canvas.loadShapes(shapes, replace=replace)
-        if self._config['label_flags']:
-            self.loadLabelFlags({k: False for k in self._config['label_flags']})
 
     def loadLabels(self, shapes):
         s = []
@@ -967,13 +970,16 @@ class MainWindow(QtWidgets.QMainWindow):
             item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
             self.flag_widget.addItem(item)
 
-    def loadLabelFlags(self, flags):
+    def loadLabelFlags(self, flags=None, label=''):
         self.label_flag_widget.clear()
-        for key, flag in flags.items():
-            item = QtWidgets.QListWidgetItem(key)
-            item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
-            item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
-            self.label_flag_widget.addItem(item)
+        if flags:
+            for label in ["__all__", label]:
+                if label in self._config['label_flags']:
+                    for key in self._config['label_flags'][label]:
+                        item = QtWidgets.QListWidgetItem(key)
+                        item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
+                        item.setCheckState(Qt.Checked if (key in flags and flags[key]) else Qt.Unchecked)
+                        self.label_flag_widget.addItem(item)
 
     def saveLabels(self, filename):
         lf = LabelFile()
@@ -1039,7 +1045,8 @@ class MainWindow(QtWidgets.QMainWindow):
         item = self.currentItem()
         if item:
             shape = self.labelList.get_shape_from_item(item)
-            self.loadLabelFlags(shape.flags)
+            if self._config['label_flags']:
+                self.loadLabelFlags(shape.flags, shape.label)
             if self.canvas.editing():
                 self._noSelectionSlot = True
                 self.canvas.selectShape(shape)
@@ -1056,14 +1063,16 @@ class MainWindow(QtWidgets.QMainWindow):
     def labelFlagChanged(self):
         item = self.currentItem()
         if item:
+            shape = self.labelList.get_shape_from_item(item)
             index = 0
             flags = {}
-            for key in self._config["label_flags"]:
-                checkBox = self.label_flag_widget.item(index)
-                index = index + 1
-                value = True if checkBox.checkState() else False
-                flags[key] = value
-            shape = self.labelList.get_shape_from_item(item)
+            for label in ["__all__", shape.label]:
+                if label in self._config['label_flags']:
+                    for key in self._config['label_flags'][label]:
+                        checkBox = self.label_flag_widget.item(index)
+                        index = index + 1
+                        value = True if checkBox.checkState() else False
+                        flags[key] = value
             if shape.flags != flags:
                 shape.flags = flags
                 self.setDirty()
@@ -1226,6 +1235,8 @@ class MainWindow(QtWidgets.QMainWindow):
         self.canvas.loadPixmap(QtGui.QPixmap.fromImage(image))
         if self._config['flags']:
             self.loadFlags({k: False for k in self._config['flags']})
+        if self._config['label_flags']:
+            self.loadLabelFlags()
         if self.labelFile:
             self.loadLabels(self.labelFile.shapes)
             if self.labelFile.flags is not None:
@@ -1557,8 +1568,8 @@ class MainWindow(QtWidgets.QMainWindow):
             self.remLabel(self.canvas.deleteSelected())
             self.setDirty()
             if self.noShapes():
-                if self._config['shape_flags']:
-                    self.loadLabelFlags({k: False for k in self._config['shape_flags']})
+                if self._config['label_flags']:
+                    self.loadLabelFlags()
                 for action in self.actions.onShapesPresent:
                     action.setEnabled(False)
 

+ 16 - 8
labelme/main.py

@@ -3,6 +3,7 @@ import codecs
 import logging
 import os
 import sys
+import yaml
 
 from qtpy import QtWidgets
 
@@ -81,7 +82,7 @@ def _main():
     parser.add_argument(
         '--labelflags',
         dest='label_flags',
-        help='comma separated list of label specific flags OR file containing flags',
+        help='yaml string of label specific flags OR file containing json string of label specific flags (ex. {human:[male,female],dog:[big],__all__:[occluded]} )',
         default=argparse.SUPPRESS,
     )
     parser.add_argument(
@@ -122,13 +123,6 @@ def _main():
                 args.flags = [l.strip() for l in f if l.strip()]
         else:
             args.flags = [l for l in args.flags.split(',') if l]
-            
-    if hasattr(args, 'label_flags'):
-        if os.path.isfile(args.label_flags):
-            with codecs.open(args.label_flags, 'r', encoding='utf-8') as f:
-                args.label_flags = [l.strip() for l in f if l.strip()]
-        else:
-            args.label_flags = [l for l in args.label_flags.split(',') if l]
 
     if hasattr(args, 'labels'):
         if os.path.isfile(args.labels):
@@ -137,6 +131,20 @@ def _main():
         else:
             args.labels = [l for l in args.labels.split(',') if l]
 
+    if hasattr(args, 'label_flags'):
+        if os.path.isfile(args.label_flags):
+            with codecs.open(args.label_flags, 'r', encoding='utf-8') as f:
+                args.label_flags = yaml.load(f)
+        else:
+            args.label_flags = yaml.load(args.label_flags)
+
+        # Add not overlapping labels from label flags
+        if not hasattr(args, 'labels'):
+            args.labels = []
+        for label in args.label_flags.keys():
+            if label != "__all__" and label not in args.labels:
+                args.labels.append(label)
+
     config_from_args = args.__dict__
     config_from_args.pop('version')
     reset_config = config_from_args.pop('reset_config')

+ 53 - 16
labelme/widgets/label_dialog.py

@@ -29,7 +29,7 @@ class LabelDialog(QtWidgets.QDialog):
 
     def __init__(self, text="Enter object label", parent=None, labels=None,
                  sort_labels=True, show_text_field=True,
-                 completion='startswith', fit_to_content=None, flags=[]):
+                 completion='startswith', fit_to_content=None, flags=None):
         if fit_to_content is None:
             fit_to_content = {'row': False, 'column': True}
         self._fit_to_content = fit_to_content
@@ -39,6 +39,8 @@ class LabelDialog(QtWidgets.QDialog):
         self.edit.setPlaceholderText(text)
         self.edit.setValidator(labelme.utils.labelValidator())
         self.edit.editingFinished.connect(self.postProcess)
+        if flags:
+            self.edit.textChanged.connect(self.updateFlags)
         layout = QtWidgets.QVBoxLayout()
         if show_text_field:
             layout.addWidget(self.edit)
@@ -79,8 +81,7 @@ class LabelDialog(QtWidgets.QDialog):
         self.label_flags = None
         if flags:
             self.label_flags = QtWidgets.QVBoxLayout()
-            for flag in flags:
-                self.label_flags.addWidget(QtWidgets.QCheckBox(flag, self))
+            self.resetFlags()
             layout.addItem(self.label_flags)
         self.setLayout(layout)
         # completion
@@ -113,6 +114,15 @@ class LabelDialog(QtWidgets.QDialog):
     def labelSelected(self, item):
         self.edit.setText(item.text())
 
+    def updateFlags(self, text):
+        flags = self.getFlags()
+        newFlags = {}
+        for label in ["__all__", text]:
+            if label in self.flags:
+                for key in self.flags[label]:
+                    newFlags[key] = False if key not in flags else flags[key]
+        self.setFlags(newFlags)
+
     def validate(self):
         text = self.edit.text()
         if hasattr(text, 'strip'):
@@ -130,15 +140,41 @@ class LabelDialog(QtWidgets.QDialog):
             text = text.trimmed()
         self.edit.setText(text)
 
-    def resetFlags(self):
-        for i in range(self.label_flags.count()):
-            item = self.label_flags.itemAt(i).widget()
-            item.setChecked(False)
-
-    def setFlags(self, flags):
-        for i in range(self.label_flags.count()):
+    def deleteFlags(self):
+        for i in reversed(range(self.label_flags.count())):
             item = self.label_flags.itemAt(i).widget()
-            item.setChecked(flags[item.text()])
+            self.label_flags.removeWidget(item)
+            item.setParent(None)
+
+    def resetFlags(self, text=''):
+        self.deleteFlags()
+
+        # Add all flags
+        for label in ["__all__", text]:
+            if label in self.flags:
+                for key in self.flags[label]:
+                    item = QtWidgets.QCheckBox(key, self)
+                    self.label_flags.addWidget(item)
+                    item.show()
+
+    def setFlags(self, flags, text=''):
+        self.deleteFlags()
+
+        # Add flags not set
+        for label in ["__all__", text]:
+            if label in self.flags:
+                for key in self.flags[label]:
+                    if key not in flags:
+                        item = QtWidgets.QCheckBox(key, self)
+                        self.label_flags.addWidget(item)
+                        item.show()
+
+        # Add set flags
+        for key in flags:
+            item = QtWidgets.QCheckBox(key, self)
+            item.setChecked(flags[key])
+            self.label_flags.addWidget(item)
+            item.show()
 
     def getFlags(self):
         flags = {}
@@ -159,10 +195,11 @@ class LabelDialog(QtWidgets.QDialog):
         # if text is None, the previous label in self.edit is kept
         if text is None:
             text = self.edit.text()
-        if flags:
-            self.setFlags(flags)
-        else:
-            self.resetFlags()
+        if self.label_flags:
+            if flags:
+                self.setFlags(flags)
+            else:
+                self.resetFlags(text)
         self.edit.setText(text)
         self.edit.setSelection(0, len(text))
         items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
@@ -175,4 +212,4 @@ class LabelDialog(QtWidgets.QDialog):
         self.edit.setFocus(QtCore.Qt.PopupFocusReason)
         if move:
             self.move(QtGui.QCursor.pos())
-        return (self.edit.text(), self.getFlags()) if self.exec_() else None
+        return (self.edit.text(), self.getFlags() if self.flags else None) if self.exec_() else (None, None)