segment_anything.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import collections
  2. import threading
  3. import imgviz
  4. import numpy as np
  5. import onnxruntime
  6. import PIL.Image
  7. import skimage.measure
  8. from ...logger import logger
  9. class SegmentAnythingModel:
  10. def __init__(self, name, encoder_path, decoder_path):
  11. self.name = name
  12. self._image_size = 1024
  13. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  14. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  15. self._lock = threading.Lock()
  16. self._image_embedding_cache = collections.OrderedDict()
  17. self._thread = None
  18. def set_image(self, image: np.ndarray):
  19. with self._lock:
  20. self._image = image
  21. self._image_embedding = self._image_embedding_cache.get(
  22. self._image.tobytes()
  23. )
  24. if self._image_embedding is None:
  25. self._thread = threading.Thread(
  26. target=self._compute_and_cache_image_embedding
  27. )
  28. self._thread.start()
  29. def _compute_and_cache_image_embedding(self):
  30. with self._lock:
  31. logger.debug("Computing image embedding...")
  32. self._image_embedding = _compute_image_embedding(
  33. image_size=self._image_size,
  34. encoder_session=self._encoder_session,
  35. image=self._image,
  36. )
  37. if len(self._image_embedding_cache) > 10:
  38. self._image_embedding_cache.popitem(last=False)
  39. self._image_embedding_cache[
  40. self._image.tobytes()
  41. ] = self._image_embedding
  42. logger.debug("Done computing image embedding.")
  43. def _get_image_embedding(self):
  44. if self._thread is not None:
  45. self._thread.join()
  46. self._thread = None
  47. with self._lock:
  48. return self._image_embedding
  49. def predict_polygon_from_points(self, points, point_labels):
  50. image_embedding = self._get_image_embedding()
  51. polygon = _compute_polygon_from_points(
  52. image_size=self._image_size,
  53. decoder_session=self._decoder_session,
  54. image=self._image,
  55. image_embedding=image_embedding,
  56. points=points,
  57. point_labels=point_labels,
  58. )
  59. return polygon
  60. def _compute_scale_to_resize_image(image_size, image):
  61. height, width = image.shape[:2]
  62. if width > height:
  63. scale = image_size / width
  64. new_height = int(round(height * scale))
  65. new_width = image_size
  66. else:
  67. scale = image_size / height
  68. new_height = image_size
  69. new_width = int(round(width * scale))
  70. return scale, new_height, new_width
  71. def _resize_image(image_size, image):
  72. scale, new_height, new_width = _compute_scale_to_resize_image(
  73. image_size=image_size, image=image
  74. )
  75. scaled_image = imgviz.resize(
  76. image,
  77. height=new_height,
  78. width=new_width,
  79. backend="pillow",
  80. ).astype(np.float32)
  81. return scale, scaled_image
  82. def _compute_image_embedding(image_size, encoder_session, image):
  83. image = imgviz.asrgb(image)
  84. scale, x = _resize_image(image_size, image)
  85. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  86. [58.395, 57.12, 57.375], dtype=np.float32
  87. )
  88. x = np.pad(
  89. x,
  90. (
  91. (0, image_size - x.shape[0]),
  92. (0, image_size - x.shape[1]),
  93. (0, 0),
  94. ),
  95. )
  96. x = x.transpose(2, 0, 1)[None, :, :, :]
  97. output = encoder_session.run(output_names=None, input_feed={"x": x})
  98. image_embedding = output[0]
  99. return image_embedding
  100. def _get_contour_length(contour):
  101. contour_start = contour
  102. contour_end = np.r_[contour[1:], contour[0:1]]
  103. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  104. def _compute_polygon_from_points(
  105. image_size, decoder_session, image, image_embedding, points, point_labels
  106. ):
  107. input_point = np.array(points, dtype=np.float32)
  108. input_label = np.array(point_labels, dtype=np.int32)
  109. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  110. None, :, :
  111. ]
  112. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  113. None, :
  114. ].astype(np.float32)
  115. scale, new_height, new_width = _compute_scale_to_resize_image(
  116. image_size=image_size, image=image
  117. )
  118. onnx_coord = (
  119. onnx_coord.astype(float)
  120. * (new_width / image.shape[1], new_height / image.shape[0])
  121. ).astype(np.float32)
  122. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  123. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  124. decoder_inputs = {
  125. "image_embeddings": image_embedding,
  126. "point_coords": onnx_coord,
  127. "point_labels": onnx_label,
  128. "mask_input": onnx_mask_input,
  129. "has_mask_input": onnx_has_mask_input,
  130. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  131. }
  132. masks, _, _ = decoder_session.run(None, decoder_inputs)
  133. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  134. mask = mask > 0.0
  135. if 0:
  136. imgviz.io.imsave(
  137. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  138. )
  139. contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
  140. contour = max(contours, key=_get_contour_length)
  141. POLYGON_APPROX_TOLERANCE = 0.004
  142. polygon = skimage.measure.approximate_polygon(
  143. coords=contour,
  144. tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
  145. )
  146. polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
  147. polygon = polygon[:-1] # drop last point that is duplicate of first point
  148. if 0:
  149. image_pil = PIL.Image.fromarray(image)
  150. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  151. for point in polygon:
  152. imgviz.draw.circle_(
  153. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  154. )
  155. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  156. return polygon[:, ::-1] # yx -> xy