Browse Source

Make labeled polygons sortable

Kentaro Wada 7 years ago
parent
commit
ca783362e2
2 changed files with 186 additions and 133 deletions
  1. 128 106
      examples/tutorial/apc2016_obj3.json
  2. 58 27
      labelme/app.py

File diff suppressed because it is too large
+ 128 - 106
examples/tutorial/apc2016_obj3.json


+ 58 - 27
labelme/app.py

@@ -84,6 +84,50 @@ class EscapableQListWidget(QtWidgets.QListWidget):
             self.clearSelection()
             self.clearSelection()
 
 
 
 
+class LabelQListWidget(QtWidgets.QListWidget):
+
+    def __init__(self, *args, **kwargs):
+        super(LabelQListWidget, self).__init__(*args, **kwargs)
+        self.canvas = None
+        self.itemsToShapes = []
+
+    def get_shape_from_item(self, item):
+        for index, (item_, shape) in enumerate(self.itemsToShapes):
+            if item_ == item:
+                return shape
+
+    def get_item_from_shape(self, shape):
+        for index, (item, shape_) in enumerate(self.itemsToShapes):
+            if shape_ == shape:
+                return item
+
+    def clear(self):
+        super(LabelQListWidget, self).clear()
+        self.itemsToShapes = []
+
+    def setParent(self, parent):
+        self.parent = parent
+
+    def dropEvent(self, event):
+        shapes = self.shapes
+        super(LabelQListWidget, self).dropEvent(event)
+        if self.shapes == shapes:
+            return
+        if self.canvas is None:
+            raise RuntimeError('self.canvas must be set beforehand.')
+        self.parent.setDirty()
+        self.canvas.shapes = self.shapes
+
+    @property
+    def shapes(self):
+        shapes = []
+        for i in range(self.count()):
+            item = self.item(i)
+            shape = self.get_shape_from_item(item)
+            shapes.append(shape)
+        return shapes
+
+
 class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = 0, 1, 2
     FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = 0, 1, 2
 
 
@@ -104,8 +148,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.labelDialog = LabelDialog(parent=self, labels=labels,
         self.labelDialog = LabelDialog(parent=self, labels=labels,
                                        sort_labels=sort_labels)
                                        sort_labels=sort_labels)
 
 
-        self.labelList = QtWidgets.QListWidget()
-        self.itemsToShapes = []
+        self.labelList = LabelQListWidget()
         self.lastOpenDir = None
         self.lastOpenDir = None
 
 
         self.labelList.itemActivated.connect(self.labelSelectionChanged)
         self.labelList.itemActivated.connect(self.labelSelectionChanged)
@@ -113,6 +156,9 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.labelList.itemDoubleClicked.connect(self.editLabel)
         self.labelList.itemDoubleClicked.connect(self.editLabel)
         # Connect to itemChanged to detect checkbox changes.
         # Connect to itemChanged to detect checkbox changes.
         self.labelList.itemChanged.connect(self.labelItemChanged)
         self.labelList.itemChanged.connect(self.labelItemChanged)
+        self.labelList.setDragDropMode(
+            QtWidgets.QAbstractItemView.InternalMove)
+        self.labelList.setParent(self)
 
 
         listLayout = QtWidgets.QVBoxLayout()
         listLayout = QtWidgets.QVBoxLayout()
         listLayout.setContentsMargins(0, 0, 0, 0)
         listLayout.setContentsMargins(0, 0, 0, 0)
@@ -154,7 +200,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.zoomWidget = ZoomWidget()
         self.zoomWidget = ZoomWidget()
         self.colorDialog = ColorDialog(parent=self)
         self.colorDialog = ColorDialog(parent=self)
 
 
-        self.canvas = Canvas()
+        self.canvas = self.labelList.canvas = Canvas()
         self.canvas.zoomRequest.connect(self.zoomRequest)
         self.canvas.zoomRequest.connect(self.zoomRequest)
 
 
         scroll = QtWidgets.QScrollArea()
         scroll = QtWidgets.QScrollArea()
@@ -420,7 +466,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     # Support Functions
     # Support Functions
 
 
     def noShapes(self):
     def noShapes(self):
-        return not self.itemsToShapes
+        return not self.labelList.itemsToShapes
 
 
     def toggleAdvancedMode(self, value=True):
     def toggleAdvancedMode(self, value=True):
         self._beginner = not value
         self._beginner = not value
@@ -487,7 +533,6 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.statusBar().showMessage(message, delay)
         self.statusBar().showMessage(message, delay)
 
 
     def resetState(self):
     def resetState(self):
-        self.itemsToShapes = []
         self.labelList.clear()
         self.labelList.clear()
         self.filename = None
         self.filename = None
         self.imageData = None
         self.imageData = None
@@ -603,8 +648,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         else:
         else:
             shape = self.canvas.selectedShape
             shape = self.canvas.selectedShape
             if shape:
             if shape:
-                index = self._find_item_from_shape(shape)
-                item, _ = self.itemsToShapes[index]
+                item = self.labelList.get_item_from_shape(shape)
                 item.setSelected(True)
                 item.setSelected(True)
             else:
             else:
                 self.labelList.clearSelection()
                 self.labelList.clearSelection()
@@ -618,7 +662,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         item = QtWidgets.QListWidgetItem(shape.label)
         item = QtWidgets.QListWidgetItem(shape.label)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setCheckState(Qt.Checked)
         item.setCheckState(Qt.Checked)
-        self.itemsToShapes.append((item, shape))
+        self.labelList.itemsToShapes.append((item, shape))
         self.labelList.addItem(item)
         self.labelList.addItem(item)
         if not self.uniqLabelList.findItems(shape.label, Qt.MatchExactly):
         if not self.uniqLabelList.findItems(shape.label, Qt.MatchExactly):
             self.uniqLabelList.addItem(shape.label)
             self.uniqLabelList.addItem(shape.label)
@@ -627,19 +671,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         for action in self.actions.onShapesPresent:
         for action in self.actions.onShapesPresent:
             action.setEnabled(True)
             action.setEnabled(True)
 
 
-    def _find_item_from_shape(self, shape):
-        for index, (item, shape_) in enumerate(self.itemsToShapes):
-            if shape_ == shape:
-                return index
-
-    def _find_shape_from_item(self, item):
-        for index, (item_, shape) in enumerate(self.itemsToShapes):
-            if item_ == item:
-                return index
-
     def remLabel(self, shape):
     def remLabel(self, shape):
-        index = self._find_item_from_shape(shape)
-        item, _ = self.itemsToShapes.pop(index)
+        item = self.labelList.get_item_from_shape(shape)
         self.labelList.takeItem(self.labelList.row(item))
         self.labelList.takeItem(self.labelList.row(item))
 
 
     def loadLabels(self, shapes):
     def loadLabels(self, shapes):
@@ -668,7 +701,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                         if s.fill_color != self.fillColor else None,
                         if s.fill_color != self.fillColor else None,
                         points=[(p.x(), p.y()) for p in s.points])
                         points=[(p.x(), p.y()) for p in s.points])
 
 
-        shapes = [format_shape(shape) for shape in self.canvas.shapes]
+        shapes = [format_shape(shape) for shape in self.labelList.shapes]
         try:
         try:
             imagePath = os.path.relpath(
             imagePath = os.path.relpath(
                 self.imagePath, os.path.dirname(filename))
                 self.imagePath, os.path.dirname(filename))
@@ -692,13 +725,11 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         item = self.currentItem()
         item = self.currentItem()
         if item and self.canvas.editing():
         if item and self.canvas.editing():
             self._noSelectionSlot = True
             self._noSelectionSlot = True
-            index = self._find_shape_from_item(item)
-            _, shape = self.itemsToShapes[index]
+            shape = self.labelList.get_shape_from_item(item)
             self.canvas.selectShape(shape)
             self.canvas.selectShape(shape)
 
 
     def labelItemChanged(self, item):
     def labelItemChanged(self, item):
-        index = self._find_shape_from_item(item)
-        _, shape = self.itemsToShapes[index]
+        shape = self.labelList.get_shape_from_item(item)
         label = str(item.text())
         label = str(item.text())
         if label != shape.label:
         if label != shape.label:
             shape.label = str(item.text())
             shape.label = str(item.text())
@@ -760,7 +791,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.adjustScale()
         self.adjustScale()
 
 
     def togglePolygons(self, value):
     def togglePolygons(self, value):
-        for item, shape in self.itemsToShapes:
+        for item, shape in self.labelList.itemsToShapes:
             item.setCheckState(Qt.Checked if value else Qt.Unchecked)
             item.setCheckState(Qt.Checked if value else Qt.Unchecked)
 
 
     def loadFile(self, filename=None):
     def loadFile(self, filename=None):
@@ -995,7 +1026,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 
 
     # Message Dialogs. #
     # Message Dialogs. #
     def hasLabels(self):
     def hasLabels(self):
-        if not self.itemsToShapes:
+        if not self.labelList.itemsToShapes:
             self.errorMessage(
             self.errorMessage(
                 'No objects labeled',
                 'No objects labeled',
                 'You must label at least one object to save the file.')
                 'You must label at least one object to save the file.')

Some files were not shown because too many files changed in this diff