Jelajahi Sumber

Give options to select AI models

Kentaro Wada 2 tahun lalu
induk
melakukan
e92bba9219
4 mengubah file dengan 106 tambahan dan 28 penghapusan
  1. 42 1
      labelme/ai/__init__.py
  2. 8 22
      labelme/ai/models/segment_anything.py
  3. 32 0
      labelme/app.py
  4. 24 5
      labelme/widgets/canvas.py

+ 42 - 1
labelme/ai/__init__.py

@@ -1,3 +1,44 @@
-# flake8: noqa
+import collections
 
 from .models.segment_anything import SegmentAnythingModel
+
+
+Model = collections.namedtuple("Model", ["name", "encoder_weight", "decoder_weight"])
+
+Weight = collections.namedtuple("Weight", ["url", "md5"])
+
+MODELS = [
+    Model(
+        name="Segment-Anything (speed)",
+        encoder_weight=Weight(
+            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx",  # NOQA
+            md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
+        ),
+        decoder_weight=Weight(
+            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx",  # NOQA
+            md5="4253558be238c15fc265a7a876aaec82",
+        ),
+    ),
+    Model(
+        name="Segment-Anything (balanced)",
+        encoder_weight=Weight(
+            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx",  # NOQA
+            md5="080004dc9992724d360a49399d1ee24b",
+        ),
+        decoder_weight=Weight(
+            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx",  # NOQA
+            md5="851b7faac91e8e23940ee1294231d5c7",
+        ),
+    ),
+    Model(
+        name="Segment-Anything (accuracy)",
+        encoder_weight=Weight(
+            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx",  # NOQA
+            md5="958b5710d25b198d765fb6b94798f49e",
+        ),
+        decoder_weight=Weight(
+            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx",  # NOQA
+            md5="a997a408347aa081b17a3ffff9f42a80",
+        ),
+    ),
+]

+ 8 - 22
labelme/ai/models/segment_anything.py

@@ -12,25 +12,10 @@ from ...logger import logger
 
 
 class SegmentAnythingModel:
-    def __init__(self):
-        self._image_size = 1024
+    def __init__(self, name, encoder_path, decoder_path):
+        self.name = name
 
-        # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx"  # NOQA
-        # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx"  # NOQA
-        #
-        # encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx"  # NOQA
-        # decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx"  # NOQA
-        #
-        # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx"  # NOQA
-        # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx"  # NOQA
-        encoder_path = gdown.cached_download(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx",  # NOQA
-            md5="080004dc9992724d360a49399d1ee24b",
-        )
-        decoder_path = gdown.cached_download(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx",  # NOQA
-            md5="851b7faac91e8e23940ee1294231d5c7",
-        )
+        self._image_size = 1024
 
         self._encoder_session = onnxruntime.InferenceSession(encoder_path)
         self._decoder_session = onnxruntime.InferenceSession(decoder_path)
@@ -45,10 +30,11 @@ class SegmentAnythingModel:
                 self._image.tobytes()
             )
 
-        self._thread = threading.Thread(
-            target=self._compute_and_cache_image_embedding
-        )
-        self._thread.start()
+        if self._image_embedding is None:
+            self._thread = threading.Thread(
+                target=self._compute_and_cache_image_embedding
+            )
+            self._thread.start()
 
     def _compute_and_cache_image_embedding(self):
         with self._lock:

+ 32 - 0
labelme/app.py

@@ -19,6 +19,7 @@ from labelme import __appname__
 from labelme import PY2
 
 from . import utils
+from labelme.ai import MODELS
 from labelme.config import get_config
 from labelme.label_file import LabelFile
 from labelme.label_file import LabelFileError
@@ -748,6 +749,24 @@ class MainWindow(QtWidgets.QMainWindow):
             ),
         )
 
+        selectAiModel = QtWidgets.QWidgetAction(self)
+        selectAiModel.setDefaultWidget(QtWidgets.QWidget())
+        selectAiModel.defaultWidget().setLayout(QtWidgets.QVBoxLayout())
+        self._selectAiModelComboBox = QtWidgets.QComboBox()
+        selectAiModel.defaultWidget().layout().addWidget(self._selectAiModelComboBox)
+        self._selectAiModelComboBox.addItems([model.name for model in MODELS])
+        self._selectAiModelComboBox.setCurrentIndex(1)
+        self._selectAiModelComboBox.setEnabled(False)
+        self._selectAiModelComboBox.currentIndexChanged.connect(
+            lambda: self.canvas.initializeAiModel(
+                name=self._selectAiModelComboBox.currentText()
+            )
+        )
+        selectAiModelLabel = QtWidgets.QLabel(self.tr("AI Model"))
+        selectAiModelLabel.setAlignment(QtCore.Qt.AlignCenter)
+        selectAiModelLabel.setFont(QtGui.QFont(None, 10))
+        selectAiModel.defaultWidget().layout().addWidget(selectAiModelLabel)
+
         self.tools = self.toolbar("Tools")
         self.actions.tool = (
             open_,
@@ -768,6 +787,8 @@ class MainWindow(QtWidgets.QMainWindow):
             None,
             zoom,
             fitWidth,
+            None,
+            selectAiModel,
         )
 
         self.statusBar().showMessage(str(self.tr("%s started.")) % __appname__)
@@ -981,6 +1002,7 @@ class MainWindow(QtWidgets.QMainWindow):
             self.actions.createPointMode.setEnabled(True)
             self.actions.createLineStripMode.setEnabled(True)
             self.actions.createAiPolygonMode.setEnabled(True)
+            self._selectAiModelComboBox.setEnabled(False)
         else:
             if createMode == "polygon":
                 self.actions.createMode.setEnabled(False)
@@ -990,6 +1012,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
                 self.actions.createAiPolygonMode.setEnabled(True)
+                self._selectAiModelComboBox.setEnabled(False)
             elif createMode == "rectangle":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(False)
@@ -998,6 +1021,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
                 self.actions.createAiPolygonMode.setEnabled(True)
+                self._selectAiModelComboBox.setEnabled(False)
             elif createMode == "line":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -1006,6 +1030,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
                 self.actions.createAiPolygonMode.setEnabled(True)
+                self._selectAiModelComboBox.setEnabled(False)
             elif createMode == "point":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -1014,6 +1039,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(False)
                 self.actions.createLineStripMode.setEnabled(True)
                 self.actions.createAiPolygonMode.setEnabled(True)
+                self._selectAiModelComboBox.setEnabled(False)
             elif createMode == "circle":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -1022,6 +1048,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
                 self.actions.createAiPolygonMode.setEnabled(True)
+                self._selectAiModelComboBox.setEnabled(False)
             elif createMode == "linestrip":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -1030,6 +1057,7 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(False)
                 self.actions.createAiPolygonMode.setEnabled(True)
+                self._selectAiModelComboBox.setEnabled(False)
             elif createMode == "ai_polygon":
                 self.actions.createMode.setEnabled(True)
                 self.actions.createRectangleMode.setEnabled(True)
@@ -1038,6 +1066,10 @@ class MainWindow(QtWidgets.QMainWindow):
                 self.actions.createPointMode.setEnabled(True)
                 self.actions.createLineStripMode.setEnabled(True)
                 self.actions.createAiPolygonMode.setEnabled(False)
+                self.canvas.initializeAiModel(
+                    name=self._selectAiModelComboBox.currentText()
+                )
+                self._selectAiModelComboBox.setEnabled(True)
             else:
                 raise ValueError("Unsupported createMode: %s" % createMode)
         self.actions.editMode.setEnabled(not edit)

+ 24 - 5
labelme/widgets/canvas.py

@@ -1,3 +1,4 @@
+import gdown
 from qtpy import QtCore
 from qtpy import QtGui
 from qtpy import QtWidgets
@@ -126,14 +127,32 @@ class Canvas(QtWidgets.QWidget):
             "ai_polygon",
         ]:
             raise ValueError("Unsupported createMode: %s" % value)
+        self._createMode = value
 
-        if value == "ai_polygon" and self._ai_model is None:
-            self._ai_model = labelme.ai.SegmentAnythingModel()
-            self._ai_model.set_image(
-                image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
+    def initializeAiModel(self, name):
+        if name not in [model.name for model in labelme.ai.MODELS]:
+            raise ValueError("Unsupported ai model: %s" % name)
+        model = [model for model in labelme.ai.MODELS if model.name == name][0]
+
+        if self._ai_model is not None and self._ai_model.name == model.name:
+            logger.debug("AI model is already initialized: %r" % model.name)
+        else:
+            logger.debug("Initializing AI model: %r" % model.name)
+            self._ai_model = labelme.ai.SegmentAnythingModel(
+                name=model.name,
+                encoder_path = gdown.cached_download(
+                    url=model.encoder_weight.url,
+                    md5=model.encoder_weight.md5,
+                ),
+                decoder_path = gdown.cached_download(
+                    url=model.decoder_weight.url,
+                    md5=model.decoder_weight.md5,
+                ),
             )
 
-        self._createMode = value
+        self._ai_model.set_image(
+            image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
+        )
 
     def storeShapes(self):
         shapesBackup = []