Prechádzať zdrojové kódy

Introduce Segment-Anything model for crateMode="ai_polygon"

Kentaro Wada 2 rokov pred
rodič
commit
ac3a7c63b4

+ 1 - 0
labelme/ai/__init__.py

@@ -0,0 +1 @@
+from .models.segment_anything import SegmentAnythingModel

+ 0 - 0
labelme/ai/models/__init__.py


+ 145 - 0
labelme/ai/models/segment_anything.py

@@ -0,0 +1,145 @@
+import threading
+
+import imgviz
+import numpy as np
+import onnxruntime
+import PIL.Image
+import skimage.measure
+
+
+class SegmentAnythingModel:
+    def __init__(self):
+        self._image_size = 1024
+
+        # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx"  # NOQA
+        # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx"  # NOQA
+        encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx"  # NOQA
+        decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx"  # NOQA
+        # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx"  # NOQA
+        # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx"  # NOQA
+
+        self._encoder_session = onnxruntime.InferenceSession(encoder_path)
+        self._decoder_session = onnxruntime.InferenceSession(decoder_path)
+
+    def set_image(self, image: np.ndarray):
+        self._image = image
+        self._image_embedding = None
+
+        self._thread = threading.Thread(target=self.get_image_embedding)
+        self._thread.start()
+
+    def get_image_embedding(self):
+        if self._image_embedding is None:
+            self._image_embedding = compute_image_embedding(
+                image_size=self._image_size,
+                encoder_session=self._encoder_session,
+                image=self._image,
+            )
+        return self._image_embedding
+
+    def points_to_polygon_callback(self, points):
+        self._thread.join()
+        image_embedding = self.get_image_embedding()
+
+        polygon = compute_polygon_from_points(
+            image_size=self._image_size,
+            decoder_session=self._decoder_session,
+            image=self._image,
+            image_embedding=image_embedding,
+            points=points,
+        )
+        return polygon
+
+
+def compute_image_embedding(image_size, encoder_session, image):
+    assert image.shape[1] > image.shape[0]
+    scale = image_size / image.shape[1]
+    x = imgviz.resize(
+        image,
+        height=int(round(image.shape[0] * scale)),
+        width=image_size,
+        backend="pillow",
+    ).astype(np.float32)
+    x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
+        [58.395, 57.12, 57.375], dtype=np.float32
+    )
+    x = np.pad(
+        x,
+        (
+            (0, image_size - x.shape[0]),
+            (0, image_size - x.shape[1]),
+            (0, 0),
+        ),
+    )
+    x = x.transpose(2, 0, 1)[None, :, :, :]
+
+    output = encoder_session.run(output_names=None, input_feed={"x": x})
+    image_embedding = output[0]
+
+    return image_embedding
+
+
+def _get_contour_length(contour):
+    contour_start = contour
+    contour_end = np.r_[contour[1:], contour[0:1]]
+    return np.linalg.norm(contour_end - contour_start, axis=1).sum()
+
+
+def compute_polygon_from_points(
+    image_size, decoder_session, image, image_embedding, points
+):
+    input_point = np.array(points, dtype=np.float32)
+    input_label = np.ones(len(input_point), dtype=np.int32)
+
+    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
+        None, :, :
+    ]
+    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
+        None, :
+    ].astype(np.float32)
+
+    assert image.shape[1] > image.shape[0]
+    scale = image_size / image.shape[1]
+    new_height = int(round(image.shape[0] * scale))
+    new_width = image_size
+    onnx_coord = (
+        onnx_coord.astype(float)
+        * (new_width / image.shape[1], new_height / image.shape[0])
+    ).astype(np.float32)
+
+    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
+    onnx_has_mask_input = np.array([-1], dtype=np.float32)
+
+    decoder_inputs = {
+        "image_embeddings": image_embedding,
+        "point_coords": onnx_coord,
+        "point_labels": onnx_label,
+        "mask_input": onnx_mask_input,
+        "has_mask_input": onnx_has_mask_input,
+        "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
+    }
+
+    masks, _, _ = decoder_session.run(None, decoder_inputs)
+    mask = masks[0, 0]  # (1, 1, H, W) -> (H, W)
+    mask = mask > 0.0
+    if 0:
+        imgviz.io.imsave(
+            "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
+        )
+
+    contours = skimage.measure.find_contours(mask)
+    contour = max(contours, key=_get_contour_length)
+    polygon = skimage.measure.approximate_polygon(
+        coords=contour,
+        tolerance=np.ptp(contour, axis=0).max() / 100,
+    )
+    if 0:
+        image_pil = PIL.Image.fromarray(image)
+        imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
+        for point in polygon:
+            imgviz.draw.circle_(
+                image_pil, center=point, diameter=10, fill=(0, 255, 0)
+            )
+        imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
+
+    return polygon[:, ::-1]  # yx -> xy

+ 42 - 0
labelme/app.py

@@ -18,6 +18,7 @@ from qtpy import QtWidgets
 from labelme import __appname__
 from labelme import PY2
 
+from . import ai
 from . import utils
 from labelme.config import get_config
 from labelme.label_file import LabelFile
@@ -367,6 +368,14 @@ class MainWindow(QtWidgets.QMainWindow):
             self.tr("Start drawing linestrip. Ctrl+LeftClick ends creation."),
             enabled=False,
         )
+        createAiPolygonMode = action(
+            self.tr("Create AI-Polygon"),
+            lambda: self.toggleDrawMode(False, createMode="ai_polygon"),
+            None,
+            "objects",
+            self.tr("Start drawing ai_polygon. Ctrl+LeftClick ends creation."),
+            enabled=False,
+        )
         editMode = action(
             self.tr("Edit Polygons"),
             self.setEditMode,
@@ -603,6 +612,7 @@ class MainWindow(QtWidgets.QMainWindow):
             createLineMode=createLineMode,
             createPointMode=createPointMode,
             createLineStripMode=createLineStripMode,
+            createAiPolygonMode=createAiPolygonMode,
             zoom=zoom,
             zoomIn=zoomIn,
             zoomOut=zoomOut,
@@ -637,6 +647,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 createLineMode,
                 createPointMode,
                 createLineStripMode,
+                createAiPolygonMode,
                 editMode,
                 edit,
                 duplicate,
@@ -655,6 +666,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 createLineMode,
                 createPointMode,
                 createLineStripMode,
+                createAiPolygonMode,
                 editMode,
                 brightnessContrast,
             ),
@@ -817,6 +829,12 @@ class MainWindow(QtWidgets.QMainWindow):
         # if self.firstStart:
         #    QWhatsThis.enterWhatsThisMode()
 
+    @property
+    def _ai_model(self):
+        if not hasattr(self, "_ai_model_initialized"):
+            self._ai_model_initialized = ai.SegmentAnythingModel()
+        return self._ai_model_initialized
+
     def menu(self, title, actions=None):
         menu = self.menuBar().addMenu(title)
         if actions:
@@ -852,6 +870,7 @@ class MainWindow(QtWidgets.QMainWindow):
             self.actions.createLineMode,
             self.actions.createPointMode,
             self.actions.createLineStripMode,
+            self.actions.createAiPolygonMode,
             self.actions.editMode,
         )
         utils.addActions(self.menus.edit, actions + self.actions.editMenu)
@@ -883,6 +902,7 @@ class MainWindow(QtWidgets.QMainWindow):
         self.actions.createLineMode.setEnabled(True)
         self.actions.createPointMode.setEnabled(True)
         self.actions.createLineStripMode.setEnabled(True)
+        self.actions.createAiPolygonMode.setEnabled(True)
         title = __appname__
         if self.filename is not None:
             title = "{} - {}".format(title, self.filename)
@@ -953,6 +973,13 @@ class MainWindow(QtWidgets.QMainWindow):
     def toggleDrawMode(self, edit=True, createMode="polygon"):
         self.canvas.setEditing(edit)
         self.canvas.createMode = createMode
+        if createMode == "ai_polygon":
+            self._ai_model.set_image(utils.img_data_to_arr(self.imageData))
+            self.canvas.setAiCallback(
+                self._ai_model.points_to_polygon_callback
+            )
+        else:
+            self.canvas.setAiCallback(None)
         if edit:
             self.actions.createMode.setEnabled(True)
             self.actions.createRectangleMode.setEnabled(True)
@@ -960,6 +987,7 @@ class MainWindow(QtWidgets.QMainWindow):
             self.actions.createLineMode.setEnabled(True)
             self.actions.createPointMode.setEnabled(True)
             self.actions.createLineStripMode.setEnabled(True)
+            self.actions.createAiPolygonMode.setEnabled(True)
         else:
             if createMode == "polygon":
                 self.actions.createMode.setEnabled(False)
@@ -968,6 +996,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createLineMode.setEnabled(True)
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
+                self.actions.createAiPolygonMode.setEnabled(True)
             elif createMode == "rectangle":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(False)
@@ -975,6 +1004,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createLineMode.setEnabled(True)
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
+                self.actions.createAiPolygonMode.setEnabled(True)
             elif createMode == "line":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -982,6 +1012,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createLineMode.setEnabled(False)
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
+                self.actions.createAiPolygonMode.setEnabled(True)
             elif createMode == "point":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -989,6 +1020,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createLineMode.setEnabled(True)
                 self.actions.createPointMode.setEnabled(False)
                 self.actions.createLineStripMode.setEnabled(True)
+                self.actions.createAiPolygonMode.setEnabled(True)
             elif createMode == "circle":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -996,6 +1028,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createLineMode.setEnabled(True)
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
+                self.actions.createAiPolygonMode.setEnabled(True)
             elif createMode == "linestrip":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -1003,6 +1036,15 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createLineMode.setEnabled(True)
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(False)
+                self.actions.createAiPolygonMode.setEnabled(True)
+            elif createMode == "ai_polygon":
+                self.actions.createMode.setEnabled(True)
+                self.actions.createRectangleMode.setEnabled(True)
+                self.actions.createCircleMode.setEnabled(True)
+                self.actions.createLineMode.setEnabled(True)
+                self.actions.createPointMode.setEnabled(True)
+                self.actions.createLineStripMode.setEnabled(True)
+                self.actions.createAiPolygonMode.setEnabled(False)
             else:
                 raise ValueError("Unsupported createMode: %s" % createMode)
         self.actions.editMode.setEnabled(not edit)

+ 1 - 0
labelme/config/default_config.yaml

@@ -75,6 +75,7 @@ canvas:
     line: false
     point: false
     linestrip: false
+    ai_polygon: false
 
 shortcuts:
   close: Ctrl+W

+ 14 - 0
labelme/shape.py

@@ -57,6 +57,10 @@ class Shape(object):
         self.label = label
         self.group_id = group_id
         self.points = []
+        self.shape_type = shape_type
+        self._shape_raw = None
+        self._points_raw = []
+        self._shape_type_raw = None
         self.fill = False
         self.selected = False
         self.shape_type = shape_type
@@ -79,8 +83,17 @@ class Shape(object):
             # is used for drawing the pending line a different color.
             self.line_color = line_color
 
+    def setShapeRefined(self, points, shape_type):
+        self._shape_raw = (self.points, self.shape_type)
+        self.points = points
         self.shape_type = shape_type
 
+    def restoreShapeRaw(self):
+        if self._shape_raw is None:
+            return
+        self.points, self.shape_type = self._shape_raw
+        self._shape_raw = None
+
     @property
     def shape_type(self):
         return self._shape_type
@@ -96,6 +109,7 @@ class Shape(object):
             "line",
             "circle",
             "linestrip",
+            "points",
         ]:
             raise ValueError("Unexpected shape_type: {}".format(value))
         self._shape_type = value

+ 38 - 4
labelme/widgets/canvas.py

@@ -1,7 +1,9 @@
+import numpy as np
 from qtpy import QtCore
 from qtpy import QtGui
 from qtpy import QtWidgets
 
+import labelme.ai
 from labelme import QT5
 from labelme.shape import Shape
 import labelme.utils
@@ -56,6 +58,7 @@ class Canvas(QtWidgets.QWidget):
                 "line": False,
                 "point": False,
                 "linestrip": False,
+                "ai_polygon": False,
             },
         )
         super(Canvas, self).__init__(*args, **kwargs)
@@ -99,6 +102,11 @@ class Canvas(QtWidgets.QWidget):
         self.setMouseTracking(True)
         self.setFocusPolicy(QtCore.Qt.WheelFocus)
 
+        self._ai_callback = None
+
+    def setAiCallback(self, ai_callback):
+        self._ai_callback = ai_callback
+
     def fillDrawing(self):
         return self._fill_drawing
 
@@ -118,8 +126,10 @@ class Canvas(QtWidgets.QWidget):
             "line",
             "point",
             "linestrip",
+            "ai_polygon",
         ]:
             raise ValueError("Unsupported createMode: %s" % value)
+
         self._createMode = value
 
     def storeShapes(self):
@@ -215,7 +225,10 @@ class Canvas(QtWidgets.QWidget):
 
         # Polygon drawing.
         if self.drawing():
-            self.line.shape_type = self.createMode
+            if self.createMode == "ai_polygon":
+                self.line.shape_type = "points"
+            else:
+                self.line.shape_type = self.createMode
 
             self.overrideCursor(CURSOR_DRAW)
             if not self.current:
@@ -237,7 +250,7 @@ class Canvas(QtWidgets.QWidget):
                 pos = self.current[0]
                 self.overrideCursor(CURSOR_POINT)
                 self.current.highlightVertex(0, Shape.NEAR_VERTEX)
-            if self.createMode in ["polygon", "linestrip"]:
+            if self.createMode in ["polygon", "linestrip", "ai_polygon"]:
                 self.line[0] = self.current[-1]
                 self.line[1] = pos
             elif self.createMode == "rectangle":
@@ -378,14 +391,18 @@ class Canvas(QtWidgets.QWidget):
                         assert len(self.current.points) == 1
                         self.current.points = self.line.points
                         self.finalise()
-                    elif self.createMode == "linestrip":
+                    elif self.createMode in ["linestrip", "ai_polygon"]:
                         self.current.addPoint(self.line[1])
                         self.line[0] = self.current[-1]
                         if int(ev.modifiers()) == QtCore.Qt.ControlModifier:
                             self.finalise()
                 elif not self.outOfPixmap(pos):
                     # Create new shape.
-                    self.current = Shape(shape_type=self.createMode)
+                    self.current = Shape(
+                        shape_type="points"
+                        if self.createMode == "ai_polygon"
+                        else self.createMode
+                    )
                     self.current.addPoint(pos)
                     if self.createMode == "point":
                         self.finalise()
@@ -701,7 +718,23 @@ class Canvas(QtWidgets.QWidget):
 
     def finalise(self):
         assert self.current
+        if self.createMode == "ai_polygon":
+            # convert points to polygon by an AI model
+            assert self.current.shape_type == "points"
+            points = self._ai_callback(
+                points=np.array(
+                    [[point.x(), point.y()] for point in self.current.points],
+                    dtype=np.float32,
+                )
+            )
+            self.current.setShapeRefined(
+                points=[
+                    QtCore.QPointF(point[0], point[1]) for point in points
+                ],
+                shape_type="polygon",
+            )
         self.current.close()
+
         self.shapes.append(self.current)
         self.storeShapes()
         self.current = None
@@ -869,6 +902,7 @@ class Canvas(QtWidgets.QWidget):
         assert self.shapes
         self.current = self.shapes.pop()
         self.current.setOpen()
+        self.current.restoreShapeRaw()
         if self.createMode in ["polygon", "linestrip"]:
             self.line.points = [self.current[-1], self.current[0]]
         elif self.createMode in ["rectangle", "line", "circle"]: