123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- import collections
- import threading
- import imgviz
- import numpy as np
- import onnxruntime
- import skimage
- from ..logger import logger
- from . import _utils
- class SegmentAnythingModel:
- def __init__(self, encoder_path, decoder_path):
- self._image_size = 1024
- self._encoder_session = onnxruntime.InferenceSession(encoder_path)
- self._decoder_session = onnxruntime.InferenceSession(decoder_path)
- self._lock = threading.Lock()
- self._image_embedding_cache = collections.OrderedDict()
- self._thread = None
- def set_image(self, image: np.ndarray):
- with self._lock:
- self._image = image
- self._image_embedding = self._image_embedding_cache.get(
- self._image.tobytes()
- )
- if self._image_embedding is None:
- self._thread = threading.Thread(
- target=self._compute_and_cache_image_embedding
- )
- self._thread.start()
- def _compute_and_cache_image_embedding(self):
- with self._lock:
- logger.debug("Computing image embedding...")
- self._image_embedding = _compute_image_embedding(
- image_size=self._image_size,
- encoder_session=self._encoder_session,
- image=self._image,
- )
- if len(self._image_embedding_cache) > 10:
- self._image_embedding_cache.popitem(last=False)
- self._image_embedding_cache[
- self._image.tobytes()
- ] = self._image_embedding
- logger.debug("Done computing image embedding.")
- def _get_image_embedding(self):
- if self._thread is not None:
- self._thread.join()
- self._thread = None
- with self._lock:
- return self._image_embedding
- def predict_mask_from_points(self, points, point_labels):
- return _compute_mask_from_points(
- image_size=self._image_size,
- decoder_session=self._decoder_session,
- image=self._image,
- image_embedding=self._get_image_embedding(),
- points=points,
- point_labels=point_labels,
- )
- def predict_polygon_from_points(self, points, point_labels):
- mask = self.predict_mask_from_points(
- points=points, point_labels=point_labels
- )
- return _utils.compute_polygon_from_mask(mask=mask)
- def _compute_scale_to_resize_image(image_size, image):
- height, width = image.shape[:2]
- if width > height:
- scale = image_size / width
- new_height = int(round(height * scale))
- new_width = image_size
- else:
- scale = image_size / height
- new_height = image_size
- new_width = int(round(width * scale))
- return scale, new_height, new_width
- def _resize_image(image_size, image):
- scale, new_height, new_width = _compute_scale_to_resize_image(
- image_size=image_size, image=image
- )
- scaled_image = imgviz.resize(
- image,
- height=new_height,
- width=new_width,
- backend="pillow",
- ).astype(np.float32)
- return scale, scaled_image
- def _compute_image_embedding(image_size, encoder_session, image):
- image = imgviz.asrgb(image)
- scale, x = _resize_image(image_size, image)
- x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
- [58.395, 57.12, 57.375], dtype=np.float32
- )
- x = np.pad(
- x,
- (
- (0, image_size - x.shape[0]),
- (0, image_size - x.shape[1]),
- (0, 0),
- ),
- )
- x = x.transpose(2, 0, 1)[None, :, :, :]
- output = encoder_session.run(output_names=None, input_feed={"x": x})
- image_embedding = output[0]
- return image_embedding
- def _compute_mask_from_points(
- image_size, decoder_session, image, image_embedding, points, point_labels
- ):
- input_point = np.array(points, dtype=np.float32)
- input_label = np.array(point_labels, dtype=np.int32)
- onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
- None, :, :
- ]
- onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
- None, :
- ].astype(np.float32)
- scale, new_height, new_width = _compute_scale_to_resize_image(
- image_size=image_size, image=image
- )
- onnx_coord = (
- onnx_coord.astype(float)
- * (new_width / image.shape[1], new_height / image.shape[0])
- ).astype(np.float32)
- onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
- onnx_has_mask_input = np.array([-1], dtype=np.float32)
- decoder_inputs = {
- "image_embeddings": image_embedding,
- "point_coords": onnx_coord,
- "point_labels": onnx_label,
- "mask_input": onnx_mask_input,
- "has_mask_input": onnx_has_mask_input,
- "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
- }
- masks, _, _ = decoder_session.run(None, decoder_inputs)
- mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
- mask = mask > 0.0
- MIN_SIZE_RATIO = 0.05
- skimage.morphology.remove_small_objects(
- mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
- )
- if 0:
- imgviz.io.imsave(
- "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
- )
- return mask
|