浏览代码

Add predict_mask_from_points to SegmentAnythingModel

Kentaro Wada 1 年之前
父节点
当前提交
1cca2447c6
共有 1 个文件被更改,包括 27 次插入1 次删除
  1. 27 1
      labelme/ai/models/segment_anything.py

+ 27 - 1
labelme/ai/models/segment_anything.py

@@ -71,6 +71,18 @@ class SegmentAnythingModel:
         )
         return polygon
 
+    def predict_mask_from_points(self, points, point_labels):
+        image_embedding = self._get_image_embedding()
+        mask = _compute_mask_from_points(
+            image_size=self._image_size,
+            decoder_session=self._decoder_session,
+            image=self._image,
+            image_embedding=image_embedding,
+            points=points,
+            point_labels=point_labels,
+        )
+        return mask
+
 
 def _compute_scale_to_resize_image(image_size, image):
     height, width = image.shape[:2]
@@ -127,7 +139,7 @@ def _get_contour_length(contour):
     return np.linalg.norm(contour_end - contour_start, axis=1).sum()
 
 
-def _compute_polygon_from_points(
+def _compute_mask_from_points(
     image_size, decoder_session, image, image_embedding, points, point_labels
 ):
     input_point = np.array(points, dtype=np.float32)
@@ -167,6 +179,20 @@ def _compute_polygon_from_points(
         imgviz.io.imsave(
             "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
         )
+    return mask
+
+
+def _compute_polygon_from_points(
+    image_size, decoder_session, image, image_embedding, points, point_labels
+):
+    mask = _compute_mask_from_points(
+        image_size=image_size,
+        decoder_session=decoder_session,
+        image=image,
+        image_embedding=image_embedding,
+        points=points,
+        point_labels=point_labels,
+    )
 
     contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
     contour = max(contours, key=_get_contour_length)