segment_anything.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import collections
  2. import threading
  3. import gdown
  4. import imgviz
  5. import numpy as np
  6. import onnxruntime
  7. import PIL.Image
  8. import skimage.measure
  9. class SegmentAnythingModel:
  10. def __init__(self):
  11. self._image_size = 1024
  12. # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx" # NOQA
  13. # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx" # NOQA
  14. #
  15. # encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx" # NOQA
  16. # decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx" # NOQA
  17. #
  18. # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx" # NOQA
  19. # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx" # NOQA
  20. encoder_path = gdown.cached_download(
  21. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  22. md5="080004dc9992724d360a49399d1ee24b",
  23. )
  24. decoder_path = gdown.cached_download(
  25. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  26. md5="851b7faac91e8e23940ee1294231d5c7",
  27. )
  28. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  29. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  30. self._lock = threading.Lock()
  31. self._image_embedding_cache = collections.OrderedDict()
  32. def set_image(self, image: np.ndarray):
  33. with self._lock:
  34. self._image = image
  35. self._image_embedding = self._image_embedding_cache.get(
  36. self._image.tobytes()
  37. )
  38. self._thread = threading.Thread(target=self.get_image_embedding)
  39. self._thread.start()
  40. def get_image_embedding(self):
  41. if self._image_embedding is None:
  42. with self._lock:
  43. self._image_embedding = compute_image_embedding(
  44. image_size=self._image_size,
  45. encoder_session=self._encoder_session,
  46. image=self._image,
  47. )
  48. if len(self._image_embedding_cache) > 10:
  49. self._image_embedding_cache.popitem(last=False)
  50. self._image_embedding_cache[
  51. self._image.tobytes()
  52. ] = self._image_embedding
  53. return self._image_embedding
  54. def points_to_polygon_callback(self, points, point_labels):
  55. self._thread.join()
  56. image_embedding = self.get_image_embedding()
  57. polygon = compute_polygon_from_points(
  58. image_size=self._image_size,
  59. decoder_session=self._decoder_session,
  60. image=self._image,
  61. image_embedding=image_embedding,
  62. points=points,
  63. point_labels=point_labels,
  64. )
  65. return polygon
  66. def compute_image_embedding(image_size, encoder_session, image):
  67. assert image.shape[1] > image.shape[0]
  68. scale = image_size / image.shape[1]
  69. x = imgviz.resize(
  70. image,
  71. height=int(round(image.shape[0] * scale)),
  72. width=image_size,
  73. backend="pillow",
  74. ).astype(np.float32)
  75. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  76. [58.395, 57.12, 57.375], dtype=np.float32
  77. )
  78. x = np.pad(
  79. x,
  80. (
  81. (0, image_size - x.shape[0]),
  82. (0, image_size - x.shape[1]),
  83. (0, 0),
  84. ),
  85. )
  86. x = x.transpose(2, 0, 1)[None, :, :, :]
  87. output = encoder_session.run(output_names=None, input_feed={"x": x})
  88. image_embedding = output[0]
  89. return image_embedding
  90. def _get_contour_length(contour):
  91. contour_start = contour
  92. contour_end = np.r_[contour[1:], contour[0:1]]
  93. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  94. def compute_polygon_from_points(
  95. image_size, decoder_session, image, image_embedding, points, point_labels
  96. ):
  97. input_point = np.array(points, dtype=np.float32)
  98. input_label = np.array(point_labels, dtype=np.int32)
  99. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  100. None, :, :
  101. ]
  102. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  103. None, :
  104. ].astype(np.float32)
  105. assert image.shape[1] > image.shape[0]
  106. scale = image_size / image.shape[1]
  107. new_height = int(round(image.shape[0] * scale))
  108. new_width = image_size
  109. onnx_coord = (
  110. onnx_coord.astype(float)
  111. * (new_width / image.shape[1], new_height / image.shape[0])
  112. ).astype(np.float32)
  113. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  114. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  115. decoder_inputs = {
  116. "image_embeddings": image_embedding,
  117. "point_coords": onnx_coord,
  118. "point_labels": onnx_label,
  119. "mask_input": onnx_mask_input,
  120. "has_mask_input": onnx_has_mask_input,
  121. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  122. }
  123. masks, _, _ = decoder_session.run(None, decoder_inputs)
  124. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  125. mask = mask > 0.0
  126. if 0:
  127. imgviz.io.imsave(
  128. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  129. )
  130. contours = skimage.measure.find_contours(mask)
  131. contour = max(contours, key=_get_contour_length)
  132. polygon = skimage.measure.approximate_polygon(
  133. coords=contour,
  134. tolerance=np.ptp(contour, axis=0).max() / 100,
  135. )
  136. if 0:
  137. image_pil = PIL.Image.fromarray(image)
  138. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  139. for point in polygon:
  140. imgviz.draw.circle_(
  141. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  142. )
  143. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  144. return polygon[:, ::-1] # yx -> xy