segment_anything.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import threading
  2. import gdown
  3. import imgviz
  4. import numpy as np
  5. import onnxruntime
  6. import PIL.Image
  7. import skimage.measure
  8. class SegmentAnythingModel:
  9. def __init__(self):
  10. self._image_size = 1024
  11. # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx" # NOQA
  12. # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx" # NOQA
  13. #
  14. # encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx" # NOQA
  15. # decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx" # NOQA
  16. #
  17. # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx" # NOQA
  18. # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx" # NOQA
  19. encoder_path = gdown.cached_download(
  20. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  21. md5="080004dc9992724d360a49399d1ee24b",
  22. )
  23. decoder_path = gdown.cached_download(
  24. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  25. md5="851b7faac91e8e23940ee1294231d5c7",
  26. )
  27. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  28. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  29. def set_image(self, image: np.ndarray):
  30. self._image = image
  31. self._image_embedding = None
  32. self._thread = threading.Thread(target=self.get_image_embedding)
  33. self._thread.start()
  34. def get_image_embedding(self):
  35. if self._image_embedding is None:
  36. self._image_embedding = compute_image_embedding(
  37. image_size=self._image_size,
  38. encoder_session=self._encoder_session,
  39. image=self._image,
  40. )
  41. return self._image_embedding
  42. def points_to_polygon_callback(self, points, point_labels):
  43. self._thread.join()
  44. image_embedding = self.get_image_embedding()
  45. polygon = compute_polygon_from_points(
  46. image_size=self._image_size,
  47. decoder_session=self._decoder_session,
  48. image=self._image,
  49. image_embedding=image_embedding,
  50. points=points,
  51. point_labels=point_labels,
  52. )
  53. return polygon
  54. def compute_image_embedding(image_size, encoder_session, image):
  55. assert image.shape[1] > image.shape[0]
  56. scale = image_size / image.shape[1]
  57. x = imgviz.resize(
  58. image,
  59. height=int(round(image.shape[0] * scale)),
  60. width=image_size,
  61. backend="pillow",
  62. ).astype(np.float32)
  63. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  64. [58.395, 57.12, 57.375], dtype=np.float32
  65. )
  66. x = np.pad(
  67. x,
  68. (
  69. (0, image_size - x.shape[0]),
  70. (0, image_size - x.shape[1]),
  71. (0, 0),
  72. ),
  73. )
  74. x = x.transpose(2, 0, 1)[None, :, :, :]
  75. output = encoder_session.run(output_names=None, input_feed={"x": x})
  76. image_embedding = output[0]
  77. return image_embedding
  78. def _get_contour_length(contour):
  79. contour_start = contour
  80. contour_end = np.r_[contour[1:], contour[0:1]]
  81. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  82. def compute_polygon_from_points(
  83. image_size, decoder_session, image, image_embedding, points, point_labels
  84. ):
  85. input_point = np.array(points, dtype=np.float32)
  86. input_label = np.array(point_labels, dtype=np.int32)
  87. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  88. None, :, :
  89. ]
  90. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  91. None, :
  92. ].astype(np.float32)
  93. assert image.shape[1] > image.shape[0]
  94. scale = image_size / image.shape[1]
  95. new_height = int(round(image.shape[0] * scale))
  96. new_width = image_size
  97. onnx_coord = (
  98. onnx_coord.astype(float)
  99. * (new_width / image.shape[1], new_height / image.shape[0])
  100. ).astype(np.float32)
  101. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  102. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  103. decoder_inputs = {
  104. "image_embeddings": image_embedding,
  105. "point_coords": onnx_coord,
  106. "point_labels": onnx_label,
  107. "mask_input": onnx_mask_input,
  108. "has_mask_input": onnx_has_mask_input,
  109. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  110. }
  111. masks, _, _ = decoder_session.run(None, decoder_inputs)
  112. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  113. mask = mask > 0.0
  114. if 0:
  115. imgviz.io.imsave(
  116. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  117. )
  118. contours = skimage.measure.find_contours(mask)
  119. contour = max(contours, key=_get_contour_length)
  120. polygon = skimage.measure.approximate_polygon(
  121. coords=contour,
  122. tolerance=np.ptp(contour, axis=0).max() / 100,
  123. )
  124. if 0:
  125. image_pil = PIL.Image.fromarray(image)
  126. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  127. for point in polygon:
  128. imgviz.draw.circle_(
  129. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  130. )
  131. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  132. return polygon[:, ::-1] # yx -> xy