Browse Source

Make labeled polygons sortable

Kentaro Wada 7 years ago
parent
commit
ca783362e2
2 changed files with 186 additions and 133 deletions
  1. 128 106
      examples/tutorial/apc2016_obj3.json
  2. 58 27
      labelme/app.py

File diff suppressed because it is too large
+ 128 - 106
examples/tutorial/apc2016_obj3.json


+ 58 - 27
labelme/app.py

@@ -84,6 +84,50 @@ class EscapableQListWidget(QtWidgets.QListWidget):
             self.clearSelection()
 
 
+class LabelQListWidget(QtWidgets.QListWidget):
+
+    def __init__(self, *args, **kwargs):
+        super(LabelQListWidget, self).__init__(*args, **kwargs)
+        self.canvas = None
+        self.itemsToShapes = []
+
+    def get_shape_from_item(self, item):
+        for index, (item_, shape) in enumerate(self.itemsToShapes):
+            if item_ == item:
+                return shape
+
+    def get_item_from_shape(self, shape):
+        for index, (item, shape_) in enumerate(self.itemsToShapes):
+            if shape_ == shape:
+                return item
+
+    def clear(self):
+        super(LabelQListWidget, self).clear()
+        self.itemsToShapes = []
+
+    def setParent(self, parent):
+        self.parent = parent
+
+    def dropEvent(self, event):
+        shapes = self.shapes
+        super(LabelQListWidget, self).dropEvent(event)
+        if self.shapes == shapes:
+            return
+        if self.canvas is None:
+            raise RuntimeError('self.canvas must be set beforehand.')
+        self.parent.setDirty()
+        self.canvas.shapes = self.shapes
+
+    @property
+    def shapes(self):
+        shapes = []
+        for i in range(self.count()):
+            item = self.item(i)
+            shape = self.get_shape_from_item(item)
+            shapes.append(shape)
+        return shapes
+
+
 class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = 0, 1, 2
 
@@ -104,8 +148,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.labelDialog = LabelDialog(parent=self, labels=labels,
                                        sort_labels=sort_labels)
 
-        self.labelList = QtWidgets.QListWidget()
-        self.itemsToShapes = []
+        self.labelList = LabelQListWidget()
         self.lastOpenDir = None
 
         self.labelList.itemActivated.connect(self.labelSelectionChanged)
@@ -113,6 +156,9 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.labelList.itemDoubleClicked.connect(self.editLabel)
         # Connect to itemChanged to detect checkbox changes.
         self.labelList.itemChanged.connect(self.labelItemChanged)
+        self.labelList.setDragDropMode(
+            QtWidgets.QAbstractItemView.InternalMove)
+        self.labelList.setParent(self)
 
         listLayout = QtWidgets.QVBoxLayout()
         listLayout.setContentsMargins(0, 0, 0, 0)
@@ -154,7 +200,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.zoomWidget = ZoomWidget()
         self.colorDialog = ColorDialog(parent=self)
 
-        self.canvas = Canvas()
+        self.canvas = self.labelList.canvas = Canvas()
         self.canvas.zoomRequest.connect(self.zoomRequest)
 
         scroll = QtWidgets.QScrollArea()
@@ -420,7 +466,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     # Support Functions
 
     def noShapes(self):
-        return not self.itemsToShapes
+        return not self.labelList.itemsToShapes
 
     def toggleAdvancedMode(self, value=True):
         self._beginner = not value
@@ -487,7 +533,6 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.statusBar().showMessage(message, delay)
 
     def resetState(self):
-        self.itemsToShapes = []
         self.labelList.clear()
         self.filename = None
         self.imageData = None
@@ -603,8 +648,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         else:
             shape = self.canvas.selectedShape
             if shape:
-                index = self._find_item_from_shape(shape)
-                item, _ = self.itemsToShapes[index]
+                item = self.labelList.get_item_from_shape(shape)
                 item.setSelected(True)
             else:
                 self.labelList.clearSelection()
@@ -618,7 +662,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         item = QtWidgets.QListWidgetItem(shape.label)
         item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
         item.setCheckState(Qt.Checked)
-        self.itemsToShapes.append((item, shape))
+        self.labelList.itemsToShapes.append((item, shape))
         self.labelList.addItem(item)
         if not self.uniqLabelList.findItems(shape.label, Qt.MatchExactly):
             self.uniqLabelList.addItem(shape.label)
@@ -627,19 +671,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         for action in self.actions.onShapesPresent:
             action.setEnabled(True)
 
-    def _find_item_from_shape(self, shape):
-        for index, (item, shape_) in enumerate(self.itemsToShapes):
-            if shape_ == shape:
-                return index
-
-    def _find_shape_from_item(self, item):
-        for index, (item_, shape) in enumerate(self.itemsToShapes):
-            if item_ == item:
-                return index
-
     def remLabel(self, shape):
-        index = self._find_item_from_shape(shape)
-        item, _ = self.itemsToShapes.pop(index)
+        item = self.labelList.get_item_from_shape(shape)
         self.labelList.takeItem(self.labelList.row(item))
 
     def loadLabels(self, shapes):
@@ -668,7 +701,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                         if s.fill_color != self.fillColor else None,
                         points=[(p.x(), p.y()) for p in s.points])
 
-        shapes = [format_shape(shape) for shape in self.canvas.shapes]
+        shapes = [format_shape(shape) for shape in self.labelList.shapes]
         try:
             imagePath = os.path.relpath(
                 self.imagePath, os.path.dirname(filename))
@@ -692,13 +725,11 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         item = self.currentItem()
         if item and self.canvas.editing():
             self._noSelectionSlot = True
-            index = self._find_shape_from_item(item)
-            _, shape = self.itemsToShapes[index]
+            shape = self.labelList.get_shape_from_item(item)
             self.canvas.selectShape(shape)
 
     def labelItemChanged(self, item):
-        index = self._find_shape_from_item(item)
-        _, shape = self.itemsToShapes[index]
+        shape = self.labelList.get_shape_from_item(item)
         label = str(item.text())
         if label != shape.label:
             shape.label = str(item.text())
@@ -760,7 +791,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.adjustScale()
 
     def togglePolygons(self, value):
-        for item, shape in self.itemsToShapes:
+        for item, shape in self.labelList.itemsToShapes:
             item.setCheckState(Qt.Checked if value else Qt.Unchecked)
 
     def loadFile(self, filename=None):
@@ -995,7 +1026,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 
     # Message Dialogs. #
     def hasLabels(self):
-        if not self.itemsToShapes:
+        if not self.labelList.itemsToShapes:
             self.errorMessage(
                 'No objects labeled',
                 'You must label at least one object to save the file.')

Some files were not shown because too many files changed in this diff