__init__.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import gdown
  2. from .efficient_sam import EfficientSam
  3. from .segment_anything_model import SegmentAnythingModel
  4. from .barcode_model import BarcodePredictModel
  5. class BarcodePredict(BarcodePredictModel):
  6. name="BarcodePredict(ov)"
  7. def __init__(self, detection_model_path=None, segmentation_model_path=None):
  8. super().__init__(
  9. detection_model_path=detection_model_path,
  10. segmentation_model_path=segmentation_model_path
  11. )
  12. class SegmentAnythingModelVitB(SegmentAnythingModel):
  13. name = "SegmentAnything (speed)"
  14. def __init__(self,model_path=None):
  15. super().__init__(
  16. encoder_path=gdown.cached_download(
  17. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx", # NOQA
  18. md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
  19. ),
  20. decoder_path=gdown.cached_download(
  21. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx", # NOQA
  22. md5="4253558be238c15fc265a7a876aaec82",
  23. ),
  24. )
  25. class SegmentAnythingModelVitL(SegmentAnythingModel):
  26. name = "SegmentAnything (balanced)"
  27. def __init__(self,model_path=None):
  28. super().__init__(
  29. encoder_path=gdown.cached_download(
  30. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
  31. md5="080004dc9992724d360a49399d1ee24b",
  32. ),
  33. decoder_path=gdown.cached_download(
  34. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
  35. md5="851b7faac91e8e23940ee1294231d5c7",
  36. ),
  37. )
  38. class SegmentAnythingModelVitH(SegmentAnythingModel):
  39. name = "SegmentAnything (accuracy)"
  40. def __init__(self,model_path=None):
  41. super().__init__(
  42. encoder_path=gdown.cached_download(
  43. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx", # NOQA
  44. md5="958b5710d25b198d765fb6b94798f49e",
  45. ),
  46. decoder_path=gdown.cached_download(
  47. url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx", # NOQA
  48. md5="a997a408347aa081b17a3ffff9f42a80",
  49. ),
  50. )
  51. class EfficientSamVitT(EfficientSam):
  52. name = "EfficientSam (speed)"
  53. def __init__(self,model_path=None):
  54. super().__init__(
  55. encoder_path=gdown.cached_download(
  56. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_encoder.onnx", # NOQA
  57. md5="2d4a1303ff0e19fe4a8b8ede69c2f5c7",
  58. ),
  59. decoder_path=gdown.cached_download(
  60. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_decoder.onnx", # NOQA
  61. md5="be3575ca4ed9b35821ac30991ab01843",
  62. ),
  63. )
  64. class EfficientSamVitS(EfficientSam):
  65. name = "EfficientSam (accuracy)"
  66. def __init__(self,model_path=None):
  67. super().__init__(
  68. encoder_path=gdown.cached_download(
  69. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_encoder.onnx", # NOQA
  70. md5="7d97d23e8e0847d4475ca7c9f80da96d",
  71. ),
  72. decoder_path=gdown.cached_download(
  73. url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_decoder.onnx", # NOQA
  74. md5="d9372f4a7bbb1a01d236b0508300b994",
  75. ),
  76. )
  77. MODELS = [
  78. SegmentAnythingModelVitB,
  79. SegmentAnythingModelVitL,
  80. SegmentAnythingModelVitH,
  81. EfficientSamVitT,
  82. EfficientSamVitS,
  83. BarcodePredict,
  84. ]