Browse Source

Added label specific flags

cmerchant 6 years ago
parent
commit
ecdd186347
5 changed files with 76 additions and 10 deletions
  1. 49 5
      labelme/app.py
  2. 6 0
      labelme/config/default_config.yaml
  3. 1 0
      labelme/label_file.py
  4. 18 4
      labelme/main.py
  5. 2 1
      labelme/shape.py

+ 49 - 5
labelme/app.py

@@ -95,6 +95,15 @@ class MainWindow(QtWidgets.QMainWindow):
         self.flag_dock.setWidget(self.flag_widget)
         self.flag_dock.setWidget(self.flag_widget)
         self.flag_widget.itemChanged.connect(self.setDirty)
         self.flag_widget.itemChanged.connect(self.setDirty)
 
 
+        self.label_flag_dock = self.label_flag_widget = None
+        self.label_flag_dock = QtWidgets.QDockWidget('Label Flags', self)
+        self.label_flag_dock.setObjectName('Label Flags')
+        self.label_flag_widget = QtWidgets.QListWidget()
+        if config['label_flags']:
+            self.loadLabelFlags({k: False for k in config['label_flags']})
+        self.label_flag_dock.setWidget(self.label_flag_widget)
+        self.label_flag_widget.itemChanged.connect(self.labelFlagChanged)
+
         self.labelList.itemActivated.connect(self.labelSelectionChanged)
         self.labelList.itemActivated.connect(self.labelSelectionChanged)
         self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
         self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
         self.labelList.itemDoubleClicked.connect(self.editLabel)
         self.labelList.itemDoubleClicked.connect(self.editLabel)
@@ -161,7 +170,7 @@ class MainWindow(QtWidgets.QMainWindow):
         self.setCentralWidget(scrollArea)
         self.setCentralWidget(scrollArea)
 
 
         features = QtWidgets.QDockWidget.DockWidgetFeatures()
         features = QtWidgets.QDockWidget.DockWidgetFeatures()
-        for dock in ['flag_dock', 'label_dock', 'shape_dock', 'file_dock']:
+        for dock in ['flag_dock', 'label_flag_dock', 'label_dock', 'shape_dock', 'file_dock']:
             if self._config[dock]['closable']:
             if self._config[dock]['closable']:
                 features = features | QtWidgets.QDockWidget.DockWidgetClosable
                 features = features | QtWidgets.QDockWidget.DockWidgetClosable
             if self._config[dock]['floatable']:
             if self._config[dock]['floatable']:
@@ -173,6 +182,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 getattr(self, dock).setVisible(False)
                 getattr(self, dock).setVisible(False)
 
 
         self.addDockWidget(Qt.RightDockWidgetArea, self.flag_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.flag_dock)
+        self.addDockWidget(Qt.RightDockWidgetArea, self.label_flag_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.label_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.label_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.shape_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.shape_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.file_dock)
         self.addDockWidget(Qt.RightDockWidgetArea, self.file_dock)
@@ -494,6 +504,7 @@ class MainWindow(QtWidgets.QMainWindow):
         utils.addActions(
         utils.addActions(
             self.menus.view,
             self.menus.view,
             (
             (
+                self.label_flag_dock.toggleViewAction(),
                 self.flag_dock.toggleViewAction(),
                 self.flag_dock.toggleViewAction(),
                 self.label_dock.toggleViewAction(),
                 self.label_dock.toggleViewAction(),
                 self.shape_dock.toggleViewAction(),
                 self.shape_dock.toggleViewAction(),
@@ -904,6 +915,8 @@ class MainWindow(QtWidgets.QMainWindow):
         self.actions.shapeFillColor.setEnabled(selected)
         self.actions.shapeFillColor.setEnabled(selected)
 
 
     def addLabel(self, shape):
     def addLabel(self, shape):
+        if not shape.flags:
+            shape.flags =  {k: False for k in self._config['label_flags']}
         item = QtWidgets.QListWidgetItem(shape.label)
         item = QtWidgets.QListWidgetItem(shape.label)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setCheckState(Qt.Checked)
         item.setCheckState(Qt.Checked)
@@ -924,10 +937,12 @@ class MainWindow(QtWidgets.QMainWindow):
         for shape in shapes:
         for shape in shapes:
             self.addLabel(shape)
             self.addLabel(shape)
         self.canvas.loadShapes(shapes, replace=replace)
         self.canvas.loadShapes(shapes, replace=replace)
+        if self._config['label_flags']:
+            self.loadLabelFlags({k: False for k in self._config['label_flags']})
 
 
     def loadLabels(self, shapes):
     def loadLabels(self, shapes):
         s = []
         s = []
-        for label, points, line_color, fill_color, shape_type in shapes:
+        for label, points, line_color, fill_color, shape_type, flags in shapes:
             shape = Shape(label=label, shape_type=shape_type)
             shape = Shape(label=label, shape_type=shape_type)
             for x, y in points:
             for x, y in points:
                 shape.addPoint(QtCore.QPoint(x, y))
                 shape.addPoint(QtCore.QPoint(x, y))
@@ -937,6 +952,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 shape.line_color = QtGui.QColor(*line_color)
                 shape.line_color = QtGui.QColor(*line_color)
             if fill_color:
             if fill_color:
                 shape.fill_color = QtGui.QColor(*fill_color)
                 shape.fill_color = QtGui.QColor(*fill_color)
+            shape.flags = flags
         self.loadShapes(s)
         self.loadShapes(s)
 
 
     def loadFlags(self, flags):
     def loadFlags(self, flags):
@@ -947,6 +963,14 @@ class MainWindow(QtWidgets.QMainWindow):
             item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
             item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
             self.flag_widget.addItem(item)
             self.flag_widget.addItem(item)
 
 
+    def loadLabelFlags(self, flags):
+        self.label_flag_widget.clear()
+        for key, flag in flags.items():
+            item = QtWidgets.QListWidgetItem(key)
+            item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
+            item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
+            self.label_flag_widget.addItem(item)
+
     def saveLabels(self, filename):
     def saveLabels(self, filename):
         lf = LabelFile()
         lf = LabelFile()
 
 
@@ -959,6 +983,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 if s.fill_color != self.fillColor else None,
                 if s.fill_color != self.fillColor else None,
                 points=[(p.x(), p.y()) for p in s.points],
                 points=[(p.x(), p.y()) for p in s.points],
                 shape_type=s.shape_type,
                 shape_type=s.shape_type,
+                flags=s.flags
             )
             )
 
 
         shapes = [format_shape(shape) for shape in self.labelList.shapes]
         shapes = [format_shape(shape) for shape in self.labelList.shapes]
@@ -1008,10 +1033,12 @@ class MainWindow(QtWidgets.QMainWindow):
 
 
     def labelSelectionChanged(self):
     def labelSelectionChanged(self):
         item = self.currentItem()
         item = self.currentItem()
-        if item and self.canvas.editing():
-            self._noSelectionSlot = True
+        if item:
             shape = self.labelList.get_shape_from_item(item)
             shape = self.labelList.get_shape_from_item(item)
-            self.canvas.selectShape(shape)
+            self.loadLabelFlags(shape.flags)
+            if self.canvas.editing():
+                self._noSelectionSlot = True
+                self.canvas.selectShape(shape)
 
 
     def labelItemChanged(self, item):
     def labelItemChanged(self, item):
         shape = self.labelList.get_shape_from_item(item)
         shape = self.labelList.get_shape_from_item(item)
@@ -1022,6 +1049,21 @@ class MainWindow(QtWidgets.QMainWindow):
         else:  # User probably changed item visibility
         else:  # User probably changed item visibility
             self.canvas.setShapeVisible(shape, item.checkState() == Qt.Checked)
             self.canvas.setShapeVisible(shape, item.checkState() == Qt.Checked)
 
 
+    def labelFlagChanged(self):
+        item = self.currentItem()
+        if item:
+            index = 0
+            flags = {}
+            for key in self._config["label_flags"]:
+                checkBox = self.label_flag_widget.item(index)
+                index = index + 1
+                value = True if checkBox.checkState() else False
+                flags[key] = value
+            shape = self.labelList.get_shape_from_item(item)
+            if shape.flags != flags:
+                shape.flags = flags
+                self.setDirty()
+
     # Callback functions:
     # Callback functions:
 
 
     def newShape(self):
     def newShape(self):
@@ -1509,6 +1551,8 @@ class MainWindow(QtWidgets.QMainWindow):
             self.remLabel(self.canvas.deleteSelected())
             self.remLabel(self.canvas.deleteSelected())
             self.setDirty()
             self.setDirty()
             if self.noShapes():
             if self.noShapes():
+                if self._config['shape_flags']:
+                    self.loadLabelFlags({k: False for k in self._config['shape_flags']})
                 for action in self.actions.onShapesPresent:
                 for action in self.actions.onShapesPresent:
                     action.setEnabled(False)
                     action.setEnabled(False)
 
 

+ 6 - 0
labelme/config/default_config.yaml

@@ -6,6 +6,7 @@ keep_prev: false
 logger_level: info
 logger_level: info
 
 
 flags: null
 flags: null
+label_flags: null
 labels: null
 labels: null
 file_search: null
 file_search: null
 sort_labels: true
 sort_labels: true
@@ -17,6 +18,11 @@ flag_dock:
   closable: true
   closable: true
   movable: true
   movable: true
   floatable: true
   floatable: true
+label_flag_dock:
+  show: true
+  closable: true
+  movable: true
+  floatable: true
 label_dock:
 label_dock:
   show: true
   show: true
   closable: true
   closable: true

+ 1 - 0
labelme/label_file.py

@@ -89,6 +89,7 @@ class LabelFile(object):
                     s['line_color'],
                     s['line_color'],
                     s['fill_color'],
                     s['fill_color'],
                     s.get('shape_type', 'polygon'),
                     s.get('shape_type', 'polygon'),
+                    s['flags'] if 'flags' in s else None
                 )
                 )
                 for s in data['shapes']
                 for s in data['shapes']
             )
             )

+ 18 - 4
labelme/main.py

@@ -15,10 +15,11 @@ from labelme.utils import newIcon
 
 
 
 
 def main():
 def main():
-    try:
-        _main()
-    except Exception as e:
-        logger.error(e)
+    _main()
+    # try:
+    #     _main()
+    # except Exception as e:
+    #     logger.error(e)
 
 
 
 
 def _main():
 def _main():
@@ -77,6 +78,12 @@ def _main():
         help='comma separated list of flags OR file containing flags',
         help='comma separated list of flags OR file containing flags',
         default=argparse.SUPPRESS,
         default=argparse.SUPPRESS,
     )
     )
+    parser.add_argument(
+        '--labelflags',
+        dest='label_flags',
+        help='comma separated list of label specific flags OR file containing flags',
+        default=argparse.SUPPRESS,
+    )
     parser.add_argument(
     parser.add_argument(
         '--labels',
         '--labels',
         help='comma separated list of labels OR file containing labels',
         help='comma separated list of labels OR file containing labels',
@@ -115,6 +122,13 @@ def _main():
                 args.flags = [l.strip() for l in f if l.strip()]
                 args.flags = [l.strip() for l in f if l.strip()]
         else:
         else:
             args.flags = [l for l in args.flags.split(',') if l]
             args.flags = [l for l in args.flags.split(',') if l]
+            
+    if hasattr(args, 'label_flags'):
+        if os.path.isfile(args.label_flags):
+            with codecs.open(args.label_flags, 'r', encoding='utf-8') as f:
+                args.label_flags = [l.strip() for l in f if l.strip()]
+        else:
+            args.label_flags = [l for l in args.label_flags.split(',') if l]
 
 
     if hasattr(args, 'labels'):
     if hasattr(args, 'labels'):
         if os.path.isfile(args.labels):
         if os.path.isfile(args.labels):

+ 2 - 1
labelme/shape.py

@@ -36,12 +36,13 @@ class Shape(object):
     point_size = 8
     point_size = 8
     scale = 1.0
     scale = 1.0
 
 
-    def __init__(self, label=None, line_color=None, shape_type=None):
+    def __init__(self, label=None, line_color=None, shape_type=None, flags=None):
         self.label = label
         self.label = label
         self.points = []
         self.points = []
         self.fill = False
         self.fill = False
         self.selected = False
         self.selected = False
         self.shape_type = shape_type
         self.shape_type = shape_type
+        self.flags = flags
 
 
         self._highlightIndex = None
         self._highlightIndex = None
         self._highlightMode = self.NEAR_VERTEX
         self._highlightMode = self.NEAR_VERTEX