barcode_model.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import collections
  2. import threading
  3. import numpy as np
  4. import openvino as ov
  5. import os.path as osp
  6. import cv2
  7. from ..logger import logger
  8. from . import _utils
  9. from labelme.utils import img_qt_to_arr
  10. from qtpy import QtGui
  11. class Normalize:
  12. def __init__(self, mean=(0.5,), std=(0.5,)):
  13. if not (isinstance(mean, (list, tuple)) and isinstance(std, (list, tuple))):
  14. raise ValueError("mean and std should be of type list or tuple.")
  15. self.mean = np.array(mean, dtype=np.float32)
  16. self.std = np.array(std, dtype=np.float32)
  17. if np.any(self.std == 0):
  18. raise ValueError("std should not contain zero values.")
  19. def __call__(self, img):
  20. img = img.astype(np.float32) / 255.0 # Scale pixel values to [0, 1]
  21. img = (img - self.mean) / self.std # Normalize
  22. return img
  23. class BarcodePredictModel:
  24. def __init__(self, detection_model_path, segmentation_model_path=None):
  25. self.ie = ov.Core()
  26. # Load detection model
  27. self.detection_net = self.ie.read_model(model=detection_model_path)
  28. self.detection_sess = self.ie.compile_model(model=self.detection_net, device_name="CPU")
  29. self.detection_request = self.detection_sess.create_infer_request()
  30. # Load segmentation model if provided
  31. self.segmentation_net = None
  32. self.segmentation_sess = None
  33. if segmentation_model_path:
  34. self.segmentation_net = self.ie.read_model(model=segmentation_model_path)
  35. self.segmentation_sess = self.ie.compile_model(model=self.segmentation_net, device_name="CPU")
  36. self._lock = threading.Lock()
  37. self.input_height = 640 # Input shape for detection model (example size)
  38. self.input_width = 640
  39. self.segmentation_input_shape = (1, 3, 128, 256) # Input shape for segmentation model
  40. self._image_embedding_cache = collections.OrderedDict()
  41. self._max_cache_size = 10
  42. self.normalize = Normalize() # Normalization instance
  43. def set_image(self, image: np.ndarray):
  44. with self._lock:
  45. self.raw_width = image.shape[1]
  46. self.raw_height = image.shape[0]
  47. # Preprocess the image
  48. input_tensor = self.preprocess_image(image)
  49. self._image = input_tensor
  50. # Prepare other inputs
  51. self._im_shape = np.array([[self.raw_height, self.raw_width]], dtype=np.float32)
  52. self._scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
  53. self._thread = threading.Thread(
  54. target=self._compute_and_cache_image_embedding
  55. )
  56. self._thread.start()
  57. def preprocess_image(self, image, for_segmentation=False):
  58. if for_segmentation:
  59. # Resize image to segmentation model input size
  60. # logger.debug(f"Preprocessing image for segmentation: {image.shape}")
  61. resized_image = cv2.resize(image, (self.segmentation_input_shape[3], self.segmentation_input_shape[2])) # Width, Height
  62. resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
  63. resized_image = self.normalize(resized_image) # Normalize for segmentation model
  64. else:
  65. # Resize image for detection model input size
  66. logger.debug(f"Preprocessing image for detection: {image.shape}")
  67. resized_image = cv2.resize(image, (self.input_width, self.input_height))
  68. resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
  69. resized_image = resized_image.astype('float32') / 255.0
  70. input_tensor = resized_image.transpose(2, 0, 1) # Convert HWC to CHW
  71. input_tensor = np.expand_dims(input_tensor, 0) # Add batch dimension
  72. logger.debug(f"Processed image shape: {input_tensor.shape}")
  73. return input_tensor
  74. def _compute_and_cache_image_embedding(self):
  75. with self._lock:
  76. # Prepare the inputs dictionary
  77. inputs = {
  78. 'image': self._image,
  79. 'im_shape': self._im_shape,
  80. 'scale_factor': self._scale_factor
  81. }
  82. # Perform inference
  83. self._result = self.detection_request.infer(inputs)
  84. # print("models results:", self._result)
  85. def _get_image_embedding(self):
  86. if self._thread is not None:
  87. self._thread.join()
  88. self._thread = None
  89. with self._lock:
  90. new_result = self._result
  91. return new_result
  92. def predict_mask_from_points(self,points=None,point_labels=None):
  93. return _collect_result_from_output(
  94. outputs=self._get_image_embedding(),
  95. raw_width=self.raw_width,
  96. raw_height=self.raw_height,
  97. )
  98. def predict_polygon_from_points(self,points=None,point_labels=None):
  99. result_list=self.predict_mask_from_points(points,point_labels)
  100. return result_list
  101. def _collect_result_from_output(outputs, raw_width, raw_height):
  102. # Extract the desired output array from outputs dictionary
  103. output_array = None
  104. for key in outputs:
  105. if 'save_infer_model/scale_0.tmp_0' in key.names:
  106. output_array = outputs[key]
  107. break
  108. if output_array is None:
  109. raise ValueError("Desired output not found in outputs")
  110. outputs = output_array # shape [50,6]
  111. point_list = []
  112. thresh_hold = 0.7
  113. for bbox_info in outputs:
  114. score = bbox_info[1]
  115. if score > thresh_hold:
  116. x1_raw = bbox_info[2]
  117. y1_raw = bbox_info[3]
  118. x2_raw = bbox_info[4]
  119. y2_raw = bbox_info[5]
  120. print(f"Raw bbox coordinates: x1={x1_raw}, y1={y1_raw}, x2={x2_raw}, y2={y2_raw}")
  121. x1 = max(min(int(x1_raw), raw_width - 1), 0)
  122. y1 = max(min(int(y1_raw), raw_height - 1), 0)
  123. x2 = max(min(int(x2_raw), raw_width - 1), 0)
  124. y2 = max(min(int(y2_raw), raw_height - 1), 0)
  125. print(f"Clamped bbox coordinates: x1={x1}, y1={y1}, x2={x2}, y2={y2}")
  126. point_xy = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
  127. point_list.append(point_xy)
  128. return point_list