|
@@ -71,6 +71,18 @@ class SegmentAnythingModel:
|
|
)
|
|
)
|
|
return polygon
|
|
return polygon
|
|
|
|
|
|
|
|
+ def predict_mask_from_points(self, points, point_labels):
|
|
|
|
+ image_embedding = self._get_image_embedding()
|
|
|
|
+ mask = _compute_mask_from_points(
|
|
|
|
+ image_size=self._image_size,
|
|
|
|
+ decoder_session=self._decoder_session,
|
|
|
|
+ image=self._image,
|
|
|
|
+ image_embedding=image_embedding,
|
|
|
|
+ points=points,
|
|
|
|
+ point_labels=point_labels,
|
|
|
|
+ )
|
|
|
|
+ return mask
|
|
|
|
+
|
|
|
|
|
|
def _compute_scale_to_resize_image(image_size, image):
|
|
def _compute_scale_to_resize_image(image_size, image):
|
|
height, width = image.shape[:2]
|
|
height, width = image.shape[:2]
|
|
@@ -127,7 +139,7 @@ def _get_contour_length(contour):
|
|
return np.linalg.norm(contour_end - contour_start, axis=1).sum()
|
|
return np.linalg.norm(contour_end - contour_start, axis=1).sum()
|
|
|
|
|
|
|
|
|
|
-def _compute_polygon_from_points(
|
|
|
|
|
|
+def _compute_mask_from_points(
|
|
image_size, decoder_session, image, image_embedding, points, point_labels
|
|
image_size, decoder_session, image, image_embedding, points, point_labels
|
|
):
|
|
):
|
|
input_point = np.array(points, dtype=np.float32)
|
|
input_point = np.array(points, dtype=np.float32)
|
|
@@ -167,6 +179,20 @@ def _compute_polygon_from_points(
|
|
imgviz.io.imsave(
|
|
imgviz.io.imsave(
|
|
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
|
|
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
|
|
)
|
|
)
|
|
|
|
+ return mask
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _compute_polygon_from_points(
|
|
|
|
+ image_size, decoder_session, image, image_embedding, points, point_labels
|
|
|
|
+):
|
|
|
|
+ mask = _compute_mask_from_points(
|
|
|
|
+ image_size=image_size,
|
|
|
|
+ decoder_session=decoder_session,
|
|
|
|
+ image=image,
|
|
|
|
+ image_embedding=image_embedding,
|
|
|
|
+ points=points,
|
|
|
|
+ point_labels=point_labels,
|
|
|
|
+ )
|
|
|
|
|
|
contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
|
|
contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
|
|
contour = max(contours, key=_get_contour_length)
|
|
contour = max(contours, key=_get_contour_length)
|