__init__.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import gdown
  2. from .efficient_sam import EfficientSam
  3. from .segment_anything_model import SegmentAnythingModel
  4. class SegmentAnythingModelVitB(SegmentAnythingModel):
  5. name = "SegmentAnything (speed)"
  6. def __init__(self):
  7. super().__init__(
  8. encoder_path=gdown.cached_download(
  9. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx", # NOQA
  10. md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
  11. ),
  12. decoder_path=gdown.cached_download(
  13. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx", # NOQA
  14. md5="4253558be238c15fc265a7a876aaec82",
  15. ),
  16. )
  17. class SegmentAnythingModelVitL(SegmentAnythingModel):
  18. name = "SegmentAnything (balanced)"
  19. def __init__(self):
  20. super().__init__(
  21. encoder_path=gdown.cached_download(
  22. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  23. md5="080004dc9992724d360a49399d1ee24b",
  24. ),
  25. decoder_path=gdown.cached_download(
  26. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  27. md5="851b7faac91e8e23940ee1294231d5c7",
  28. ),
  29. )
  30. class SegmentAnythingModelVitH(SegmentAnythingModel):
  31. name = "SegmentAnything (accuracy)"
  32. def __init__(self):
  33. super().__init__(
  34. encoder_path=gdown.cached_download(
  35. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx", # NOQA
  36. md5="958b5710d25b198d765fb6b94798f49e",
  37. ),
  38. decoder_path=gdown.cached_download(
  39. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx", # NOQA
  40. md5="a997a408347aa081b17a3ffff9f42a80",
  41. ),
  42. )
  43. class EfficientSamVitT(EfficientSam):
  44. name = "EfficientSam (speed)"
  45. def __init__(self):
  46. super().__init__(
  47. encoder_path=gdown.cached_download(
  48. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_encoder.onnx", # NOQA
  49. md5="2d4a1303ff0e19fe4a8b8ede69c2f5c7",
  50. ),
  51. decoder_path=gdown.cached_download(
  52. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_decoder.onnx", # NOQA
  53. md5="be3575ca4ed9b35821ac30991ab01843",
  54. ),
  55. )
  56. class EfficientSamVitS(EfficientSam):
  57. name = "EfficientSam (accuracy)"
  58. def __init__(self):
  59. super().__init__(
  60. encoder_path=gdown.cached_download(
  61. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_encoder.onnx", # NOQA
  62. md5="7d97d23e8e0847d4475ca7c9f80da96d",
  63. ),
  64. decoder_path=gdown.cached_download(
  65. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_decoder.onnx", # NOQA
  66. md5="d9372f4a7bbb1a01d236b0508300b994",
  67. ),
  68. )
  69. MODELS = [
  70. SegmentAnythingModelVitB,
  71. SegmentAnythingModelVitL,
  72. SegmentAnythingModelVitH,
  73. EfficientSamVitT,
  74. EfficientSamVitS,
  75. ]