segment_anything_model.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import collections
  2. import threading
  3. import imgviz
  4. import numpy as np
  5. import onnxruntime
  6. import skimage
  7. from ..logger import logger
  8. from . import _utils
  9. class SegmentAnythingModel:
  10. def __init__(self, name, encoder_path, decoder_path):
  11. self.name = name
  12. self._image_size = 1024
  13. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  14. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  15. self._lock = threading.Lock()
  16. self._image_embedding_cache = collections.OrderedDict()
  17. self._thread = None
  18. def set_image(self, image: np.ndarray):
  19. with self._lock:
  20. self._image = image
  21. self._image_embedding = self._image_embedding_cache.get(
  22. self._image.tobytes()
  23. )
  24. if self._image_embedding is None:
  25. self._thread = threading.Thread(
  26. target=self._compute_and_cache_image_embedding
  27. )
  28. self._thread.start()
  29. def _compute_and_cache_image_embedding(self):
  30. with self._lock:
  31. logger.debug("Computing image embedding...")
  32. self._image_embedding = _compute_image_embedding(
  33. image_size=self._image_size,
  34. encoder_session=self._encoder_session,
  35. image=self._image,
  36. )
  37. if len(self._image_embedding_cache) > 10:
  38. self._image_embedding_cache.popitem(last=False)
  39. self._image_embedding_cache[
  40. self._image.tobytes()
  41. ] = self._image_embedding
  42. logger.debug("Done computing image embedding.")
  43. def _get_image_embedding(self):
  44. if self._thread is not None:
  45. self._thread.join()
  46. self._thread = None
  47. with self._lock:
  48. return self._image_embedding
  49. def predict_mask_from_points(self, points, point_labels):
  50. return _compute_mask_from_points(
  51. image_size=self._image_size,
  52. decoder_session=self._decoder_session,
  53. image=self._image,
  54. image_embedding=self._get_image_embedding(),
  55. points=points,
  56. point_labels=point_labels,
  57. )
  58. def predict_polygon_from_points(self, points, point_labels):
  59. mask = self.predict_mask_from_points(
  60. points=points, point_labels=point_labels
  61. )
  62. return _utils.compute_polygon_from_mask(mask=mask)
  63. def _compute_scale_to_resize_image(image_size, image):
  64. height, width = image.shape[:2]
  65. if width > height:
  66. scale = image_size / width
  67. new_height = int(round(height * scale))
  68. new_width = image_size
  69. else:
  70. scale = image_size / height
  71. new_height = image_size
  72. new_width = int(round(width * scale))
  73. return scale, new_height, new_width
  74. def _resize_image(image_size, image):
  75. scale, new_height, new_width = _compute_scale_to_resize_image(
  76. image_size=image_size, image=image
  77. )
  78. scaled_image = imgviz.resize(
  79. image,
  80. height=new_height,
  81. width=new_width,
  82. backend="pillow",
  83. ).astype(np.float32)
  84. return scale, scaled_image
  85. def _compute_image_embedding(image_size, encoder_session, image):
  86. image = imgviz.asrgb(image)
  87. scale, x = _resize_image(image_size, image)
  88. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  89. [58.395, 57.12, 57.375], dtype=np.float32
  90. )
  91. x = np.pad(
  92. x,
  93. (
  94. (0, image_size - x.shape[0]),
  95. (0, image_size - x.shape[1]),
  96. (0, 0),
  97. ),
  98. )
  99. x = x.transpose(2, 0, 1)[None, :, :, :]
  100. output = encoder_session.run(output_names=None, input_feed={"x": x})
  101. image_embedding = output[0]
  102. return image_embedding
  103. def _compute_mask_from_points(
  104. image_size, decoder_session, image, image_embedding, points, point_labels
  105. ):
  106. input_point = np.array(points, dtype=np.float32)
  107. input_label = np.array(point_labels, dtype=np.int32)
  108. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  109. None, :, :
  110. ]
  111. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  112. None, :
  113. ].astype(np.float32)
  114. scale, new_height, new_width = _compute_scale_to_resize_image(
  115. image_size=image_size, image=image
  116. )
  117. onnx_coord = (
  118. onnx_coord.astype(float)
  119. * (new_width / image.shape[1], new_height / image.shape[0])
  120. ).astype(np.float32)
  121. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  122. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  123. decoder_inputs = {
  124. "image_embeddings": image_embedding,
  125. "point_coords": onnx_coord,
  126. "point_labels": onnx_label,
  127. "mask_input": onnx_mask_input,
  128. "has_mask_input": onnx_has_mask_input,
  129. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  130. }
  131. masks, _, _ = decoder_session.run(None, decoder_inputs)
  132. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  133. mask = mask > 0.0
  134. MIN_SIZE_RATIO = 0.05
  135. skimage.morphology.remove_small_objects(
  136. mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
  137. )
  138. if 0:
  139. imgviz.io.imsave(
  140. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  141. )
  142. return mask