소스 검색

Use class to represent AI model

Kentaro Wada 1 년 전
부모
커밋
2ce212dbf7
4개의 변경된 파일53개의 추가작업 그리고 54개의 파일을 삭제
  1. 51 39
      labelme/ai/__init__.py
  2. 0 0
      labelme/ai/models/__init__.py
  3. 1 3
      labelme/ai/segment_anything_model.py
  4. 1 12
      labelme/widgets/canvas.py

+ 51 - 39
labelme/ai/__init__.py

@@ -1,46 +1,58 @@
-import collections
+import gdown
 
-from .segment_anything_model import SegmentAnythingModel  # NOQA
+from .segment_anything_model import SegmentAnythingModel
 
 
-Model = collections.namedtuple(
-    "Model", ["name", "encoder_weight", "decoder_weight"]
-)
+class SegmentAnythingModelVitB(SegmentAnythingModel):
+    name = "SegmentAnything (speed)"
+
+    def __init__(self):
+        super().__init__(
+            encoder_path=gdown.cached_download(
+                url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx",  # NOQA
+                md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
+            ),
+            decoder_path=gdown.cached_download(
+                url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx",  # NOQA
+                md5="4253558be238c15fc265a7a876aaec82",
+            ),
+        )
+
+
+class SegmentAnythingModelVitL(SegmentAnythingModel):
+    name = "SegmentAnything (balanced)"
+
+    def __init__(self):
+        super().__init__(
+            encoder_path=gdown.cached_download(
+                url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx",  # NOQA
+                md5="080004dc9992724d360a49399d1ee24b",
+            ),
+            decoder_path=gdown.cached_download(
+                url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx",  # NOQA
+                md5="851b7faac91e8e23940ee1294231d5c7",
+            ),
+        )
+
+
+class SegmentAnythingModelVitH(SegmentAnythingModel):
+    name = "SegmentAnything (accuracy)"
+
+    def __init__(self):
+        super().__init__(
+            encoder_path=gdown.cached_download(
+                url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx",  # NOQA
+                md5="958b5710d25b198d765fb6b94798f49e",
+            ),
+            decoder_path=gdown.cached_download(
+                url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx",  # NOQA
+                md5="a997a408347aa081b17a3ffff9f42a80",
+            ),
+        )
 
-Weight = collections.namedtuple("Weight", ["url", "md5"])
 
 MODELS = [
-    Model(
-        name="Segment-Anything (speed)",
-        encoder_weight=Weight(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx",  # NOQA
-            md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
-        ),
-        decoder_weight=Weight(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx",  # NOQA
-            md5="4253558be238c15fc265a7a876aaec82",
-        ),
-    ),
-    Model(
-        name="Segment-Anything (balanced)",
-        encoder_weight=Weight(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx",  # NOQA
-            md5="080004dc9992724d360a49399d1ee24b",
-        ),
-        decoder_weight=Weight(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx",  # NOQA
-            md5="851b7faac91e8e23940ee1294231d5c7",
-        ),
-    ),
-    Model(
-        name="Segment-Anything (accuracy)",
-        encoder_weight=Weight(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx",  # NOQA
-            md5="958b5710d25b198d765fb6b94798f49e",
-        ),
-        decoder_weight=Weight(
-            url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx",  # NOQA
-            md5="a997a408347aa081b17a3ffff9f42a80",
-        ),
-    ),
+    SegmentAnythingModelVitL,
+    SegmentAnythingModelVitB,
+    SegmentAnythingModelVitH,
 ]

+ 0 - 0
labelme/ai/models/__init__.py


+ 1 - 3
labelme/ai/segment_anything_model.py

@@ -12,9 +12,7 @@ from . import _utils
 
 
 class SegmentAnythingModel:
-    def __init__(self, name, encoder_path, decoder_path):
-        self.name = name
-
+    def __init__(self, encoder_path, decoder_path):
         self._image_size = 1024
 
         self._encoder_session = onnxruntime.InferenceSession(encoder_path)

+ 1 - 12
labelme/widgets/canvas.py

@@ -1,4 +1,3 @@
-import gdown
 import imgviz
 from qtpy import QtCore
 from qtpy import QtGui
@@ -141,17 +140,7 @@ class Canvas(QtWidgets.QWidget):
             logger.debug("AI model is already initialized: %r" % model.name)
         else:
             logger.debug("Initializing AI model: %r" % model.name)
-            self._ai_model = labelme.ai.SegmentAnythingModel(
-                name=model.name,
-                encoder_path=gdown.cached_download(
-                    url=model.encoder_weight.url,
-                    md5=model.encoder_weight.md5,
-                ),
-                decoder_path=gdown.cached_download(
-                    url=model.decoder_weight.url,
-                    md5=model.decoder_weight.md5,
-                ),
-            )
+            self._ai_model = model()
 
         if self.pixmap is None:
             logger.warning("Pixmap is not set yet")