__init__.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import gdown
  2. from .efficient_sam import EfficientSam
  3. from .segment_anything_model import SegmentAnythingModel
  4. from .barcode_model import BarcodePredictModel
  5. class BarcodePredict(BarcodePredictModel):
  6. name="BarcodePredict(ov)"
  7. def __init__(self,model_path=None):
  8. super().__init__(
  9. model_path=model_path
  10. )
  11. class SegmentAnythingModelVitB(SegmentAnythingModel):
  12. name = "SegmentAnything (speed)"
  13. def __init__(self,model_path=None):
  14. super().__init__(
  15. encoder_path=gdown.cached_download(
  16. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx", # NOQA
  17. md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
  18. ),
  19. decoder_path=gdown.cached_download(
  20. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx", # NOQA
  21. md5="4253558be238c15fc265a7a876aaec82",
  22. ),
  23. )
  24. class SegmentAnythingModelVitL(SegmentAnythingModel):
  25. name = "SegmentAnything (balanced)"
  26. def __init__(self,model_path=None):
  27. super().__init__(
  28. encoder_path=gdown.cached_download(
  29. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  30. md5="080004dc9992724d360a49399d1ee24b",
  31. ),
  32. decoder_path=gdown.cached_download(
  33. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  34. md5="851b7faac91e8e23940ee1294231d5c7",
  35. ),
  36. )
  37. class SegmentAnythingModelVitH(SegmentAnythingModel):
  38. name = "SegmentAnything (accuracy)"
  39. def __init__(self,model_path=None):
  40. super().__init__(
  41. encoder_path=gdown.cached_download(
  42. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx", # NOQA
  43. md5="958b5710d25b198d765fb6b94798f49e",
  44. ),
  45. decoder_path=gdown.cached_download(
  46. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx", # NOQA
  47. md5="a997a408347aa081b17a3ffff9f42a80",
  48. ),
  49. )
  50. class EfficientSamVitT(EfficientSam):
  51. name = "EfficientSam (speed)"
  52. def __init__(self,model_path=None):
  53. super().__init__(
  54. encoder_path=gdown.cached_download(
  55. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_encoder.onnx", # NOQA
  56. md5="2d4a1303ff0e19fe4a8b8ede69c2f5c7",
  57. ),
  58. decoder_path=gdown.cached_download(
  59. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_decoder.onnx", # NOQA
  60. md5="be3575ca4ed9b35821ac30991ab01843",
  61. ),
  62. )
  63. class EfficientSamVitS(EfficientSam):
  64. name = "EfficientSam (accuracy)"
  65. def __init__(self,model_path=None):
  66. super().__init__(
  67. encoder_path=gdown.cached_download(
  68. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_encoder.onnx", # NOQA
  69. md5="7d97d23e8e0847d4475ca7c9f80da96d",
  70. ),
  71. decoder_path=gdown.cached_download(
  72. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_decoder.onnx", # NOQA
  73. md5="d9372f4a7bbb1a01d236b0508300b994",
  74. ),
  75. )
  76. MODELS = [
  77. SegmentAnythingModelVitB,
  78. SegmentAnythingModelVitL,
  79. SegmentAnythingModelVitH,
  80. EfficientSamVitT,
  81. EfficientSamVitS,
  82. BarcodePredict,
  83. ]