Просмотр исходного кода

Add createMode="ai_mask" that generates shape_type="mask"

Kentaro Wada 1 год назад
Родитель
Сommit
897ea14e5b
5 измененных файлов с 151 добавлено и 28 удалено
  1. 26 1
      labelme/app.py
  2. 1 0
      labelme/config/default_config.yaml
  3. 4 0
      labelme/label_file.py
  4. 67 18
      labelme/shape.py
  5. 53 9
      labelme/widgets/canvas.py

+ 26 - 1
labelme/app.py

@@ -383,6 +383,21 @@ class MainWindow(QtWidgets.QMainWindow):
             if self.canvas.createMode == "ai_polygon"
             else None
         )
+        createAiMaskMode = action(
+            self.tr("Create AI-Mask"),
+            lambda: self.toggleDrawMode(False, createMode="ai_mask"),
+            None,
+            "objects",
+            self.tr("Start drawing ai_mask. Ctrl+LeftClick ends creation."),
+            enabled=False,
+        )
+        createAiMaskMode.changed.connect(
+            lambda: self.canvas.initializeAiModel(
+                name=self._selectAiModelComboBox.currentText()
+            )
+            if self.canvas.createMode == "ai_mask"
+            else None
+        )
         editMode = action(
             self.tr("Edit Polygons"),
             self.setEditMode,
@@ -627,6 +642,7 @@ class MainWindow(QtWidgets.QMainWindow):
             createPointMode=createPointMode,
             createLineStripMode=createLineStripMode,
             createAiPolygonMode=createAiPolygonMode,
+            createAiMaskMode=createAiMaskMode,
             zoom=zoom,
             zoomIn=zoomIn,
             zoomOut=zoomOut,
@@ -662,6 +678,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 createPointMode,
                 createLineStripMode,
                 createAiPolygonMode,
+                createAiMaskMode,
                 editMode,
                 edit,
                 duplicate,
@@ -681,6 +698,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 createPointMode,
                 createLineStripMode,
                 createAiPolygonMode,
+                createAiMaskMode,
                 editMode,
                 brightnessContrast,
             ),
@@ -773,7 +791,7 @@ class MainWindow(QtWidgets.QMainWindow):
             lambda: self.canvas.initializeAiModel(
                 name=self._selectAiModelComboBox.currentText()
             )
-            if self.canvas.createMode == "ai_polygon"
+            if self.canvas.createMode in ["ai_polygon", "ai_mask"]
             else None
         )
 
@@ -900,6 +918,7 @@ class MainWindow(QtWidgets.QMainWindow):
             self.actions.createPointMode,
             self.actions.createLineStripMode,
             self.actions.createAiPolygonMode,
+            self.actions.createAiMaskMode,
             self.actions.editMode,
         )
         utils.addActions(self.menus.edit, actions + self.actions.editMenu)
@@ -932,6 +951,7 @@ class MainWindow(QtWidgets.QMainWindow):
         self.actions.createPointMode.setEnabled(True)
         self.actions.createLineStripMode.setEnabled(True)
         self.actions.createAiPolygonMode.setEnabled(True)
+        self.actions.createAiMaskMode.setEnabled(True)
         title = __appname__
         if self.filename is not None:
             title = "{} - {}".format(title, self.filename)
@@ -1008,6 +1028,7 @@ class MainWindow(QtWidgets.QMainWindow):
             "line": self.actions.createLineMode,
             "linestrip": self.actions.createLineStripMode,
             "ai_polygon": self.actions.createAiPolygonMode,
+            "ai_mask": self.actions.createAiMaskMode,
         }
 
         self.canvas.setEditing(edit)
@@ -1232,6 +1253,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 shape_type=shape_type,
                 group_id=group_id,
                 description=description,
+                mask=shape["mask"],
             )
             for x, y in points:
                 shape.addPoint(QtCore.QPointF(x, y))
@@ -1271,6 +1293,9 @@ class MainWindow(QtWidgets.QMainWindow):
                     description=s.description,
                     shape_type=s.shape_type,
                     flags=s.flags,
+                    mask=None
+                    if s.mask is None
+                    else utils.img_arr_to_b64(s.mask),
                 )
             )
             return data

+ 1 - 0
labelme/config/default_config.yaml

@@ -77,6 +77,7 @@ canvas:
     point: false
     linestrip: false
     ai_polygon: false
+    ai_mask: false
 
 shortcuts:
   close: Ctrl+W

+ 4 - 0
labelme/label_file.py

@@ -84,6 +84,7 @@ class LabelFile(object):
             "shape_type",
             "flags",
             "description",
+            "mask",
         ]
         try:
             with open(filename, "r") as f:
@@ -112,6 +113,9 @@ class LabelFile(object):
                     flags=s.get("flags", {}),
                     description=s.get("description"),
                     group_id=s.get("group_id"),
+                    mask=utils.img_b64_to_arr(s["mask"])
+                    if s.get("mask")
+                    else None,
                     other_data={
                         k: v for k, v in s.items() if k not in shape_keys
                     },

+ 67 - 18
labelme/shape.py

@@ -1,8 +1,10 @@
 import copy
 import math
 
+import numpy as np
 from qtpy import QtCore
 from qtpy import QtGui
+import skimage.measure
 
 from labelme.logger import logger
 import labelme.utils
@@ -45,6 +47,7 @@ class Shape(object):
         flags=None,
         group_id=None,
         description=None,
+        mask=None,
     ):
         self.label = label
         self.group_id = group_id
@@ -60,6 +63,7 @@ class Shape(object):
         self.flags = flags
         self.description = description
         self.other_data = {}
+        self.mask = mask
 
         self._highlightIndex = None
         self._highlightMode = self.NEAR_VERTEX
@@ -76,16 +80,17 @@ class Shape(object):
             # is used for drawing the pending line a different color.
             self.line_color = line_color
 
-    def setShapeRefined(self, points, point_labels, shape_type):
-        self._shape_raw = (self.points, self.point_labels, self.shape_type)
+    def setShapeRefined(self, shape_type, points, point_labels, mask=None):
+        self._shape_raw = (self.shape_type, self.points, self.point_labels)
+        self.shape_type = shape_type
         self.points = points
         self.point_labels = point_labels
-        self.shape_type = shape_type
+        self.mask = mask
 
     def restoreShapeRaw(self):
         if self._shape_raw is None:
             return
-        self.points, self.point_labels, self.shape_type = self._shape_raw
+        self.shape_type, self.points, self.point_labels = self._shape_raw
         self._shape_raw = None
 
     @property
@@ -104,6 +109,7 @@ class Shape(object):
             "circle",
             "linestrip",
             "points",
+            "mask",
         ]:
             raise ValueError("Unexpected shape_type: {}".format(value))
         self._shape_type = value
@@ -171,26 +177,56 @@ class Shape(object):
         return QtCore.QRectF(x1, y1, x2 - x1, y2 - y1)
 
     def paint(self, painter):
-        if self.points:
-            color = (
-                self.select_line_color if self.selected else self.line_color
+        if self.mask is None and not self.points:
+            return
+
+        color = self.select_line_color if self.selected else self.line_color
+        pen = QtGui.QPen(color)
+        # Try using integer sizes for smoother drawing(?)
+        pen.setWidth(max(1, int(round(2.0 / self.scale))))
+        painter.setPen(pen)
+
+        if self.mask is not None:
+            image_to_draw = np.zeros(self.mask.shape + (4,), dtype=np.uint8)
+            fill_color = (
+                self.select_fill_color.getRgb()
+                if self.selected
+                else self.fill_color.getRgb()
             )
-            pen = QtGui.QPen(color)
-            # Try using integer sizes for smoother drawing(?)
-            pen.setWidth(max(1, int(round(2.0 / self.scale))))
-            painter.setPen(pen)
+            image_to_draw[self.mask] = fill_color
+            qimage = QtGui.QImage.fromData(
+                labelme.utils.img_arr_to_data(image_to_draw)
+            )
+            painter.drawImage(
+                int(round(self.points[0].x())),
+                int(round(self.points[0].y())),
+                qimage,
+            )
+
+            line_path = QtGui.QPainterPath()
+            contours = skimage.measure.find_contours(
+                np.pad(self.mask, pad_width=1)
+            )
+            for contour in contours:
+                contour += [self.points[0].y(), self.points[0].x()]
+                line_path.moveTo(contour[0, 1], contour[0, 0])
+                for point in contour[1:]:
+                    line_path.lineTo(point[1], point[0])
+            painter.drawPath(line_path)
 
+        if self.points:
             line_path = QtGui.QPainterPath()
             vrtx_path = QtGui.QPainterPath()
             negative_vrtx_path = QtGui.QPainterPath()
 
-            if self.shape_type == "rectangle":
+            if self.shape_type in ["rectangle", "mask"]:
                 assert len(self.points) in [1, 2]
                 if len(self.points) == 2:
                     rectangle = self.getRectFromLine(*self.points)
                     line_path.addRect(rectangle)
-                for i in range(len(self.points)):
-                    self.drawVertex(vrtx_path, i)
+                if self.shape_type == "rectangle":
+                    for i in range(len(self.points)):
+                        self.drawVertex(vrtx_path, i)
             elif self.shape_type == "circle":
                 assert len(self.points) in [1, 2]
                 if len(self.points) == 2:
@@ -226,9 +262,10 @@ class Shape(object):
                     line_path.lineTo(self.points[0])
 
             painter.drawPath(line_path)
-            painter.drawPath(vrtx_path)
-            painter.fillPath(vrtx_path, self._vertex_fill_color)
-            if self.fill:
+            if vrtx_path.length() > 0:
+                painter.drawPath(vrtx_path)
+                painter.fillPath(vrtx_path, self._vertex_fill_color)
+            if self.fill and self.mask is None:
                 color = (
                     self.select_fill_color
                     if self.selected
@@ -281,6 +318,18 @@ class Shape(object):
         return post_i
 
     def containsPoint(self, point):
+        if self.mask is not None:
+            y = np.clip(
+                int(round(point.y() - self.points[0].y())),
+                0,
+                self.mask.shape[0] - 1,
+            )
+            x = np.clip(
+                int(round(point.x() - self.points[0].x())),
+                0,
+                self.mask.shape[1] - 1,
+            )
+            return self.mask[y, x]
         return self.makePath().contains(point)
 
     def getCircleRectFromLine(self, line):
@@ -294,7 +343,7 @@ class Shape(object):
         return rectangle
 
     def makePath(self):
-        if self.shape_type == "rectangle":
+        if self.shape_type in ["rectangle", "mask"]:
             path = QtGui.QPainterPath()
             if len(self.points) == 2:
                 rectangle = self.getRectFromLine(*self.points)

+ 53 - 9
labelme/widgets/canvas.py

@@ -1,4 +1,5 @@
 import gdown
+import imgviz
 from qtpy import QtCore
 from qtpy import QtGui
 from qtpy import QtWidgets
@@ -60,6 +61,7 @@ class Canvas(QtWidgets.QWidget):
                 "point": False,
                 "linestrip": False,
                 "ai_polygon": False,
+                "ai_mask": False,
             },
         )
         super(Canvas, self).__init__(*args, **kwargs)
@@ -125,6 +127,7 @@ class Canvas(QtWidgets.QWidget):
             "point",
             "linestrip",
             "ai_polygon",
+            "ai_mask",
         ]:
             raise ValueError("Unsupported createMode: %s" % value)
         self._createMode = value
@@ -249,7 +252,7 @@ class Canvas(QtWidgets.QWidget):
 
         # Polygon drawing.
         if self.drawing():
-            if self.createMode == "ai_polygon":
+            if self.createMode in ["ai_polygon", "ai_mask"]:
                 self.line.shape_type = "points"
             else:
                 self.line.shape_type = self.createMode
@@ -277,7 +280,7 @@ class Canvas(QtWidgets.QWidget):
             if self.createMode in ["polygon", "linestrip"]:
                 self.line.points = [self.current[-1], pos]
                 self.line.point_labels = [1, 1]
-            elif self.createMode == "ai_polygon":
+            elif self.createMode in ["ai_polygon", "ai_mask"]:
                 self.line.points = [self.current.points[-1], pos]
                 self.line.point_labels = [
                     self.current.point_labels[-1],
@@ -434,7 +437,7 @@ class Canvas(QtWidgets.QWidget):
                         self.line[0] = self.current[-1]
                         if int(ev.modifiers()) == QtCore.Qt.ControlModifier:
                             self.finalise()
-                    elif self.createMode == "ai_polygon":
+                    elif self.createMode in ["ai_polygon", "ai_mask"]:
                         self.current.addPoint(
                             self.line.points[1],
                             label=self.line.point_labels[1],
@@ -449,7 +452,7 @@ class Canvas(QtWidgets.QWidget):
                     # Create new shape.
                     self.current = Shape(
                         shape_type="points"
-                        if self.createMode == "ai_polygon"
+                        if self.createMode in ["ai_polygon", "ai_mask"]
                         else self.createMode
                     )
                     self.current.addPoint(
@@ -458,7 +461,7 @@ class Canvas(QtWidgets.QWidget):
                     if self.createMode == "point":
                         self.finalise()
                     elif (
-                        self.createMode == "ai_polygon"
+                        self.createMode in ["ai_polygon", "ai_mask"]
                         and ev.modifiers() & QtCore.Qt.ControlModifier
                     ):
                         self.finalise()
@@ -467,7 +470,7 @@ class Canvas(QtWidgets.QWidget):
                             self.current.shape_type = "circle"
                         self.line.points = [pos, pos]
                         if (
-                            self.createMode == "ai_polygon"
+                            self.createMode in ["ai_polygon", "ai_mask"]
                             and is_shift_pressed
                         ):
                             self.line.point_labels = [0, 0]
@@ -569,7 +572,7 @@ class Canvas(QtWidgets.QWidget):
 
         if (
             self.createMode == "polygon" and self.canCloseShape()
-        ) or self.createMode == "ai_polygon":
+        ) or self.createMode in ["ai_polygon", "ai_mask"]:
             self.finalise()
 
     def selectShapes(self, shapes):
@@ -770,7 +773,6 @@ class Canvas(QtWidgets.QWidget):
                 point=self.line.points[1],
                 label=self.line.point_labels[1],
             )
-            drawing_shape.selected = True
             points = self._ai_model.predict_polygon_from_points(
                 points=[
                     [point.x(), point.y()] for point in drawing_shape.points
@@ -779,14 +781,38 @@ class Canvas(QtWidgets.QWidget):
             )
             if len(points) > 2:
                 drawing_shape.setShapeRefined(
+                    shape_type="polygon",
                     points=[
                         QtCore.QPointF(point[0], point[1]) for point in points
                     ],
                     point_labels=[1] * len(points),
-                    shape_type="polygon",
                 )
                 drawing_shape.fill = self.fillDrawing()
+                drawing_shape.selected = True
                 drawing_shape.paint(p)
+        elif self.createMode == "ai_mask" and self.current is not None:
+            drawing_shape = self.current.copy()
+            drawing_shape.addPoint(
+                point=self.line.points[1],
+                label=self.line.point_labels[1],
+            )
+            mask = self._ai_model.predict_mask_from_points(
+                points=[
+                    [point.x(), point.y()] for point in drawing_shape.points
+                ],
+                point_labels=drawing_shape.point_labels,
+            )
+            y1, x1, y2, x2 = imgviz.instances.mask_to_bbox([mask])[0].astype(
+                int
+            )
+            drawing_shape.setShapeRefined(
+                shape_type="mask",
+                points=[QtCore.QPointF(x1, y1), QtCore.QPointF(x2, y2)],
+                point_labels=[1, 1],
+                mask=mask[y1:y2, x1:x2],
+            )
+            drawing_shape.selected = True
+            drawing_shape.paint(p)
 
         p.end()
 
@@ -825,6 +851,24 @@ class Canvas(QtWidgets.QWidget):
                 point_labels=[1] * len(points),
                 shape_type="polygon",
             )
+        elif self.createMode == "ai_mask":
+            # convert points to mask by an AI model
+            assert self.current.shape_type == "points"
+            mask = self._ai_model.predict_mask_from_points(
+                points=[
+                    [point.x(), point.y()] for point in self.current.points
+                ],
+                point_labels=self.current.point_labels,
+            )
+            y1, x1, y2, x2 = imgviz.instances.mask_to_bbox([mask])[0].astype(
+                int
+            )
+            self.current.setShapeRefined(
+                shape_type="mask",
+                points=[QtCore.QPointF(x1, y1), QtCore.QPointF(x2, y2)],
+                point_labels=[1, 1],
+                mask=mask[y1:y2, x1:x2],
+            )
         self.current.close()
 
         self.shapes.append(self.current)