segment_anything.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import collections
  2. import threading
  3. import imgviz
  4. import numpy as np
  5. import onnxruntime
  6. import PIL.Image
  7. import skimage
  8. from ...logger import logger
  9. class SegmentAnythingModel:
  10. def __init__(self, name, encoder_path, decoder_path):
  11. self.name = name
  12. self._image_size = 1024
  13. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  14. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  15. self._lock = threading.Lock()
  16. self._image_embedding_cache = collections.OrderedDict()
  17. self._thread = None
  18. def set_image(self, image: np.ndarray):
  19. with self._lock:
  20. self._image = image
  21. self._image_embedding = self._image_embedding_cache.get(
  22. self._image.tobytes()
  23. )
  24. if self._image_embedding is None:
  25. self._thread = threading.Thread(
  26. target=self._compute_and_cache_image_embedding
  27. )
  28. self._thread.start()
  29. def _compute_and_cache_image_embedding(self):
  30. with self._lock:
  31. logger.debug("Computing image embedding...")
  32. self._image_embedding = _compute_image_embedding(
  33. image_size=self._image_size,
  34. encoder_session=self._encoder_session,
  35. image=self._image,
  36. )
  37. if len(self._image_embedding_cache) > 10:
  38. self._image_embedding_cache.popitem(last=False)
  39. self._image_embedding_cache[
  40. self._image.tobytes()
  41. ] = self._image_embedding
  42. logger.debug("Done computing image embedding.")
  43. def _get_image_embedding(self):
  44. if self._thread is not None:
  45. self._thread.join()
  46. self._thread = None
  47. with self._lock:
  48. return self._image_embedding
  49. def predict_polygon_from_points(self, points, point_labels):
  50. image_embedding = self._get_image_embedding()
  51. polygon = _compute_polygon_from_points(
  52. image_size=self._image_size,
  53. decoder_session=self._decoder_session,
  54. image=self._image,
  55. image_embedding=image_embedding,
  56. points=points,
  57. point_labels=point_labels,
  58. )
  59. return polygon
  60. def predict_mask_from_points(self, points, point_labels):
  61. image_embedding = self._get_image_embedding()
  62. mask = _compute_mask_from_points(
  63. image_size=self._image_size,
  64. decoder_session=self._decoder_session,
  65. image=self._image,
  66. image_embedding=image_embedding,
  67. points=points,
  68. point_labels=point_labels,
  69. )
  70. return mask
  71. def _compute_scale_to_resize_image(image_size, image):
  72. height, width = image.shape[:2]
  73. if width > height:
  74. scale = image_size / width
  75. new_height = int(round(height * scale))
  76. new_width = image_size
  77. else:
  78. scale = image_size / height
  79. new_height = image_size
  80. new_width = int(round(width * scale))
  81. return scale, new_height, new_width
  82. def _resize_image(image_size, image):
  83. scale, new_height, new_width = _compute_scale_to_resize_image(
  84. image_size=image_size, image=image
  85. )
  86. scaled_image = imgviz.resize(
  87. image,
  88. height=new_height,
  89. width=new_width,
  90. backend="pillow",
  91. ).astype(np.float32)
  92. return scale, scaled_image
  93. def _compute_image_embedding(image_size, encoder_session, image):
  94. image = imgviz.asrgb(image)
  95. scale, x = _resize_image(image_size, image)
  96. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  97. [58.395, 57.12, 57.375], dtype=np.float32
  98. )
  99. x = np.pad(
  100. x,
  101. (
  102. (0, image_size - x.shape[0]),
  103. (0, image_size - x.shape[1]),
  104. (0, 0),
  105. ),
  106. )
  107. x = x.transpose(2, 0, 1)[None, :, :, :]
  108. output = encoder_session.run(output_names=None, input_feed={"x": x})
  109. image_embedding = output[0]
  110. return image_embedding
  111. def _get_contour_length(contour):
  112. contour_start = contour
  113. contour_end = np.r_[contour[1:], contour[0:1]]
  114. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  115. def _compute_mask_from_points(
  116. image_size, decoder_session, image, image_embedding, points, point_labels
  117. ):
  118. input_point = np.array(points, dtype=np.float32)
  119. input_label = np.array(point_labels, dtype=np.int32)
  120. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  121. None, :, :
  122. ]
  123. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  124. None, :
  125. ].astype(np.float32)
  126. scale, new_height, new_width = _compute_scale_to_resize_image(
  127. image_size=image_size, image=image
  128. )
  129. onnx_coord = (
  130. onnx_coord.astype(float)
  131. * (new_width / image.shape[1], new_height / image.shape[0])
  132. ).astype(np.float32)
  133. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  134. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  135. decoder_inputs = {
  136. "image_embeddings": image_embedding,
  137. "point_coords": onnx_coord,
  138. "point_labels": onnx_label,
  139. "mask_input": onnx_mask_input,
  140. "has_mask_input": onnx_has_mask_input,
  141. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  142. }
  143. masks, _, _ = decoder_session.run(None, decoder_inputs)
  144. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  145. mask = mask > 0.0
  146. MIN_SIZE_RATIO = 0.05
  147. skimage.morphology.remove_small_objects(
  148. mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
  149. )
  150. if 0:
  151. imgviz.io.imsave(
  152. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  153. )
  154. return mask
  155. def _compute_polygon_from_points(
  156. image_size, decoder_session, image, image_embedding, points, point_labels
  157. ):
  158. mask = _compute_mask_from_points(
  159. image_size=image_size,
  160. decoder_session=decoder_session,
  161. image=image,
  162. image_embedding=image_embedding,
  163. points=points,
  164. point_labels=point_labels,
  165. )
  166. contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
  167. contour = max(contours, key=_get_contour_length)
  168. POLYGON_APPROX_TOLERANCE = 0.004
  169. polygon = skimage.measure.approximate_polygon(
  170. coords=contour,
  171. tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
  172. )
  173. polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
  174. polygon = polygon[:-1] # drop last point that is duplicate of first point
  175. if 0:
  176. image_pil = PIL.Image.fromarray(image)
  177. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  178. for point in polygon:
  179. imgviz.draw.circle_(
  180. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  181. )
  182. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  183. return polygon[:, ::-1] # yx -> xy