segment_anything.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import collections
  2. import threading
  3. import gdown
  4. import imgviz
  5. import numpy as np
  6. import onnxruntime
  7. import PIL.Image
  8. import skimage.measure
  9. from ...logger import logger
  10. class SegmentAnythingModel:
  11. def __init__(self, name, encoder_path, decoder_path):
  12. self.name = name
  13. self._image_size = 1024
  14. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  15. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  16. self._lock = threading.Lock()
  17. self._image_embedding_cache = collections.OrderedDict()
  18. self._thread = None
  19. def set_image(self, image: np.ndarray):
  20. with self._lock:
  21. self._image = image
  22. self._image_embedding = self._image_embedding_cache.get(
  23. self._image.tobytes()
  24. )
  25. if self._image_embedding is None:
  26. self._thread = threading.Thread(
  27. target=self._compute_and_cache_image_embedding
  28. )
  29. self._thread.start()
  30. def _compute_and_cache_image_embedding(self):
  31. with self._lock:
  32. logger.debug("Computing image embedding...")
  33. self._image_embedding = _compute_image_embedding(
  34. image_size=self._image_size,
  35. encoder_session=self._encoder_session,
  36. image=self._image,
  37. )
  38. if len(self._image_embedding_cache) > 10:
  39. self._image_embedding_cache.popitem(last=False)
  40. self._image_embedding_cache[
  41. self._image.tobytes()
  42. ] = self._image_embedding
  43. logger.debug("Done computing image embedding.")
  44. def _get_image_embedding(self):
  45. if self._thread is not None:
  46. self._thread.join()
  47. self._thread = None
  48. with self._lock:
  49. return self._image_embedding
  50. def predict_polygon_from_points(self, points, point_labels):
  51. image_embedding = self._get_image_embedding()
  52. polygon = _compute_polygon_from_points(
  53. image_size=self._image_size,
  54. decoder_session=self._decoder_session,
  55. image=self._image,
  56. image_embedding=image_embedding,
  57. points=points,
  58. point_labels=point_labels,
  59. )
  60. return polygon
  61. def _compute_scale_to_resize_image(image_size, image):
  62. height, width = image.shape[:2]
  63. if width > height:
  64. scale = image_size / width
  65. new_height = int(round(height * scale))
  66. new_width = image_size
  67. else:
  68. scale = image_size / height
  69. new_height = image_size
  70. new_width = int(round(width * scale))
  71. return scale, new_height, new_width
  72. def _resize_image(image_size, image):
  73. scale, new_height, new_width = _compute_scale_to_resize_image(
  74. image_size=image_size, image=image
  75. )
  76. scaled_image = imgviz.resize(
  77. image,
  78. height=new_height,
  79. width=new_width,
  80. backend="pillow",
  81. ).astype(np.float32)
  82. return scale, scaled_image
  83. def _compute_image_embedding(image_size, encoder_session, image):
  84. image = imgviz.asrgb(image)
  85. scale, x = _resize_image(image_size, image)
  86. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  87. [58.395, 57.12, 57.375], dtype=np.float32
  88. )
  89. x = np.pad(
  90. x,
  91. (
  92. (0, image_size - x.shape[0]),
  93. (0, image_size - x.shape[1]),
  94. (0, 0),
  95. ),
  96. )
  97. x = x.transpose(2, 0, 1)[None, :, :, :]
  98. output = encoder_session.run(output_names=None, input_feed={"x": x})
  99. image_embedding = output[0]
  100. return image_embedding
  101. def _get_contour_length(contour):
  102. contour_start = contour
  103. contour_end = np.r_[contour[1:], contour[0:1]]
  104. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  105. def _compute_polygon_from_points(
  106. image_size, decoder_session, image, image_embedding, points, point_labels
  107. ):
  108. input_point = np.array(points, dtype=np.float32)
  109. input_label = np.array(point_labels, dtype=np.int32)
  110. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  111. None, :, :
  112. ]
  113. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  114. None, :
  115. ].astype(np.float32)
  116. scale, new_height, new_width = _compute_scale_to_resize_image(
  117. image_size=image_size, image=image
  118. )
  119. onnx_coord = (
  120. onnx_coord.astype(float)
  121. * (new_width / image.shape[1], new_height / image.shape[0])
  122. ).astype(np.float32)
  123. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  124. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  125. decoder_inputs = {
  126. "image_embeddings": image_embedding,
  127. "point_coords": onnx_coord,
  128. "point_labels": onnx_label,
  129. "mask_input": onnx_mask_input,
  130. "has_mask_input": onnx_has_mask_input,
  131. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  132. }
  133. masks, _, _ = decoder_session.run(None, decoder_inputs)
  134. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  135. mask = mask > 0.0
  136. if 0:
  137. imgviz.io.imsave(
  138. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  139. )
  140. contours = skimage.measure.find_contours(mask)
  141. contour = max(contours, key=_get_contour_length)
  142. polygon = skimage.measure.approximate_polygon(
  143. coords=contour,
  144. tolerance=np.ptp(contour, axis=0).max() / 100,
  145. )
  146. polygon = polygon[:-1] # drop last point that is duplicate of first point
  147. if 0:
  148. image_pil = PIL.Image.fromarray(image)
  149. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  150. for point in polygon:
  151. imgviz.draw.circle_(
  152. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  153. )
  154. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  155. return polygon[:, ::-1] # yx -> xy