|
@@ -4,11 +4,12 @@ import threading
|
|
|
import imgviz
|
|
|
import numpy as np
|
|
|
import onnxruntime
|
|
|
-import PIL.Image
|
|
|
import skimage
|
|
|
|
|
|
from ...logger import logger
|
|
|
|
|
|
+from . import _utils
|
|
|
+
|
|
|
|
|
|
class SegmentAnythingModel:
|
|
|
def __init__(self, name, encoder_path, decoder_path):
|
|
@@ -59,29 +60,21 @@ class SegmentAnythingModel:
|
|
|
with self._lock:
|
|
|
return self._image_embedding
|
|
|
|
|
|
- def predict_polygon_from_points(self, points, point_labels):
|
|
|
- image_embedding = self._get_image_embedding()
|
|
|
- polygon = _compute_polygon_from_points(
|
|
|
+ def predict_mask_from_points(self, points, point_labels):
|
|
|
+ return _compute_mask_from_points(
|
|
|
image_size=self._image_size,
|
|
|
decoder_session=self._decoder_session,
|
|
|
image=self._image,
|
|
|
- image_embedding=image_embedding,
|
|
|
+ image_embedding=self._get_image_embedding(),
|
|
|
points=points,
|
|
|
point_labels=point_labels,
|
|
|
)
|
|
|
- return polygon
|
|
|
|
|
|
- def predict_mask_from_points(self, points, point_labels):
|
|
|
- image_embedding = self._get_image_embedding()
|
|
|
- mask = _compute_mask_from_points(
|
|
|
- image_size=self._image_size,
|
|
|
- decoder_session=self._decoder_session,
|
|
|
- image=self._image,
|
|
|
- image_embedding=image_embedding,
|
|
|
- points=points,
|
|
|
- point_labels=point_labels,
|
|
|
+ def predict_polygon_from_points(self, points, point_labels):
|
|
|
+ mask = self.predict_mask_from_points(
|
|
|
+ points=points, point_labels=point_labels
|
|
|
)
|
|
|
- return mask
|
|
|
+ return _utils.compute_polygon_from_mask(mask=mask)
|
|
|
|
|
|
|
|
|
def _compute_scale_to_resize_image(image_size, image):
|
|
@@ -133,12 +126,6 @@ def _compute_image_embedding(image_size, encoder_session, image):
|
|
|
return image_embedding
|
|
|
|
|
|
|
|
|
-def _get_contour_length(contour):
|
|
|
- contour_start = contour
|
|
|
- contour_end = np.r_[contour[1:], contour[0:1]]
|
|
|
- return np.linalg.norm(contour_end - contour_start, axis=1).sum()
|
|
|
-
|
|
|
-
|
|
|
def _compute_mask_from_points(
|
|
|
image_size, decoder_session, image, image_embedding, points, point_labels
|
|
|
):
|
|
@@ -186,36 +173,3 @@ def _compute_mask_from_points(
|
|
|
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
|
|
|
)
|
|
|
return mask
|
|
|
-
|
|
|
-
|
|
|
-def _compute_polygon_from_points(
|
|
|
- image_size, decoder_session, image, image_embedding, points, point_labels
|
|
|
-):
|
|
|
- mask = _compute_mask_from_points(
|
|
|
- image_size=image_size,
|
|
|
- decoder_session=decoder_session,
|
|
|
- image=image,
|
|
|
- image_embedding=image_embedding,
|
|
|
- points=points,
|
|
|
- point_labels=point_labels,
|
|
|
- )
|
|
|
-
|
|
|
- contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
|
|
|
- contour = max(contours, key=_get_contour_length)
|
|
|
- POLYGON_APPROX_TOLERANCE = 0.004
|
|
|
- polygon = skimage.measure.approximate_polygon(
|
|
|
- coords=contour,
|
|
|
- tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
|
|
|
- )
|
|
|
- polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
|
|
|
- polygon = polygon[:-1] # drop last point that is duplicate of first point
|
|
|
- if 0:
|
|
|
- image_pil = PIL.Image.fromarray(image)
|
|
|
- imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
|
|
|
- for point in polygon:
|
|
|
- imgviz.draw.circle_(
|
|
|
- image_pil, center=point, diameter=10, fill=(0, 255, 0)
|
|
|
- )
|
|
|
- imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
|
|
|
-
|
|
|
- return polygon[:, ::-1] # yx -> xy
|