Quellcode durchsuchen

Refactoring to merge the feature of label-specific flags

Kentaro Wada vor 6 Jahren
Ursprung
Commit
7ad8d5c5cf
6 geänderte Dateien mit 74 neuen und 76 gelöschten Zeilen
  1. 29 20
      labelme/app.py
  2. 1 1
      labelme/label_file.py
  3. 5 3
      labelme/main.py
  4. 2 1
      labelme/shape.py
  5. 2 1
      labelme/widgets/canvas.py
  6. 35 50
      labelme/widgets/label_dialog.py

+ 29 - 20
labelme/app.py

@@ -847,12 +847,20 @@ class MainWindow(QtWidgets.QMainWindow):
                     return True
         return False
 
-    def editLabel(self, item=None):
+    def editLabel(self, item=False):
+        if item and not isinstance(item, QtWidgets.QListWidgetItem):
+            raise TypeError('unsupported type of item: {}'.format(type(item)))
+
         if not self.canvas.editing():
             return
-        item = item if item else self.currentItem()
+        if not item:
+            item = self.currentItem()
+        if item is None:
+            return
         shape = self.labelList.get_shape_from_item(item)
-        text, flags = self.labelDialog.popUp(item.text() if item else None, flags=(shape.flags if shape else None))
+        if shape is None:
+            return
+        text, flags = self.labelDialog.popUp(shape.label, flags=shape.flags)
         if text is None:
             return
         if not self.validateLabel(text):
@@ -860,8 +868,8 @@ class MainWindow(QtWidgets.QMainWindow):
                               "Invalid label '{}' with validation type '{}'"
                               .format(text, self._config['validate_label']))
             return
-        shape.flags = flags
         shape.label = text
+        shape.flags = flags
         item.setText(text)
         self.setDirty()
         if not self.uniqLabelList.findItems(text, Qt.MatchExactly):
@@ -908,13 +916,6 @@ class MainWindow(QtWidgets.QMainWindow):
         self.actions.shapeFillColor.setEnabled(selected)
 
     def addLabel(self, shape):
-        if not shape.flags:
-            shape.flags = {}
-            if self._config['label_flags']:
-                for label in ["__all__", shape.label]:
-                    if label in self._config['label_flags']:
-                        for key in self._config['label_flags'][label]:
-                            shape.flags[key] = False
         item = QtWidgets.QListWidgetItem(shape.label)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setCheckState(Qt.Checked)
@@ -943,12 +944,22 @@ class MainWindow(QtWidgets.QMainWindow):
             for x, y in points:
                 shape.addPoint(QtCore.QPoint(x, y))
             shape.close()
-            s.append(shape)
+
             if line_color:
                 shape.line_color = QtGui.QColor(*line_color)
+
             if fill_color:
                 shape.fill_color = QtGui.QColor(*fill_color)
-            shape.flags = flags
+
+            default_flags = {}
+            if self._config['label_flags']:
+                for l in ['__all__', label]:
+                    for k in self._config['label_flags'].get(l, []):
+                        default_flags[k] = False
+            shape.flags = default_flags
+            shape.flags.update(flags)
+
+            s.append(shape)
         self.loadShapes(s)
 
     def loadFlags(self, flags):
@@ -1021,11 +1032,10 @@ class MainWindow(QtWidgets.QMainWindow):
 
     def labelSelectionChanged(self):
         item = self.currentItem()
-        if item:
+        if item and self.canvas.editing():
+            self._noSelectionSlot = True
             shape = self.labelList.get_shape_from_item(item)
-            if self.canvas.editing():
-                self._noSelectionSlot = True
-                self.canvas.selectShape(shape)
+            self.canvas.selectShape(shape)
 
     def labelItemChanged(self, item):
         shape = self.labelList.get_shape_from_item(item)
@@ -1045,6 +1055,7 @@ class MainWindow(QtWidgets.QMainWindow):
         """
         items = self.uniqLabelList.selectedItems()
         text = None
+        flags = None
         if items:
             text = items[0].text()
         if self._config['display_label_popup'] or not text:
@@ -1066,9 +1077,7 @@ class MainWindow(QtWidgets.QMainWindow):
             self.canvas.undoLastLine()
             self.canvas.shapesBackups.pop()
         else:
-            shape = self.canvas.setLastLabel(text)
-            shape.flags = flags
-            self.addLabel(shape)
+            self.addLabel(self.canvas.setLastLabel(text, flags))
             self.actions.editMode.setEnabled(True)
             self.actions.undoLastPoint.setEnabled(False)
             self.actions.undo.setEnabled(True)

+ 1 - 1
labelme/label_file.py

@@ -89,7 +89,7 @@ class LabelFile(object):
                     s['line_color'],
                     s['fill_color'],
                     s.get('shape_type', 'polygon'),
-                    s['flags'] if 'flags' in s else None
+                    s.get('flags', {}),
                 )
                 for s in data['shapes']
             )

+ 5 - 3
labelme/main.py

@@ -81,7 +81,9 @@ def _main():
     parser.add_argument(
         '--labelflags',
         dest='label_flags',
-        help='yaml string of label specific flags OR file containing json string of label specific flags (ex. {human:[male,female],dog:[big],__all__:[occluded]} )',
+        help='yaml string of label specific flags OR file containing json '
+             'string of label specific flags (ex. {person: [male, tall], '
+             'dog: [big, black, brown, white], __all__: [occluded]})',
         default=argparse.SUPPRESS,
     )
     parser.add_argument(
@@ -137,11 +139,11 @@ def _main():
         else:
             args.label_flags = yaml.load(args.label_flags)
 
-        # Add not overlapping labels from label flags
+        # add not overlapping labels from label flags
         if not hasattr(args, 'labels'):
             args.labels = []
         for label in args.label_flags.keys():
-            if label != "__all__" and label not in args.labels:
+            if label != '__all__' and label not in args.labels:
                 args.labels.append(label)
 
     config_from_args = args.__dict__

+ 2 - 1
labelme/shape.py

@@ -36,7 +36,8 @@ class Shape(object):
     point_size = 8
     scale = 1.0
 
-    def __init__(self, label=None, line_color=None, shape_type=None, flags=None):
+    def __init__(self, label=None, line_color=None, shape_type=None,
+                 flags=None):
         self.label = label
         self.points = []
         self.fill = False

+ 2 - 1
labelme/widgets/canvas.py

@@ -643,9 +643,10 @@ class Canvas(QtWidgets.QWidget):
         elif key == QtCore.Qt.Key_Return and self.canCloseShape():
             self.finalise()
 
-    def setLastLabel(self, text):
+    def setLastLabel(self, text, flags):
         assert text
         self.shapes[-1].label = text
+        self.shapes[-1].flags = flags
         self.shapesBackups.pop()
         self.storeShapes()
         return self.shapes[-1]

+ 35 - 50
labelme/widgets/label_dialog.py

@@ -77,12 +77,13 @@ class LabelDialog(QtWidgets.QDialog):
         self.edit.setListWidget(self.labelList)
         layout.addWidget(self.labelList)
         # label_flags
-        self.flags = flags
-        self.label_flags = None
-        if flags:
-            self.label_flags = QtWidgets.QVBoxLayout()
-            self.resetFlags()
-            layout.addItem(self.label_flags)
+        if flags is None:
+            flags = {}
+        self._flags = flags
+        self.flagsLayout = QtWidgets.QVBoxLayout()
+        self.resetFlags()
+        layout.addItem(self.flagsLayout)
+        self.edit.textChanged.connect(self.updateFlags)
         self.setLayout(layout)
         # completion
         completer = QtWidgets.QCompleter()
@@ -114,15 +115,6 @@ class LabelDialog(QtWidgets.QDialog):
     def labelSelected(self, item):
         self.edit.setText(item.text())
 
-    def updateFlags(self, text):
-        flags = self.getFlags()
-        newFlags = {}
-        for label in ["__all__", text]:
-            if label in self.flags:
-                for key in self.flags[label]:
-                    newFlags[key] = False if key not in flags else flags[key]
-        self.setFlags(newFlags)
-
     def validate(self):
         text = self.edit.text()
         if hasattr(text, 'strip'):
@@ -140,47 +132,41 @@ class LabelDialog(QtWidgets.QDialog):
             text = text.trimmed()
         self.edit.setText(text)
 
+    def updateFlags(self, label_new):
+        # keep state of shared flags
+        flags_old = self.getFlags()
+
+        flags_new = {}
+        for label in ['__all__', label_new]:
+            for key in self._flags.get(label, []):
+                flags_new[key] = flags_old.get(key, False)
+        self.setFlags(flags_new)
+
     def deleteFlags(self):
-        for i in reversed(range(self.label_flags.count())):
-            item = self.label_flags.itemAt(i).widget()
-            self.label_flags.removeWidget(item)
+        for i in reversed(range(self.flagsLayout.count())):
+            item = self.flagsLayout.itemAt(i).widget()
+            self.flagsLayout.removeWidget(item)
             item.setParent(None)
 
-    def resetFlags(self, text=''):
-        self.deleteFlags()
-
-        # Add all flags
-        for label in ["__all__", text]:
-            if label in self.flags:
-                for key in self.flags[label]:
-                    item = QtWidgets.QCheckBox(key, self)
-                    self.label_flags.addWidget(item)
-                    item.show()
+    def resetFlags(self, label=None):
+        flags = {k: False for k in self._flags.get('__all__', [])}
+        if label:
+            flags.update({k: False for k in self._flags.get(label, [])})
+        self.setFlags(flags)
 
-    def setFlags(self, flags, text=''):
+    def setFlags(self, flags):
         self.deleteFlags()
-
-        # Add flags not set
-        for label in ["__all__", text]:
-            if label in self.flags:
-                for key in self.flags[label]:
-                    if key not in flags:
-                        item = QtWidgets.QCheckBox(key, self)
-                        self.label_flags.addWidget(item)
-                        item.show()
-
-        # Add set flags
         for key in flags:
             item = QtWidgets.QCheckBox(key, self)
             item.setChecked(flags[key])
-            self.label_flags.addWidget(item)
+            self.flagsLayout.addWidget(item)
             item.show()
 
     def getFlags(self):
         flags = {}
-        for i in range(self.label_flags.count()):
-            item = self.label_flags.itemAt(i).widget()
-            flags[item.text()] = True if item.isChecked() else False
+        for i in range(self.flagsLayout.count()):
+            item = self.flagsLayout.itemAt(i).widget()
+            flags[item.text()] = item.isChecked()
         return flags
 
     def popUp(self, text=None, move=True, flags=None):
@@ -195,11 +181,10 @@ class LabelDialog(QtWidgets.QDialog):
         # if text is None, the previous label in self.edit is kept
         if text is None:
             text = self.edit.text()
-        if self.label_flags:
-            if flags:
-                self.setFlags(flags)
-            else:
-                self.resetFlags(text)
+        if flags:
+            self.setFlags(flags)
+        else:
+            self.resetFlags(text)
         self.edit.setText(text)
         self.edit.setSelection(0, len(text))
         items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
@@ -213,6 +198,6 @@ class LabelDialog(QtWidgets.QDialog):
         if move:
             self.move(QtGui.QCursor.pos())
         if self.exec_():
-            return self.edit.text(), self.getFlags() if self.flags else None
+            return self.edit.text(), self.getFlags()
         else:
             return None, None