Explorar el Código

Refactor segment_anything.py

Kentaro Wada hace 1 año
padre
commit
9afd7556ff
Se han modificado 2 ficheros con 44 adiciones y 55 borrados
  1. 35 0
      labelme/ai/models/_utils.py
  2. 9 55
      labelme/ai/models/segment_anything.py

+ 35 - 0
labelme/ai/models/_utils.py

@@ -0,0 +1,35 @@
+import numpy as np
+import skimage
+import PIL.Image
+import imgviz
+
+
+def _get_contour_length(contour):
+    contour_start = contour
+    contour_end = np.r_[contour[1:], contour[0:1]]
+    return np.linalg.norm(contour_end - contour_start, axis=1).sum()
+
+
+def compute_polygon_from_mask(mask):
+    contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
+    contour = max(contours, key=_get_contour_length)
+    POLYGON_APPROX_TOLERANCE = 0.004
+    polygon = skimage.measure.approximate_polygon(
+        coords=contour,
+        tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
+    )
+    polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
+    polygon = polygon[:-1]  # drop last point that is duplicate of first point
+
+    if 0:
+        image_pil = PIL.Image.fromarray(
+            imgviz.gray2rgb(imgviz.bool2ubyte(mask))
+        )
+        imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
+        for point in polygon:
+            imgviz.draw.circle_(
+                image_pil, center=point, diameter=10, fill=(0, 255, 0)
+            )
+        imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
+
+    return polygon[:, ::-1]  # yx -> xy

+ 9 - 55
labelme/ai/models/segment_anything.py

@@ -4,11 +4,12 @@ import threading
 import imgviz
 import numpy as np
 import onnxruntime
-import PIL.Image
 import skimage
 
 from ...logger import logger
 
+from . import _utils
+
 
 class SegmentAnythingModel:
     def __init__(self, name, encoder_path, decoder_path):
@@ -59,29 +60,21 @@ class SegmentAnythingModel:
         with self._lock:
             return self._image_embedding
 
-    def predict_polygon_from_points(self, points, point_labels):
-        image_embedding = self._get_image_embedding()
-        polygon = _compute_polygon_from_points(
+    def predict_mask_from_points(self, points, point_labels):
+        return _compute_mask_from_points(
             image_size=self._image_size,
             decoder_session=self._decoder_session,
             image=self._image,
-            image_embedding=image_embedding,
+            image_embedding=self._get_image_embedding(),
             points=points,
             point_labels=point_labels,
         )
-        return polygon
 
-    def predict_mask_from_points(self, points, point_labels):
-        image_embedding = self._get_image_embedding()
-        mask = _compute_mask_from_points(
-            image_size=self._image_size,
-            decoder_session=self._decoder_session,
-            image=self._image,
-            image_embedding=image_embedding,
-            points=points,
-            point_labels=point_labels,
+    def predict_polygon_from_points(self, points, point_labels):
+        mask = self.predict_mask_from_points(
+            points=points, point_labels=point_labels
         )
-        return mask
+        return _utils.compute_polygon_from_mask(mask=mask)
 
 
 def _compute_scale_to_resize_image(image_size, image):
@@ -133,12 +126,6 @@ def _compute_image_embedding(image_size, encoder_session, image):
     return image_embedding
 
 
-def _get_contour_length(contour):
-    contour_start = contour
-    contour_end = np.r_[contour[1:], contour[0:1]]
-    return np.linalg.norm(contour_end - contour_start, axis=1).sum()
-
-
 def _compute_mask_from_points(
     image_size, decoder_session, image, image_embedding, points, point_labels
 ):
@@ -186,36 +173,3 @@ def _compute_mask_from_points(
             "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
         )
     return mask
-
-
-def _compute_polygon_from_points(
-    image_size, decoder_session, image, image_embedding, points, point_labels
-):
-    mask = _compute_mask_from_points(
-        image_size=image_size,
-        decoder_session=decoder_session,
-        image=image,
-        image_embedding=image_embedding,
-        points=points,
-        point_labels=point_labels,
-    )
-
-    contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
-    contour = max(contours, key=_get_contour_length)
-    POLYGON_APPROX_TOLERANCE = 0.004
-    polygon = skimage.measure.approximate_polygon(
-        coords=contour,
-        tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
-    )
-    polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
-    polygon = polygon[:-1]  # drop last point that is duplicate of first point
-    if 0:
-        image_pil = PIL.Image.fromarray(image)
-        imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
-        for point in polygon:
-            imgviz.draw.circle_(
-                image_pil, center=point, diameter=10, fill=(0, 255, 0)
-            )
-        imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
-
-    return polygon[:, ::-1]  # yx -> xy