segment_anything.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import threading
  2. import imgviz
  3. import numpy as np
  4. import onnxruntime
  5. import PIL.Image
  6. import skimage.measure
  7. class SegmentAnythingModel:
  8. def __init__(self):
  9. self._image_size = 1024
  10. # encoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.encoder.onnx" # NOQA
  11. # decoder_path = "../segment-anything/models/sam_vit_h_4b8939.quantized.decoder.onnx" # NOQA
  12. encoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.encoder.onnx" # NOQA
  13. decoder_path = "../segment-anything/models/sam_vit_l_0b3195.quantized.decoder.onnx" # NOQA
  14. # encoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.encoder.onnx" # NOQA
  15. # decoder_path = "../segment-anything/models/sam_vit_b_01ec64.quantized.decoder.onnx" # NOQA
  16. self._encoder_session = onnxruntime.InferenceSession(encoder_path)
  17. self._decoder_session = onnxruntime.InferenceSession(decoder_path)
  18. def set_image(self, image: np.ndarray):
  19. self._image = image
  20. self._image_embedding = None
  21. self._thread = threading.Thread(target=self.get_image_embedding)
  22. self._thread.start()
  23. def get_image_embedding(self):
  24. if self._image_embedding is None:
  25. self._image_embedding = compute_image_embedding(
  26. image_size=self._image_size,
  27. encoder_session=self._encoder_session,
  28. image=self._image,
  29. )
  30. return self._image_embedding
  31. def points_to_polygon_callback(self, points, point_labels):
  32. self._thread.join()
  33. image_embedding = self.get_image_embedding()
  34. polygon = compute_polygon_from_points(
  35. image_size=self._image_size,
  36. decoder_session=self._decoder_session,
  37. image=self._image,
  38. image_embedding=image_embedding,
  39. points=points,
  40. point_labels=point_labels,
  41. )
  42. return polygon
  43. def compute_image_embedding(image_size, encoder_session, image):
  44. assert image.shape[1] > image.shape[0]
  45. scale = image_size / image.shape[1]
  46. x = imgviz.resize(
  47. image,
  48. height=int(round(image.shape[0] * scale)),
  49. width=image_size,
  50. backend="pillow",
  51. ).astype(np.float32)
  52. x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
  53. [58.395, 57.12, 57.375], dtype=np.float32
  54. )
  55. x = np.pad(
  56. x,
  57. (
  58. (0, image_size - x.shape[0]),
  59. (0, image_size - x.shape[1]),
  60. (0, 0),
  61. ),
  62. )
  63. x = x.transpose(2, 0, 1)[None, :, :, :]
  64. output = encoder_session.run(output_names=None, input_feed={"x": x})
  65. image_embedding = output[0]
  66. return image_embedding
  67. def _get_contour_length(contour):
  68. contour_start = contour
  69. contour_end = np.r_[contour[1:], contour[0:1]]
  70. return np.linalg.norm(contour_end - contour_start, axis=1).sum()
  71. def compute_polygon_from_points(
  72. image_size, decoder_session, image, image_embedding, points, point_labels
  73. ):
  74. input_point = np.array(points, dtype=np.float32)
  75. input_label = np.array(point_labels, dtype=np.int32)
  76. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
  77. None, :, :
  78. ]
  79. onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
  80. None, :
  81. ].astype(np.float32)
  82. assert image.shape[1] > image.shape[0]
  83. scale = image_size / image.shape[1]
  84. new_height = int(round(image.shape[0] * scale))
  85. new_width = image_size
  86. onnx_coord = (
  87. onnx_coord.astype(float)
  88. * (new_width / image.shape[1], new_height / image.shape[0])
  89. ).astype(np.float32)
  90. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  91. onnx_has_mask_input = np.array([-1], dtype=np.float32)
  92. decoder_inputs = {
  93. "image_embeddings": image_embedding,
  94. "point_coords": onnx_coord,
  95. "point_labels": onnx_label,
  96. "mask_input": onnx_mask_input,
  97. "has_mask_input": onnx_has_mask_input,
  98. "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
  99. }
  100. masks, _, _ = decoder_session.run(None, decoder_inputs)
  101. mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
  102. mask = mask > 0.0
  103. if 0:
  104. imgviz.io.imsave(
  105. "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
  106. )
  107. contours = skimage.measure.find_contours(mask)
  108. contour = max(contours, key=_get_contour_length)
  109. polygon = skimage.measure.approximate_polygon(
  110. coords=contour,
  111. tolerance=np.ptp(contour, axis=0).max() / 100,
  112. )
  113. if 0:
  114. image_pil = PIL.Image.fromarray(image)
  115. imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
  116. for point in polygon:
  117. imgviz.draw.circle_(
  118. image_pil, center=point, diameter=10, fill=(0, 255, 0)
  119. )
  120. imgviz.io.imsave("contour.jpg", np.asarray(image_pil))
  121. return polygon[:, ::-1] # yx -> xy