segment_anything.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import collections
  2. import threading
  3. import gdown
  4. import imgviz
  5. import numpy as np
  6. import onnxruntime
  7. import PIL.Image
  8. import skimage.measure
  9. from ...logger import logger
  10. class SegmentAnythingModel:
  11. def __init__(self):
  12. self._image_size = 1024
  13. # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx" # NOQA
  14. # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx" # NOQA
  15. #
  16. # encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx" # NOQA
  17. # decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx" # NOQA
  18. #
  19. # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx" # NOQA
  20. # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx" # NOQA
  21. encoder_path = gdown.cached_download(
  22. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  23. md5="080004dc9992724d360a49399d1ee24b",
  24. )
  25. decoder_path = gdown.cached_download(
  26. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  27. md5="851b7faac91e8e23940ee1294231d5c7",
  28. )
  29. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  30. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  31. self._lock = threading.Lock()
  32. self._image_embedding_cache = collections.OrderedDict()
  33. def set_image(self, image: np.ndarray):
  34. with self._lock:
  35. self._image = image
  36. self._image_embedding = self._image_embedding_cache.get(
  37. self._image.tobytes()
  38. )
  39. self._thread = threading.Thread(target=self.get_image_embedding)
  40. self._thread.start()
  41. def get_image_embedding(self):
  42. if self._image_embedding is None:
  43. logger.debug("Computing image embedding...")
  44. with self._lock:
  45. self._image_embedding = compute_image_embedding(
  46. image_size=self._image_size,
  47. encoder_session=self._encoder_session,
  48. image=self._image,
  49. )
  50. if len(self._image_embedding_cache) > 10:
  51. self._image_embedding_cache.popitem(last=False)
  52. self._image_embedding_cache[
  53. self._image.tobytes()
  54. ] = self._image_embedding
  55. logger.debug("Done computing image embedding.")
  56. return self._image_embedding
  57. def points_to_polygon_callback(self, points, point_labels):
  58. logger.debug("Waiting for image embedding...")
  59. self._thread.join()
  60. image_embedding = self.get_image_embedding()
  61. logger.debug("Done waiting for image embedding.")
  62. polygon = compute_polygon_from_points(
  63. image_size=self._image_size,
  64. decoder_session=self._decoder_session,
  65. image=self._image,
  66. image_embedding=image_embedding,
  67. points=points,
  68. point_labels=point_labels,
  69. )
  70. return polygon
  71. def compute_image_embedding(image_size, encoder_session, image):
  72. assert image.shape[1] > image.shape[0]
  73. scale = image_size / image.shape[1]
  74. x = imgviz.resize(
  75. image,
  76. height=int(round(image.shape[0] * scale)),
  77. width=image_size,
  78. backend="pillow",
  79. ).astype(np.float32)
  80. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  81. [58.395, 57.12, 57.375], dtype=np.float32
  82. )
  83. x = np.pad(
  84. x,
  85. (
  86. (0, image_size - x.shape[0]),
  87. (0, image_size - x.shape[1]),
  88. (0, 0),
  89. ),
  90. )
  91. x = x.transpose(2, 0, 1)[None, :, :, :]
  92. output = encoder_session.run(output_names=None, input_feed={"x": x})
  93. image_embedding = output[0]
  94. return image_embedding
  95. def _get_contour_length(contour):
  96. contour_start = contour
  97. contour_end = np.r_[contour[1:], contour[0:1]]
  98. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  99. def compute_polygon_from_points(
  100. image_size, decoder_session, image, image_embedding, points, point_labels
  101. ):
  102. input_point = np.array(points, dtype=np.float32)
  103. input_label = np.array(point_labels, dtype=np.int32)
  104. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  105. None, :, :
  106. ]
  107. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  108. None, :
  109. ].astype(np.float32)
  110. assert image.shape[1] > image.shape[0]
  111. scale = image_size / image.shape[1]
  112. new_height = int(round(image.shape[0] * scale))
  113. new_width = image_size
  114. onnx_coord = (
  115. onnx_coord.astype(float)
  116. * (new_width / image.shape[1], new_height / image.shape[0])
  117. ).astype(np.float32)
  118. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  119. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  120. decoder_inputs = {
  121. "image_embeddings": image_embedding,
  122. "point_coords": onnx_coord,
  123. "point_labels": onnx_label,
  124. "mask_input": onnx_mask_input,
  125. "has_mask_input": onnx_has_mask_input,
  126. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  127. }
  128. masks, _, _ = decoder_session.run(None, decoder_inputs)
  129. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  130. mask = mask > 0.0
  131. if 0:
  132. imgviz.io.imsave(
  133. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  134. )
  135. contours = skimage.measure.find_contours(mask)
  136. contour = max(contours, key=_get_contour_length)
  137. polygon = skimage.measure.approximate_polygon(
  138. coords=contour,
  139. tolerance=np.ptp(contour, axis=0).max() / 100,
  140. )
  141. if 0:
  142. image_pil = PIL.Image.fromarray(image)
  143. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  144. for point in polygon:
  145. imgviz.draw.circle_(
  146. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  147. )
  148. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  149. return polygon[:, ::-1] # yx -> xy