Эх сурвалжийг харах

Fix uninitialized image for AI model when annotating consecutive images

Kentaro Wada 2 жил өмнө
parent
commit
f3f880e5bb

+ 2 - 0
labelme/ai/models/segment_anything.py

@@ -109,6 +109,8 @@ def _resize_image(image_size, image):
 
 
 def _compute_image_embedding(image_size, encoder_session, image):
+    image = imgviz.asrgb(image)
+
     scale, x = _resize_image(image_size, image)
     x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
         [58.395, 57.12, 57.375], dtype=np.float32

+ 0 - 14
labelme/app.py

@@ -18,7 +18,6 @@ from qtpy import QtWidgets
 from labelme import __appname__
 from labelme import PY2
 
-from . import ai
 from . import utils
 from labelme.config import get_config
 from labelme.label_file import LabelFile
@@ -829,12 +828,6 @@ class MainWindow(QtWidgets.QMainWindow):
         # if self.firstStart:
         #    QWhatsThis.enterWhatsThisMode()
 
-    @property
-    def _ai_model(self):
-        if not hasattr(self, "_ai_model_initialized"):
-            self._ai_model_initialized = ai.SegmentAnythingModel()
-        return self._ai_model_initialized
-
     def menu(self, title, actions=None):
         menu = self.menuBar().addMenu(title)
         if actions:
@@ -973,13 +966,6 @@ class MainWindow(QtWidgets.QMainWindow):
     def toggleDrawMode(self, edit=True, createMode="polygon"):
         self.canvas.setEditing(edit)
         self.canvas.createMode = createMode
-        if createMode == "ai_polygon":
-            self._ai_model.set_image(utils.img_data_to_arr(self.imageData))
-            self.canvas.setAiCallback(
-                self._ai_model.points_to_polygon_callback
-            )
-        else:
-            self.canvas.setAiCallback(None)
         if edit:
             self.actions.createMode.setEnabled(True)
             self.actions.createRectangleMode.setEnabled(True)

+ 1 - 0
labelme/utils/__init__.py

@@ -9,6 +9,7 @@ from .image import img_data_to_arr
 from .image import img_data_to_pil
 from .image import img_data_to_png_data
 from .image import img_pil_to_data
+from .image import img_qt_to_arr
 
 from .shape import labelme_shapes_to_label
 from .shape import masks_to_bboxes

+ 7 - 0
labelme/utils/image.py

@@ -56,6 +56,13 @@ def img_data_to_png_data(img_data):
             return f.read()
 
 
+def img_qt_to_arr(img_qt):
+    w, h, d = img_qt.size().width(), img_qt.size().height(), img_qt.depth()
+    bytes_ = img_qt.bits().asstring(w * h * d // 8)
+    img_arr = np.frombuffer(bytes_, dtype=np.uint8).reshape((h, w, d // 8))
+    return img_arr
+
+
 def apply_exif_orientation(image):
     try:
         exif = image._getexif()

+ 12 - 5
labelme/widgets/canvas.py

@@ -101,10 +101,7 @@ class Canvas(QtWidgets.QWidget):
         self.setMouseTracking(True)
         self.setFocusPolicy(QtCore.Qt.WheelFocus)
 
-        self._ai_callback = None
-
-    def setAiCallback(self, ai_callback):
-        self._ai_callback = ai_callback
+        self._ai_model = None
 
     def fillDrawing(self):
         return self._fill_drawing
@@ -129,6 +126,12 @@ class Canvas(QtWidgets.QWidget):
         ]:
             raise ValueError("Unsupported createMode: %s" % value)
 
+        if value == "ai_polygon" and self._ai_model is None:
+            self._ai_model = labelme.ai.SegmentAnythingModel()
+            self._ai_model.set_image(
+                image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
+            )
+
         self._createMode = value
 
     def storeShapes(self):
@@ -757,7 +760,7 @@ class Canvas(QtWidgets.QWidget):
         if self.createMode == "ai_polygon":
             # convert points to polygon by an AI model
             assert self.current.shape_type == "points"
-            points = self._ai_callback(
+            points = self._ai_model.points_to_polygon_callback(
                 points=[
                     [point.x(), point.y()] for point in self.current.points
                 ],
@@ -961,6 +964,10 @@ class Canvas(QtWidgets.QWidget):
 
     def loadPixmap(self, pixmap, clear_shapes=True):
         self.pixmap = pixmap
+        if self._ai_model:
+            self._ai_model.set_image(
+                image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
+            )
         if clear_shapes:
             self.shapes = []
         self.update()