소스 검색

Replace LabelQListWidget with custom LabelListWidget

Kentaro Wada 5 년 전
부모
커밋
38e289b7a0

+ 25 - 33
labelme/app.py

@@ -25,7 +25,8 @@ from labelme.shape import Shape
 from labelme.widgets import Canvas
 from labelme.widgets import ColorDialog
 from labelme.widgets import LabelDialog
-from labelme.widgets import LabelQListWidget
+from labelme.widgets import LabelListWidget
+from labelme.widgets import LabelListWidgetItem
 from labelme.widgets import ToolBar
 from labelme.widgets import UniqueLabelQListWidget
 from labelme.widgets import ZoomWidget
@@ -88,7 +89,7 @@ class MainWindow(QtWidgets.QMainWindow):
             flags=self._config['label_flags']
         )
 
-        self.labelList = LabelQListWidget()
+        self.labelList = LabelListWidget()
         self.lastOpenDir = None
 
         self.flag_dock = self.flag_widget = None
@@ -100,14 +101,10 @@ class MainWindow(QtWidgets.QMainWindow):
         self.flag_dock.setWidget(self.flag_widget)
         self.flag_widget.itemChanged.connect(self.setDirty)
 
-        self.labelList.itemActivated.connect(self.labelSelectionChanged)
         self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
         self.labelList.itemDoubleClicked.connect(self.editLabel)
-        # Connect to itemChanged to detect checkbox changes.
         self.labelList.itemChanged.connect(self.labelItemChanged)
-        self.labelList.setDragDropMode(
-            QtWidgets.QAbstractItemView.InternalMove)
-        self.labelList.setParent(self)
+        self.labelList.itemDropped.connect(self.labelOrderChanged)
         self.shape_dock = QtWidgets.QDockWidget(
             self.tr('Polygon Labels'),
             self
@@ -679,7 +676,7 @@ class MainWindow(QtWidgets.QMainWindow):
     # Support Functions
 
     def noShapes(self):
-        return not self.labelList.itemsToShapes
+        return not len(self.labelList)
 
     def populateModeActions(self):
         tool, menu = self.actions.tool, self.actions.menu
@@ -887,9 +884,9 @@ class MainWindow(QtWidgets.QMainWindow):
                     return True
         return False
 
-    def editLabel(self, item=False):
-        if item and not isinstance(item, QtWidgets.QListWidgetItem):
-            raise TypeError('unsupported type of item: {}'.format(type(item)))
+    def editLabel(self, item=None):
+        if item and not isinstance(item, LabelListWidgetItem):
+            raise TypeError('item must be LabelListWidgetItem type')
 
         if not self.canvas.editing():
             return
@@ -897,7 +894,7 @@ class MainWindow(QtWidgets.QMainWindow):
             item = self.currentItem()
         if item is None:
             return
-        shape = self.labelList.get_shape_from_item(item)
+        shape = item.shape()
         if shape is None:
             return
         text, flags, group_id = self.labelDialog.popUp(
@@ -957,8 +954,8 @@ class MainWindow(QtWidgets.QMainWindow):
         self.canvas.selectedShapes = selected_shapes
         for shape in self.canvas.selectedShapes:
             shape.selected = True
-            item = self.labelList.get_item_from_shape(shape)
-            item.setSelected(True)
+            item = self.labelList.findItemByShape(shape)
+            self.labelList.selectItem(item)
             self.labelList.scrollToItem(item)
         self._noSelectionSlot = False
         n_selected = len(selected_shapes)
@@ -971,16 +968,8 @@ class MainWindow(QtWidgets.QMainWindow):
             text = shape.label
         else:
             text = '{} ({})'.format(shape.label, shape.group_id)
-        item = QtWidgets.QListWidgetItem()
-        item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
-        item.setCheckState(Qt.Checked)
-        self.labelList.itemsToShapes.append((item, shape))
-        self.labelList.addItem(item)
-        qlabel = QtWidgets.QLabel()
-        qlabel.setText(text)
-        qlabel.setAlignment(QtCore.Qt.AlignBottom)
-        item.setSizeHint(qlabel.sizeHint())
-        self.labelList.setItemWidget(item, qlabel)
+        label_list_item = LabelListWidgetItem(text, shape)
+        self.labelList.addItem(label_list_item)
         if not self.uniqLabelList.findItemsByLabel(shape.label):
             item = self.uniqLabelList.createItemFromLabel(shape.label)
             self.uniqLabelList.addItem(item)
@@ -995,7 +984,7 @@ class MainWindow(QtWidgets.QMainWindow):
             return
 
         r, g, b = rgb
-        qlabel.setText(
+        label_list_item.setText(
             '{} <font color="#{:02x}{:02x}{:02x}">●</font>'
             .format(text, r, g, b)
         )
@@ -1021,7 +1010,7 @@ class MainWindow(QtWidgets.QMainWindow):
 
     def remLabels(self, shapes):
         for shape in shapes:
-            item = self.labelList.get_item_from_shape(shape)
+            item = self.labelList.findItemByShape(shape)
             self.labelList.takeItem(self.labelList.row(item))
 
     def loadShapes(self, shapes, replace=True):
@@ -1086,7 +1075,7 @@ class MainWindow(QtWidgets.QMainWindow):
             ))
             return data
 
-        shapes = [format_shape(shape) for shape in self.labelList.shapes]
+        shapes = [format_shape(item.shape()) for item in self.labelList]
         flags = {}
         for i in range(self.flag_widget.count()):
             item = self.flag_widget.item(i)
@@ -1140,17 +1129,20 @@ class MainWindow(QtWidgets.QMainWindow):
         if self.canvas.editing():
             selected_shapes = []
             for item in self.labelList.selectedItems():
-                shape = self.labelList.get_shape_from_item(item)
-                selected_shapes.append(shape)
+                selected_shapes.append(item.shape())
             if selected_shapes:
                 self.canvas.selectShapes(selected_shapes)
             else:
                 self.canvas.deSelectShape()
 
     def labelItemChanged(self, item):
-        shape = self.labelList.get_shape_from_item(item)
+        shape = item.shape()
         self.canvas.setShapeVisible(shape, item.checkState() == Qt.Checked)
 
+    def labelOrderChanged(self):
+        self.setDirty()
+        self.canvas.loadShapes([item.shape() for item in self.labelList])
+
     # Callback functions:
 
     def newShape(self):
@@ -1247,7 +1239,7 @@ class MainWindow(QtWidgets.QMainWindow):
         self.adjustScale()
 
     def togglePolygons(self, value):
-        for item, shape in self.labelList.itemsToShapes:
+        for item in self.labelList:
             item.setCheckState(Qt.Checked if value else Qt.Unchecked)
 
     def loadFile(self, filename=None):
@@ -1327,7 +1319,7 @@ class MainWindow(QtWidgets.QMainWindow):
             if self.labelFile.flags is not None:
                 flags.update(self.labelFile.flags)
         self.loadFlags(flags)
-        if self._config['keep_prev'] and not self.labelList.shapes:
+        if self._config['keep_prev'] and self.noShapes():
             self.loadShapes(prev_shapes, replace=False)
             self.setDirty()
         else:
@@ -1600,7 +1592,7 @@ class MainWindow(QtWidgets.QMainWindow):
 
     # Message Dialogs. #
     def hasLabels(self):
-        if not self.labelList.itemsToShapes:
+        if self.noShapes():
             self.errorMessage(
                 'No objects labeled',
                 'You must label at least one object to save the file.')

+ 2 - 1
labelme/widgets/__init__.py

@@ -7,7 +7,8 @@ from .color_dialog import ColorDialog
 from .label_dialog import LabelDialog
 from .label_dialog import LabelQLineEdit
 
-from .label_qlist_widget import LabelQListWidget
+from .label_list_widget import LabelListWidget
+from .label_list_widget import LabelListWidgetItem
 
 from .tool_bar import ToolBar
 

+ 173 - 0
labelme/widgets/label_list_widget.py

@@ -0,0 +1,173 @@
+from qtpy import QtCore
+from qtpy.QtCore import Qt
+from qtpy import QtGui
+from qtpy.QtGui import QPalette
+from qtpy import QtWidgets
+from qtpy.QtWidgets import QStyle
+
+
+# https://stackoverflow.com/a/2039745/4158863
+class HTMLDelegate(QtWidgets.QStyledItemDelegate):
+    def __init__(self, parent=None):
+        super(HTMLDelegate, self).__init__()
+        self.doc = QtGui.QTextDocument(self)
+
+    def paint(self, painter, option, index):
+        painter.save()
+
+        options = QtWidgets.QStyleOptionViewItem(option)
+
+        self.initStyleOption(options, index)
+        self.doc.setHtml(options.text)
+        options.text = ""
+
+        style = (
+            QtWidgets.QApplication.style()
+            if options.widget is None
+            else options.widget.style()
+        )
+        style.drawControl(QStyle.CE_ItemViewItem, options, painter)
+
+        ctx = QtGui.QAbstractTextDocumentLayout.PaintContext()
+
+        if option.state & QStyle.State_Selected:
+            ctx.palette.setColor(
+                QPalette.Text,
+                option.palette.color(
+                    QPalette.Active, QPalette.HighlightedText
+                ),
+            )
+        else:
+            ctx.palette.setColor(
+                QPalette.Text,
+                option.palette.color(QPalette.Active, QPalette.Text),
+            )
+
+        textRect = style.subElementRect(QStyle.SE_ItemViewItemText, options)
+
+        if index.column() != 0:
+            textRect.adjust(5, 0, 0, 0)
+
+        thefuckyourshitup_constant = 4
+        margin = (option.rect.height() - options.fontMetrics.height()) // 2
+        margin = margin - thefuckyourshitup_constant
+        textRect.setTop(textRect.top() + margin)
+
+        painter.translate(textRect.topLeft())
+        painter.setClipRect(textRect.translated(-textRect.topLeft()))
+        self.doc.documentLayout().draw(painter, ctx)
+
+        painter.restore()
+
+    def sizeHint(self, option, index):
+        thefuckyourshitup_constant = 4
+        return QtCore.QSize(
+            self.doc.idealWidth(),
+            self.doc.size().height() - thefuckyourshitup_constant,
+        )
+
+
+class LabelListWidgetItem(QtGui.QStandardItem):
+    def __init__(self, text=None, shape=None):
+        super(LabelListWidgetItem, self).__init__()
+        self.setText(text)
+        self.setShape(shape)
+
+        self.setCheckable(True)
+        self.setCheckState(Qt.Checked)
+        self.setEditable(False)
+        self.setTextAlignment(Qt.AlignBottom)
+
+    def clone(self):
+        return LabelListWidgetItem(self.text(), self.shape())
+
+    def setShape(self, shape):
+        self.setData(shape, Qt.UserRole)
+
+    def shape(self):
+        return self.data(Qt.UserRole)
+
+    def __hash__(self):
+        return id(self)
+
+    def __repr__(self):
+        return '{}("{}")'.format(self.__class__.__name__, self.text())
+
+
+class LabelListWidget(QtWidgets.QListView):
+
+    itemDoubleClicked = QtCore.Signal(LabelListWidgetItem)
+    itemSelectionChanged = QtCore.Signal(list, list)
+    itemDropped = QtCore.Signal()
+
+    def __init__(self):
+        super(LabelListWidget, self).__init__()
+        self._selectedItems = []
+
+        self.setWindowFlags(Qt.Window)
+        self.setModel(QtGui.QStandardItemModel())
+        self.model().setItemPrototype(LabelListWidgetItem())
+        self.setItemDelegate(HTMLDelegate())
+        self.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection)
+        self.setDragDropMode(QtWidgets.QAbstractItemView.InternalMove)
+        self.setDefaultDropAction(Qt.MoveAction)
+
+        self.doubleClicked.connect(self.itemDoubleClickedEvent)
+        self.selectionModel().selectionChanged.connect(
+            self.itemSelectionChangedEvent
+        )
+
+    def __len__(self):
+        return self.model().rowCount()
+
+    def __getitem__(self, i):
+        return self.model().item(i)
+
+    def __iter__(self):
+        for i in range(len(self)):
+            yield self[i]
+
+    def dropEvent(self, event):
+        super(LabelListWidget, self).dropEvent(event)
+        self.itemDropped.emit()
+
+    @property
+    def itemChanged(self):
+        return self.model().itemChanged
+
+    def itemSelectionChangedEvent(self, selected, deselected):
+        selected = [self.model().itemFromIndex(i) for i in selected.indexes()]
+        deselected = [
+            self.model().itemFromIndex(i) for i in deselected.indexes()
+        ]
+        self.itemSelectionChanged.emit(selected, deselected)
+
+    def itemDoubleClickedEvent(self, index):
+        self.itemDoubleClicked.emit(self.model().itemFromIndex(index))
+
+    def selectedItems(self):
+        return [self.model().itemFromIndex(i) for i in self.selectedIndexes()]
+
+    def scrollToItem(self, item):
+        self.scrollTo(self.model().indexFromItem(item))
+
+    def addItem(self, item):
+        if not isinstance(item, LabelListWidgetItem):
+            raise TypeError("item must be LabelListWidgetItem")
+        self.model().setItem(self.model().rowCount(), 0, item)
+        item.setSizeHint(self.itemDelegate().sizeHint(None, None))
+
+    def selectItem(self, item):
+        index = self.model().indexFromItem(item)
+        self.selectionModel().select(
+            index, QtCore.QItemSelectionModel.SelectionFlag.Select
+        )
+
+    def findItemByShape(self, shape):
+        for row in range(self.model().rowCount()):
+            item = self.model().item(row, 0)
+            if item.shape() == shape:
+                return item
+
+    def clear(self):
+        self.model().clear()

+ 0 - 48
labelme/widgets/label_qlist_widget.py

@@ -1,48 +0,0 @@
-from qtpy import QtWidgets
-
-from .escapable_qlist_widget import EscapableQListWidget
-
-
-class LabelQListWidget(EscapableQListWidget):
-
-    def __init__(self, *args, **kwargs):
-        super(LabelQListWidget, self).__init__(*args, **kwargs)
-        self.canvas = None
-        self.itemsToShapes = []
-        self.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection)
-
-    def get_shape_from_item(self, item):
-        for index, (item_, shape) in enumerate(self.itemsToShapes):
-            if item_ is item:
-                return shape
-
-    def get_item_from_shape(self, shape):
-        for index, (item, shape_) in enumerate(self.itemsToShapes):
-            if shape_ is shape:
-                return item
-
-    def clear(self):
-        super(LabelQListWidget, self).clear()
-        self.itemsToShapes = []
-
-    def setParent(self, parent):
-        self.parent = parent
-
-    def dropEvent(self, event):
-        shapes = self.shapes
-        super(LabelQListWidget, self).dropEvent(event)
-        if self.shapes == shapes:
-            return
-        if self.canvas is None:
-            raise RuntimeError('self.canvas must be set beforehand.')
-        self.parent.setDirty()
-        self.canvas.loadShapes(self.shapes)
-
-    @property
-    def shapes(self):
-        shapes = []
-        for i in range(self.count()):
-            item = self.item(i)
-            shape = self.get_shape_from_item(item)
-            shapes.append(shape)
-        return shapes

+ 17 - 0
tests/labelme_tests/widgets_tests/test_label_list_widget.py

@@ -0,0 +1,17 @@
+# -*- encoding: utf-8 -*-
+
+from labelme.widgets import LabelListWidget
+from labelme.widgets import LabelListWidgetItem
+
+
+def test_LabelListWidget(qtbot):
+    widget = LabelListWidget()
+
+    item = LabelListWidgetItem(text="person <font color='red'>●</fon>")
+    widget.addItem(item)
+    item = LabelListWidgetItem(text="dog <font color='blue'>●</fon>")
+    widget.addItem(item)
+
+    widget.show()
+    qtbot.addWidget(widget)
+    qtbot.waitForWindowShown(widget)