segment_anything.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import collections
  2. import threading
  3. import gdown
  4. import imgviz
  5. import numpy as np
  6. import onnxruntime
  7. import PIL.Image
  8. import skimage.measure
  9. from ...logger import logger
  10. class SegmentAnythingModel:
  11. def __init__(self):
  12. self._image_size = 1024
  13. # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx" # NOQA
  14. # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx" # NOQA
  15. #
  16. # encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx" # NOQA
  17. # decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx" # NOQA
  18. #
  19. # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx" # NOQA
  20. # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx" # NOQA
  21. encoder_path = gdown.cached_download(
  22. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  23. md5="080004dc9992724d360a49399d1ee24b",
  24. )
  25. decoder_path = gdown.cached_download(
  26. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  27. md5="851b7faac91e8e23940ee1294231d5c7",
  28. )
  29. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  30. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  31. self._lock = threading.Lock()
  32. self._image_embedding_cache = collections.OrderedDict()
  33. def set_image(self, image: np.ndarray):
  34. with self._lock:
  35. self._image = image
  36. self._image_embedding = self._image_embedding_cache.get(
  37. self._image.tobytes()
  38. )
  39. self._thread = threading.Thread(target=self._get_image_embedding)
  40. self._thread.start()
  41. def _get_image_embedding(self):
  42. if self._image_embedding is None:
  43. logger.debug("Computing image embedding...")
  44. with self._lock:
  45. self._image_embedding = _compute_image_embedding(
  46. image_size=self._image_size,
  47. encoder_session=self._encoder_session,
  48. image=self._image,
  49. )
  50. if len(self._image_embedding_cache) > 10:
  51. self._image_embedding_cache.popitem(last=False)
  52. self._image_embedding_cache[
  53. self._image.tobytes()
  54. ] = self._image_embedding
  55. logger.debug("Done computing image embedding.")
  56. return self._image_embedding
  57. def points_to_polygon_callback(self, points, point_labels):
  58. logger.debug("Waiting for image embedding...")
  59. self._thread.join()
  60. image_embedding = self._get_image_embedding()
  61. logger.debug("Done waiting for image embedding.")
  62. polygon = _compute_polygon_from_points(
  63. image_size=self._image_size,
  64. decoder_session=self._decoder_session,
  65. image=self._image,
  66. image_embedding=image_embedding,
  67. points=points,
  68. point_labels=point_labels,
  69. )
  70. return polygon
  71. def _compute_scale_to_resize_image(image_size, image):
  72. height, width = image.shape[:2]
  73. if width > height:
  74. scale = image_size / width
  75. new_height = int(round(height * scale))
  76. new_width = image_size
  77. else:
  78. scale = image_size / height
  79. new_height = image_size
  80. new_width = int(round(width * scale))
  81. return scale, new_height, new_width
  82. def _resize_image(image_size, image):
  83. scale, new_height, new_width = _compute_scale_to_resize_image(
  84. image_size=image_size, image=image
  85. )
  86. scaled_image = imgviz.resize(
  87. image,
  88. height=new_height,
  89. width=new_width,
  90. backend="pillow",
  91. ).astype(np.float32)
  92. return scale, scaled_image
  93. def _compute_image_embedding(image_size, encoder_session, image):
  94. image = imgviz.asrgb(image)
  95. scale, x = _resize_image(image_size, image)
  96. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  97. [58.395, 57.12, 57.375], dtype=np.float32
  98. )
  99. x = np.pad(
  100. x,
  101. (
  102. (0, image_size - x.shape[0]),
  103. (0, image_size - x.shape[1]),
  104. (0, 0),
  105. ),
  106. )
  107. x = x.transpose(2, 0, 1)[None, :, :, :]
  108. output = encoder_session.run(output_names=None, input_feed={"x": x})
  109. image_embedding = output[0]
  110. return image_embedding
  111. def _get_contour_length(contour):
  112. contour_start = contour
  113. contour_end = np.r_[contour[1:], contour[0:1]]
  114. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  115. def _compute_polygon_from_points(
  116. image_size, decoder_session, image, image_embedding, points, point_labels
  117. ):
  118. input_point = np.array(points, dtype=np.float32)
  119. input_label = np.array(point_labels, dtype=np.int32)
  120. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  121. None, :, :
  122. ]
  123. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  124. None, :
  125. ].astype(np.float32)
  126. scale, new_height, new_width = _compute_scale_to_resize_image(
  127. image_size=image_size, image=image
  128. )
  129. onnx_coord = (
  130. onnx_coord.astype(float)
  131. * (new_width / image.shape[1], new_height / image.shape[0])
  132. ).astype(np.float32)
  133. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  134. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  135. decoder_inputs = {
  136. "image_embeddings": image_embedding,
  137. "point_coords": onnx_coord,
  138. "point_labels": onnx_label,
  139. "mask_input": onnx_mask_input,
  140. "has_mask_input": onnx_has_mask_input,
  141. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  142. }
  143. masks, _, _ = decoder_session.run(None, decoder_inputs)
  144. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  145. mask = mask > 0.0
  146. if 0:
  147. imgviz.io.imsave(
  148. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  149. )
  150. contours = skimage.measure.find_contours(mask)
  151. contour = max(contours, key=_get_contour_length)
  152. polygon = skimage.measure.approximate_polygon(
  153. coords=contour,
  154. tolerance=np.ptp(contour, axis=0).max() / 100,
  155. )
  156. if 0:
  157. image_pil = PIL.Image.fromarray(image)
  158. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  159. for point in polygon:
  160. imgviz.draw.circle_(
  161. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  162. )
  163. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  164. return polygon[:, ::-1] # yx -> xy