efficient_sam.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import collections
  2. import threading
  3. import imgviz
  4. import numpy as np
  5. import onnxruntime
  6. import skimage
  7. from ..logger import logger
  8. from . import _utils
  9. class EfficientSam:
  10. def __init__(self, encoder_path, decoder_path):
  11. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  12. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  13. self._lock = threading.Lock()
  14. self._image_embedding_cache = collections.OrderedDict()
  15. self._thread = None
  16. def set_image(self, image: np.ndarray):
  17. with self._lock:
  18. self._image = image
  19. self._image_embedding = self._image_embedding_cache.get(
  20. self._image.tobytes()
  21. )
  22. if self._image_embedding is None:
  23. self._thread = threading.Thread(
  24. target=self._compute_and_cache_image_embedding
  25. )
  26. self._thread.start()
  27. def _compute_and_cache_image_embedding(self):
  28. with self._lock:
  29. logger.debug("Computing image embedding...")
  30. image = imgviz.rgba2rgb(self._image)
  31. batched_images = (
  32. image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0
  33. )
  34. (self._image_embedding,) = self._encoder_session.run(
  35. output_names=None,
  36. input_feed={"batched_images": batched_images},
  37. )
  38. if len(self._image_embedding_cache) > 10:
  39. self._image_embedding_cache.popitem(last=False)
  40. self._image_embedding_cache[
  41. self._image.tobytes()
  42. ] = self._image_embedding
  43. logger.debug("Done computing image embedding.")
  44. def _get_image_embedding(self):
  45. if self._thread is not None:
  46. self._thread.join()
  47. self._thread = None
  48. with self._lock:
  49. return self._image_embedding
  50. def predict_mask_from_points(self, points, point_labels):
  51. return _compute_mask_from_points(
  52. decoder_session=self._decoder_session,
  53. image=self._image,
  54. image_embedding=self._get_image_embedding(),
  55. points=points,
  56. point_labels=point_labels,
  57. )
  58. def predict_polygon_from_points(self, points, point_labels):
  59. mask = self.predict_mask_from_points(
  60. points=points, point_labels=point_labels
  61. )
  62. return _utils.compute_polygon_from_mask(mask=mask)
  63. def _compute_mask_from_points(
  64. decoder_session, image, image_embedding, points, point_labels
  65. ):
  66. input_point = np.array(points, dtype=np.float32)
  67. input_label = np.array(point_labels, dtype=np.float32)
  68. # batch_size, num_queries, num_points, 2
  69. batched_point_coords = input_point[None, None, :, :]
  70. # batch_size, num_queries, num_points
  71. batched_point_labels = input_label[None, None, :]
  72. decoder_inputs = {
  73. "image_embeddings": image_embedding,
  74. "batched_point_coords": batched_point_coords,
  75. "batched_point_labels": batched_point_labels,
  76. "orig_im_size": np.array(image.shape[:2], dtype=np.int64),
  77. }
  78. masks, _, _ = decoder_session.run(None, decoder_inputs)
  79. mask = masks[0, 0, 0, :, :] # (1, 1, 3, H, W) -> (H, W)
  80. mask = mask > 0.0
  81. MIN_SIZE_RATIO = 0.05
  82. skimage.morphology.remove_small_objects(
  83. mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
  84. )
  85. if 0:
  86. imgviz.io.imsave(
  87. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  88. )
  89. return mask