segment_anything.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import collections
  2. import threading
  3. import gdown
  4. import imgviz
  5. import numpy as np
  6. import onnxruntime
  7. import PIL.Image
  8. import skimage.measure
  9. from ...logger import logger
  10. class SegmentAnythingModel:
  11. def __init__(self):
  12. self._image_size = 1024
  13. # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx" # NOQA
  14. # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx" # NOQA
  15. #
  16. # encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx" # NOQA
  17. # decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx" # NOQA
  18. #
  19. # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx" # NOQA
  20. # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx" # NOQA
  21. encoder_path = gdown.cached_download(
  22. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  23. md5="080004dc9992724d360a49399d1ee24b",
  24. )
  25. decoder_path = gdown.cached_download(
  26. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  27. md5="851b7faac91e8e23940ee1294231d5c7",
  28. )
  29. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  30. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  31. self._lock = threading.Lock()
  32. self._image_embedding_cache = collections.OrderedDict()
  33. def set_image(self, image: np.ndarray):
  34. with self._lock:
  35. self._image = image
  36. self._image_embedding = self._image_embedding_cache.get(
  37. self._image.tobytes()
  38. )
  39. self._thread = threading.Thread(
  40. target=self._compute_and_cache_image_embedding
  41. )
  42. self._thread.start()
  43. def _compute_and_cache_image_embedding(self):
  44. with self._lock:
  45. logger.debug("Computing image embedding...")
  46. self._image_embedding = _compute_image_embedding(
  47. image_size=self._image_size,
  48. encoder_session=self._encoder_session,
  49. image=self._image,
  50. )
  51. if len(self._image_embedding_cache) > 10:
  52. self._image_embedding_cache.popitem(last=False)
  53. self._image_embedding_cache[
  54. self._image.tobytes()
  55. ] = self._image_embedding
  56. logger.debug("Done computing image embedding.")
  57. def _get_image_embedding(self):
  58. if self._thread is not None:
  59. self._thread.join()
  60. self._thread = None
  61. with self._lock:
  62. return self._image_embedding
  63. def predict_polygon_from_points(self, points, point_labels):
  64. image_embedding = self._get_image_embedding()
  65. polygon = _compute_polygon_from_points(
  66. image_size=self._image_size,
  67. decoder_session=self._decoder_session,
  68. image=self._image,
  69. image_embedding=image_embedding,
  70. points=points,
  71. point_labels=point_labels,
  72. )
  73. return polygon
  74. def _compute_scale_to_resize_image(image_size, image):
  75. height, width = image.shape[:2]
  76. if width > height:
  77. scale = image_size / width
  78. new_height = int(round(height * scale))
  79. new_width = image_size
  80. else:
  81. scale = image_size / height
  82. new_height = image_size
  83. new_width = int(round(width * scale))
  84. return scale, new_height, new_width
  85. def _resize_image(image_size, image):
  86. scale, new_height, new_width = _compute_scale_to_resize_image(
  87. image_size=image_size, image=image
  88. )
  89. scaled_image = imgviz.resize(
  90. image,
  91. height=new_height,
  92. width=new_width,
  93. backend="pillow",
  94. ).astype(np.float32)
  95. return scale, scaled_image
  96. def _compute_image_embedding(image_size, encoder_session, image):
  97. image = imgviz.asrgb(image)
  98. scale, x = _resize_image(image_size, image)
  99. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  100. [58.395, 57.12, 57.375], dtype=np.float32
  101. )
  102. x = np.pad(
  103. x,
  104. (
  105. (0, image_size - x.shape[0]),
  106. (0, image_size - x.shape[1]),
  107. (0, 0),
  108. ),
  109. )
  110. x = x.transpose(2, 0, 1)[None, :, :, :]
  111. output = encoder_session.run(output_names=None, input_feed={"x": x})
  112. image_embedding = output[0]
  113. return image_embedding
  114. def _get_contour_length(contour):
  115. contour_start = contour
  116. contour_end = np.r_[contour[1:], contour[0:1]]
  117. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  118. def _compute_polygon_from_points(
  119. image_size, decoder_session, image, image_embedding, points, point_labels
  120. ):
  121. input_point = np.array(points, dtype=np.float32)
  122. input_label = np.array(point_labels, dtype=np.int32)
  123. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  124. None, :, :
  125. ]
  126. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  127. None, :
  128. ].astype(np.float32)
  129. scale, new_height, new_width = _compute_scale_to_resize_image(
  130. image_size=image_size, image=image
  131. )
  132. onnx_coord = (
  133. onnx_coord.astype(float)
  134. * (new_width / image.shape[1], new_height / image.shape[0])
  135. ).astype(np.float32)
  136. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  137. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  138. decoder_inputs = {
  139. "image_embeddings": image_embedding,
  140. "point_coords": onnx_coord,
  141. "point_labels": onnx_label,
  142. "mask_input": onnx_mask_input,
  143. "has_mask_input": onnx_has_mask_input,
  144. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  145. }
  146. masks, _, _ = decoder_session.run(None, decoder_inputs)
  147. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  148. mask = mask > 0.0
  149. if 0:
  150. imgviz.io.imsave(
  151. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  152. )
  153. contours = skimage.measure.find_contours(mask)
  154. contour = max(contours, key=_get_contour_length)
  155. polygon = skimage.measure.approximate_polygon(
  156. coords=contour,
  157. tolerance=np.ptp(contour, axis=0).max() / 100,
  158. )[:-1] # drop last point that is duplicate of first point
  159. if 0:
  160. image_pil = PIL.Image.fromarray(image)
  161. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  162. for point in polygon:
  163. imgviz.draw.circle_(
  164. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  165. )
  166. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  167. return polygon[:, ::-1] # yx -> xy