Переглянути джерело

Introduce shape_type field to enhance rectangle annotation

Kentaro Wada 6 роки тому
батько
коміт
00cfa6371c
4 змінених файлів з 65 додано та 31 видалено
  1. 4 3
      labelme/app.py
  2. 7 1
      labelme/label_file.py
  3. 50 16
      labelme/shape.py
  4. 4 11
      labelme/widgets/canvas.py

+ 4 - 3
labelme/app.py

@@ -776,8 +776,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 
     def loadLabels(self, shapes):
         s = []
-        for label, points, line_color, fill_color in shapes:
-            shape = Shape(label=label)
+        for label, points, line_color, fill_color, shape_type in shapes:
+            shape = Shape(label=label, shape_type=shape_type)
             for x, y in points:
                 shape.addPoint(QtCore.QPoint(x, y))
             shape.close()
@@ -805,7 +805,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
                         if s.line_color != self.lineColor else None,
                         fill_color=s.fill_color.getRgb()
                         if s.fill_color != self.fillColor else None,
-                        points=[(p.x(), p.y()) for p in s.points])
+                        points=[(p.x(), p.y()) for p in s.points],
+                        shape_type=s.shape_type)
 
         shapes = [format_shape(shape) for shape in self.labelList.shapes]
         flags = {}

+ 7 - 1
labelme/label_file.py

@@ -48,7 +48,13 @@ class LabelFile(object):
             lineColor = data['lineColor']
             fillColor = data['fillColor']
             shapes = (
-                (s['label'], s['points'], s['line_color'], s['fill_color'])
+                (
+                    s['label'],
+                    s['points'],
+                    s['line_color'],
+                    s['fill_color'],
+                    s.get('shape_type'),
+                )
                 for s in data['shapes']
             )
         except Exception as e:

+ 50 - 16
labelme/shape.py

@@ -1,5 +1,6 @@
 import copy
 
+from qtpy import QtCore
 from qtpy import QtGui
 
 import labelme.utils
@@ -34,7 +35,7 @@ class Shape(object):
     point_size = 8
     scale = 1.0
 
-    def __init__(self, label=None, line_color=None):
+    def __init__(self, label=None, line_color=None, shape_type=None):
         self.label = label
         self.points = []
         self.fill = False
@@ -55,6 +56,20 @@ class Shape(object):
             # is used for drawing the pending line a different color.
             self.line_color = line_color
 
+        self.shape_type = shape_type
+
+    @property
+    def shape_type(self):
+        return self._shape_type
+
+    @shape_type.setter
+    def shape_type(self, value):
+        if value is None:
+            value = 'polygon'
+        if value not in ['polygon', 'rectangle', 'point', 'line']:
+            raise ValueError('Unexpected shape_type: {}'.format(value))
+        self._shape_type = value
+
     def close(self):
         self._closed = True
 
@@ -78,6 +93,11 @@ class Shape(object):
     def setOpen(self):
         self._closed = False
 
+    def getRectFromLine(self, pt1, pt2):
+        x1, y1 = pt1.x(), pt1.y()
+        x2, y2 = pt2.x(), pt2.y()
+        return QtCore.QRectF(x1, y1, x2 - x1, y2 - y1)
+
     def paint(self, painter):
         if self.points:
             color = self.select_line_color \
@@ -90,17 +110,25 @@ class Shape(object):
             line_path = QtGui.QPainterPath()
             vrtx_path = QtGui.QPainterPath()
 
-            line_path.moveTo(self.points[0])
-            # Uncommenting the following line will draw 2 paths
-            # for the 1st vertex, and make it non-filled, which
-            # may be desirable.
-            # self.drawVertex(vrtx_path, 0)
-
-            for i, p in enumerate(self.points):
-                line_path.lineTo(p)
-                self.drawVertex(vrtx_path, i)
-            if self.isClosed():
-                line_path.lineTo(self.points[0])
+            if self.shape_type == 'rectangle':
+                assert len(self.points) in [1, 2]
+                if len(self.points) == 2:
+                    rectangle = self.getRectFromLine(*self.points)
+                    line_path.addRect(rectangle)
+                for i in range(len(self.points)):
+                    self.drawVertex(vrtx_path, i)
+            else:
+                line_path.moveTo(self.points[0])
+                # Uncommenting the following line will draw 2 paths
+                # for the 1st vertex, and make it non-filled, which
+                # may be desirable.
+                # self.drawVertex(vrtx_path, 0)
+
+                for i, p in enumerate(self.points):
+                    line_path.lineTo(p)
+                    self.drawVertex(vrtx_path, i)
+                if self.isClosed():
+                    line_path.lineTo(self.points[0])
 
             painter.drawPath(line_path)
             painter.drawPath(vrtx_path)
@@ -153,9 +181,15 @@ class Shape(object):
         return self.makePath().contains(point)
 
     def makePath(self):
-        path = QtGui.QPainterPath(self.points[0])
-        for p in self.points[1:]:
-            path.lineTo(p)
+        if self.shape_type == 'rectangle':
+            path = QtGui.QPainterPath()
+            if len(self.points) == 2:
+                rectangle = self.getRectFromLine(*self.points)
+                path.addRect(rectangle)
+        else:
+            path = QtGui.QPainterPath(self.points[0])
+            for p in self.points[1:]:
+                path.lineTo(p)
         return path
 
     def boundingRect(self):
@@ -175,7 +209,7 @@ class Shape(object):
         self._highlightIndex = None
 
     def copy(self):
-        shape = Shape(self.label)
+        shape = Shape(label=self.label, shape_type=self.shape_type)
         shape.points = [copy.deepcopy(p) for p in self.points]
         shape.fill = self.fill
         shape.selected = self.selected

+ 4 - 11
labelme/widgets/canvas.py

@@ -155,6 +155,8 @@ class Canvas(QtWidgets.QWidget):
 
         # Polygon drawing.
         if self.drawing():
+            self.line.shape_type = self.createMode
+
             self.overrideCursor(CURSOR_DRAW)
             if not self.current:
                 return
@@ -176,9 +178,7 @@ class Canvas(QtWidgets.QWidget):
                 self.line[0] = self.current[-1]
                 self.line[1] = pos
             elif self.createMode == 'rectangle':
-                self.line.points = list(self.getRectangleFromLine(
-                    (self.current[0], pos)
-                ))
+                self.line.points = [self.current[0], pos]
                 self.line.close()
             elif self.createMode == 'line':
                 self.line.points = [self.current[0], pos]
@@ -271,13 +271,6 @@ class Canvas(QtWidgets.QWidget):
         self.hVertex = index
         self.hEdge = None
 
-    def getRectangleFromLine(self, line):
-        pt1 = line[0]
-        pt3 = line[1]
-        pt2 = QtCore.QPoint(pt3.x(), pt1.y())
-        pt4 = QtCore.QPoint(pt1.x(), pt3.y())
-        return pt1, pt2, pt3, pt4
-
     def mousePressEvent(self, ev):
         if QT5:
             pos = self.transformPos(ev.pos())
@@ -298,7 +291,7 @@ class Canvas(QtWidgets.QWidget):
                         self.finalise()
                 elif not self.outOfPixmap(pos):
                     # Create new shape.
-                    self.current = Shape()
+                    self.current = Shape(shape_type=self.createMode)
                     self.current.addPoint(pos)
                     if self.createMode == 'point':
                         self.finalise()