Browse Source

Support negative point label for SAM

Kentaro Wada 2 năm trước cách đây
mục cha
commit
378ab6e572
3 tập tin đã thay đổi với 79 bổ sung17 xóa
  1. 4 3
      labelme/ai/models/segment_anything.py
  2. 27 5
      labelme/shape.py
  3. 48 9
      labelme/widgets/canvas.py

+ 4 - 3
labelme/ai/models/segment_anything.py

@@ -37,7 +37,7 @@ class SegmentAnythingModel:
             )
         return self._image_embedding
 
-    def points_to_polygon_callback(self, points):
+    def points_to_polygon_callback(self, points, point_labels):
         self._thread.join()
         image_embedding = self.get_image_embedding()
 
@@ -47,6 +47,7 @@ class SegmentAnythingModel:
             image=self._image,
             image_embedding=image_embedding,
             points=points,
+            point_labels=point_labels,
         )
         return polygon
 
@@ -86,10 +87,10 @@ def _get_contour_length(contour):
 
 
 def compute_polygon_from_points(
-    image_size, decoder_session, image, image_embedding, points
+    image_size, decoder_session, image, image_embedding, points, point_labels
 ):
     input_point = np.array(points, dtype=np.float32)
-    input_label = np.ones(len(input_point), dtype=np.int32)
+    input_label = np.array(point_labels, dtype=np.int32)
 
     onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
         None, :, :

+ 27 - 5
labelme/shape.py

@@ -57,6 +57,7 @@ class Shape(object):
         self.label = label
         self.group_id = group_id
         self.points = []
+        self.point_labels = []
         self.shape_type = shape_type
         self._shape_raw = None
         self._points_raw = []
@@ -83,15 +84,16 @@ class Shape(object):
             # is used for drawing the pending line a different color.
             self.line_color = line_color
 
-    def setShapeRefined(self, points, shape_type):
-        self._shape_raw = (self.points, self.shape_type)
+    def setShapeRefined(self, points, point_labels, shape_type):
+        self._shape_raw = (self.points, self.point_labels, self.shape_type)
         self.points = points
+        self.point_labels = point_labels
         self.shape_type = shape_type
 
     def restoreShapeRaw(self):
         if self._shape_raw is None:
             return
-        self.points, self.shape_type = self._shape_raw
+        self.points, self.point_labels, self.shape_type = self._shape_raw
         self._shape_raw = None
 
     @property
@@ -117,22 +119,26 @@ class Shape(object):
     def close(self):
         self._closed = True
 
-    def addPoint(self, point):
+    def addPoint(self, point, label=1):
         if self.points and point == self.points[0]:
             self.close()
         else:
             self.points.append(point)
+            self.point_labels.append(label)
 
     def canAddPoint(self):
         return self.shape_type in ["polygon", "linestrip"]
 
     def popPoint(self):
         if self.points:
+            if self.point_labels:
+                self.point_labels.pop()
             return self.points.pop()
         return None
 
-    def insertPoint(self, i, point):
+    def insertPoint(self, i, point, label=1):
         self.points.insert(i, point)
+        self.point_labels.insert(i, label)
 
     def removePoint(self, i):
         if not self.canAddPoint():
@@ -159,6 +165,7 @@ class Shape(object):
             return
 
         self.points.pop(i)
+        self.point_labels.pop(i)
 
     def isClosed(self):
         return self._closed
@@ -183,6 +190,7 @@ class Shape(object):
 
             line_path = QtGui.QPainterPath()
             vrtx_path = QtGui.QPainterPath()
+            negative_vrtx_path = QtGui.QPainterPath()
 
             if self.shape_type == "rectangle":
                 assert len(self.points) in [1, 2]
@@ -203,6 +211,15 @@ class Shape(object):
                 for i, p in enumerate(self.points):
                     line_path.lineTo(p)
                     self.drawVertex(vrtx_path, i)
+            elif self.shape_type == "points":
+                assert len(self.points) == len(self.point_labels)
+                for i, (p, l) in enumerate(
+                    zip(self.points, self.point_labels)
+                ):
+                    if l == 1:
+                        self.drawVertex(vrtx_path, i)
+                    else:
+                        self.drawVertex(negative_vrtx_path, i)
             else:
                 line_path.moveTo(self.points[0])
                 # Uncommenting the following line will draw 2 paths
@@ -227,6 +244,11 @@ class Shape(object):
                 )
                 painter.fillPath(line_path, color)
 
+            pen.setColor(QtGui.QColor(255, 0, 0, 255))
+            painter.setPen(pen)
+            painter.drawPath(negative_vrtx_path)
+            painter.fillPath(negative_vrtx_path, QtGui.QColor(255, 0, 0, 255))
+
     def drawVertex(self, path, i):
         d = self.point_size / self.scale
         shape = self.point_type

+ 48 - 9
labelme/widgets/canvas.py

@@ -6,6 +6,7 @@ from qtpy import QtWidgets
 import labelme.ai
 from labelme import QT5
 from labelme.shape import Shape
+from labelme.shape import Shape
 import labelme.utils
 
 
@@ -223,6 +224,8 @@ class Canvas(QtWidgets.QWidget):
         self.prevMovePoint = pos
         self.restoreCursor()
 
+        is_shift_pressed = int(ev.modifiers()) == QtCore.Qt.ShiftModifier
+
         # Polygon drawing.
         if self.drawing():
             if self.createMode == "ai_polygon":
@@ -250,21 +253,32 @@ class Canvas(QtWidgets.QWidget):
                 pos = self.current[0]
                 self.overrideCursor(CURSOR_POINT)
                 self.current.highlightVertex(0, Shape.NEAR_VERTEX)
-            if self.createMode in ["polygon", "linestrip", "ai_polygon"]:
-                self.line[0] = self.current[-1]
-                self.line[1] = pos
+            if self.createMode in ["polygon", "linestrip"]:
+                self.line.points = [self.current[-1], pos]
+                self.line.point_labels = [1, 1]
+            elif self.createMode == "ai_polygon":
+                self.line.points = [self.current.points[-1], pos]
+                self.line.point_labels = [
+                    self.current.point_labels[-1],
+                    0 if is_shift_pressed else 1,
+                ]
             elif self.createMode == "rectangle":
                 self.line.points = [self.current[0], pos]
+                self.line.point_labels = [1, 1]
                 self.line.close()
             elif self.createMode == "circle":
                 self.line.points = [self.current[0], pos]
+                self.line.point_labels = [1, 1]
                 self.line.shape_type = "circle"
             elif self.createMode == "line":
                 self.line.points = [self.current[0], pos]
+                self.line.point_labels = [1, 1]
                 self.line.close()
             elif self.createMode == "point":
                 self.line.points = [self.current[0]]
+                self.line.point_labels = [1]
                 self.line.close()
+            assert len(self.line.points) == len(self.line.point_labels)
             self.repaint()
             self.current.highlightClear()
             return
@@ -378,6 +392,9 @@ class Canvas(QtWidgets.QWidget):
             pos = self.transformPos(ev.localPos())
         else:
             pos = self.transformPos(ev.posF())
+
+        is_shift_pressed = int(ev.modifiers()) == QtCore.Qt.ShiftModifier
+
         if ev.button() == QtCore.Qt.LeftButton:
             if self.drawing():
                 if self.current:
@@ -391,11 +408,22 @@ class Canvas(QtWidgets.QWidget):
                         assert len(self.current.points) == 1
                         self.current.points = self.line.points
                         self.finalise()
-                    elif self.createMode in ["linestrip", "ai_polygon"]:
+                    elif self.createMode == "linestrip":
                         self.current.addPoint(self.line[1])
                         self.line[0] = self.current[-1]
                         if int(ev.modifiers()) == QtCore.Qt.ControlModifier:
                             self.finalise()
+                    elif self.createMode == "ai_polygon":
+                        self.current.addPoint(
+                            self.line.points[1],
+                            label=self.line.point_labels[1],
+                        )
+                        self.line.points[0] = self.current.points[-1]
+                        self.line.point_labels[0] = self.current.point_labels[
+                            -1
+                        ]
+                        if int(ev.modifiers()) == QtCore.Qt.ControlModifier:
+                            self.finalise()
                 elif not self.outOfPixmap(pos):
                     # Create new shape.
                     self.current = Shape(
@@ -403,13 +431,22 @@ class Canvas(QtWidgets.QWidget):
                         if self.createMode == "ai_polygon"
                         else self.createMode
                     )
-                    self.current.addPoint(pos)
+                    self.current.addPoint(
+                        pos, label=0 if is_shift_pressed else 1
+                    )
                     if self.createMode == "point":
                         self.finalise()
                     else:
                         if self.createMode == "circle":
                             self.current.shape_type = "circle"
                         self.line.points = [pos, pos]
+                        if (
+                            self.createMode == "ai_polygon"
+                            and is_shift_pressed
+                        ):
+                            self.line.point_labels = [0, 0]
+                        else:
+                            self.line.point_labels = [1, 1]
                         self.setHiding()
                         self.drawingPolygon.emit(True)
                         self.update()
@@ -681,6 +718,7 @@ class Canvas(QtWidgets.QWidget):
                 shape.paint(p)
         if self.current:
             self.current.paint(p)
+            assert len(self.line.points) == len(self.line.point_labels)
             self.line.paint(p)
         if self.selectedShapesCopy:
             for s in self.selectedShapesCopy:
@@ -722,15 +760,16 @@ class Canvas(QtWidgets.QWidget):
             # convert points to polygon by an AI model
             assert self.current.shape_type == "points"
             points = self._ai_callback(
-                points=np.array(
-                    [[point.x(), point.y()] for point in self.current.points],
-                    dtype=np.float32,
-                )
+                points=[
+                    [point.x(), point.y()] for point in self.current.points
+                ],
+                point_labels=self.current.point_labels,
             )
             self.current.setShapeRefined(
                 points=[
                     QtCore.QPointF(point[0], point[1]) for point in points
                 ],
+                point_labels=[1] * len(points),
                 shape_type="polygon",
             )
         self.current.close()