|  | @@ -8,7 +8,8 @@ import labelme.utils
 | 
	
		
			
				|  |  |  from labelme import QT5
 | 
	
		
			
				|  |  |  from labelme.logger import logger
 | 
	
		
			
				|  |  |  from labelme.shape import Shape
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +import numpy as np
 | 
	
		
			
				|  |  | +import cv2
 | 
	
		
			
				|  |  |  # TODO(unknown):
 | 
	
		
			
				|  |  |  # - [maybe] Find optimal epsilon value.
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -41,6 +42,8 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def __init__(self, *args, **kwargs):
 | 
	
		
			
				|  |  |          self.epsilon = kwargs.pop("epsilon", 10.0)
 | 
	
		
			
				|  |  | +        self.detection_model_path=kwargs.pop("detection_model_path",None)
 | 
	
		
			
				|  |  | +        self.segmentation_model_path=kwargs.pop("segmentation_model_path",None)
 | 
	
		
			
				|  |  |          self.double_click = kwargs.pop("double_click", "close")
 | 
	
		
			
				|  |  |          if self.double_click not in [None, "close"]:
 | 
	
		
			
				|  |  |              raise ValueError(
 | 
	
	
		
			
				|  | @@ -93,6 +96,9 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          self.hShapeIsSelected = False
 | 
	
		
			
				|  |  |          self._painter = QtGui.QPainter()
 | 
	
		
			
				|  |  |          self._cursor = CURSOR_DEFAULT
 | 
	
		
			
				|  |  | +        self.draw_pred=False
 | 
	
		
			
				|  |  | +        self.pred_bbox_points=None
 | 
	
		
			
				|  |  | +        self.current_bbox_point=None
 | 
	
		
			
				|  |  |          # Menus:
 | 
	
		
			
				|  |  |          # 0: right-click without selection and dragging of shapes
 | 
	
		
			
				|  |  |          # 1: right-click with selection and dragging of shapes
 | 
	
	
		
			
				|  | @@ -102,7 +108,8 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          self.setFocusPolicy(QtCore.Qt.WheelFocus)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          self._ai_model = None
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        self._detection_model = None
 | 
	
		
			
				|  |  | +        self._segmentation_model = None
 | 
	
		
			
				|  |  |      def fillDrawing(self):
 | 
	
		
			
				|  |  |          return self._fill_drawing
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -128,16 +135,36 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |              raise ValueError("Unsupported createMode: %s" % value)
 | 
	
		
			
				|  |  |          self._createMode = value
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def initializeAiModel(self, name):
 | 
	
		
			
				|  |  | -        if name not in [model.name for model in labelme.ai.MODELS]:
 | 
	
		
			
				|  |  | -            raise ValueError("Unsupported ai model: %s" % name)
 | 
	
		
			
				|  |  | -        model = [model for model in labelme.ai.MODELS if model.name == name][0]
 | 
	
		
			
				|  |  | +    def initializeBarcodeModel(self, detection_model_path, segmentation_model_path=None):
 | 
	
		
			
				|  |  | +        if not detection_model_path:
 | 
	
		
			
				|  |  | +            raise ValueError("Detection model path is required.")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        if self._ai_model is not None and self._ai_model.name == model.name:
 | 
	
		
			
				|  |  | -            logger.debug("AI model is already initialized: %r" % model.name)
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            logger.debug("Initializing AI model: %r" % model.name)
 | 
	
		
			
				|  |  | -            self._ai_model = model()
 | 
	
		
			
				|  |  | +        logger.debug("Initializing only detection model: %r" % "BarcodePredictModel")
 | 
	
		
			
				|  |  | +        self._detection_model = labelme.ai.BarcodePredictModel(detection_model_path)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if segmentation_model_path:
 | 
	
		
			
				|  |  | +            logger.debug("Initializing barcode detection  & Segmentation model: %r" % "BarcodePredictModel")
 | 
	
		
			
				|  |  | +            self._segmentation_model = labelme.ai.BarcodePredictModel(
 | 
	
		
			
				|  |  | +                detection_model_path, segmentation_model_path
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        if self.pixmap is None:
 | 
	
		
			
				|  |  | +            logger.warning("Pixmap is not set yet")
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self._detection_model.set_image(
 | 
	
		
			
				|  |  | +            image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def initializeAiModel(self, name, weight_path=None):
 | 
	
		
			
				|  |  | +        if self._ai_model is not None:
 | 
	
		
			
				|  |  | +            logger.debug("AI model is already initialized.")
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if name not in [model.name for model in labelme.ai.MODELS]:
 | 
	
		
			
				|  |  | +            raise ValueError("Unsupported AI model: %s" % name)
 | 
	
		
			
				|  |  | +        model_class = [model for model in labelme.ai.MODELS if model.name == name][0]
 | 
	
		
			
				|  |  | +        logger.debug(f"Initializing AI model: {name}")
 | 
	
		
			
				|  |  | +        self._ai_model = model_class(weight_path)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          if self.pixmap is None:
 | 
	
		
			
				|  |  |              logger.warning("Pixmap is not set yet")
 | 
	
	
		
			
				|  | @@ -145,6 +172,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          self._ai_model.set_image(
 | 
	
		
			
				|  |  |              image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
 | 
	
		
			
				|  |  | +            # image=self.pixmap.toImage()
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def storeShapes(self):
 | 
	
	
		
			
				|  | @@ -154,7 +182,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          if len(self.shapesBackups) > self.num_backups:
 | 
	
		
			
				|  |  |              self.shapesBackups = self.shapesBackups[-self.num_backups - 1 :]
 | 
	
		
			
				|  |  |          self.shapesBackups.append(shapesBackup)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        
 | 
	
		
			
				|  |  |      @property
 | 
	
		
			
				|  |  |      def isShapeRestorable(self):
 | 
	
		
			
				|  |  |          # We save the state AFTER each edit (not before) so for an
 | 
	
	
		
			
				|  | @@ -170,6 +198,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          # and app.py::loadShapes and our own Canvas::loadShapes function.
 | 
	
		
			
				|  |  |          if not self.isShapeRestorable:
 | 
	
		
			
				|  |  |              return
 | 
	
		
			
				|  |  | +        print(f"shape is restorable")
 | 
	
		
			
				|  |  |          self.shapesBackups.pop()  # latest
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # The application will eventually call Canvas.loadShapes which will
 | 
	
	
		
			
				|  | @@ -358,7 +387,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |                  self.setStatusTip(self.toolTip())
 | 
	
		
			
				|  |  |                  self.update()
 | 
	
		
			
				|  |  |                  break
 | 
	
		
			
				|  |  | -            elif shape.containsPoint(pos):
 | 
	
		
			
				|  |  | +            elif len(shape.points)!=0 and shape.containsPoint(pos) :
 | 
	
		
			
				|  |  |                  if self.selectedVertex():
 | 
	
		
			
				|  |  |                      self.hShape.highlightClear()
 | 
	
		
			
				|  |  |                  self.prevhVertex = self.hVertex
 | 
	
	
		
			
				|  | @@ -683,7 +712,6 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |      def paintEvent(self, event):
 | 
	
		
			
				|  |  |          if not self.pixmap:
 | 
	
		
			
				|  |  |              return super(Canvas, self).paintEvent(event)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          p = self._painter
 | 
	
		
			
				|  |  |          p.begin(self)
 | 
	
		
			
				|  |  |          p.setRenderHint(QtGui.QPainter.Antialiasing)
 | 
	
	
		
			
				|  | @@ -728,7 +756,19 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          if self.selectedShapesCopy:
 | 
	
		
			
				|  |  |              for s in self.selectedShapesCopy:
 | 
	
		
			
				|  |  |                  s.paint(p)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        # if(self.draw_pred and self.current is not None):
 | 
	
		
			
				|  |  | +        #     print("pred mode on")
 | 
	
		
			
				|  |  | +        #     for bbox_points in self.pred_bbox_points:
 | 
	
		
			
				|  |  | +        #         drawing_shape = self.current.copy()
 | 
	
		
			
				|  |  | +        #         drawing_shape.setShapeRefined(
 | 
	
		
			
				|  |  | +        #             shape_type="polygon",
 | 
	
		
			
				|  |  | +        #             points=[QtCore.QPointF(point[0], point[1]) for point in bbox_points],
 | 
	
		
			
				|  |  | +        #             point_labels=[1]*len(bbox_points)
 | 
	
		
			
				|  |  | +        #         )
 | 
	
		
			
				|  |  | +        #         drawing_shape.fill = self.fillDrawing()
 | 
	
		
			
				|  |  | +        #         drawing_shape.selected = True
 | 
	
		
			
				|  |  | +        #         drawing_shape.paint(p)
 | 
	
		
			
				|  |  | +        #     self.draw_pred=False
 | 
	
		
			
				|  |  |          if (
 | 
	
		
			
				|  |  |              self.fillDrawing()
 | 
	
		
			
				|  |  |              and self.createMode == "polygon"
 | 
	
	
		
			
				|  | @@ -751,7 +791,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |                  point=self.line.points[1],
 | 
	
		
			
				|  |  |                  label=self.line.point_labels[1],
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  | -            points = self._ai_model.predict_polygon_from_points(
 | 
	
		
			
				|  |  | +            points = self._detection_model.predict_polygon_from_points(
 | 
	
		
			
				|  |  |                  points=[[point.x(), point.y()] for point in drawing_shape.points],
 | 
	
		
			
				|  |  |                  point_labels=drawing_shape.point_labels,
 | 
	
		
			
				|  |  |              )
 | 
	
	
		
			
				|  | @@ -770,7 +810,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |                  point=self.line.points[1],
 | 
	
		
			
				|  |  |                  label=self.line.point_labels[1],
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  | -            mask = self._ai_model.predict_mask_from_points(
 | 
	
		
			
				|  |  | +            mask = self._detection_model.predict_mask_from_points(
 | 
	
		
			
				|  |  |                  points=[[point.x(), point.y()] for point in drawing_shape.points],
 | 
	
		
			
				|  |  |                  point_labels=drawing_shape.point_labels,
 | 
	
		
			
				|  |  |              )
 | 
	
	
		
			
				|  | @@ -804,11 +844,12 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          return not (0 <= p.x() <= w - 1 and 0 <= p.y() <= h - 1)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def finalise(self):
 | 
	
		
			
				|  |  | -        assert self.current
 | 
	
		
			
				|  |  | +        if(self.current is None):
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  |          if self.createMode == "ai_polygon":
 | 
	
		
			
				|  |  |              # convert points to polygon by an AI model
 | 
	
		
			
				|  |  |              assert self.current.shape_type == "points"
 | 
	
		
			
				|  |  | -            points = self._ai_model.predict_polygon_from_points(
 | 
	
		
			
				|  |  | +            points = self._detection_model.predict_polygon_from_points(
 | 
	
		
			
				|  |  |                  points=[[point.x(), point.y()] for point in self.current.points],
 | 
	
		
			
				|  |  |                  point_labels=self.current.point_labels,
 | 
	
		
			
				|  |  |              )
 | 
	
	
		
			
				|  | @@ -820,7 +861,7 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          elif self.createMode == "ai_mask":
 | 
	
		
			
				|  |  |              # convert points to mask by an AI model
 | 
	
		
			
				|  |  |              assert self.current.shape_type == "points"
 | 
	
		
			
				|  |  | -            mask = self._ai_model.predict_mask_from_points(
 | 
	
		
			
				|  |  | +            mask = self._detection_model.predict_mask_from_points(
 | 
	
		
			
				|  |  |                  points=[[point.x(), point.y()] for point in self.current.points],
 | 
	
		
			
				|  |  |                  point_labels=self.current.point_labels,
 | 
	
		
			
				|  |  |              )
 | 
	
	
		
			
				|  | @@ -831,8 +872,30 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |                  point_labels=[1, 1],
 | 
	
		
			
				|  |  |                  mask=mask[y1 : y2 + 1, x1 : x2 + 1],
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  | +        elif self.pred_bbox_points is not None and self.draw_pred:
 | 
	
		
			
				|  |  | +            print("pred mode on")
 | 
	
		
			
				|  |  | +            current_copy=self.current.copy()
 | 
	
		
			
				|  |  | +            for bbox_point in self.pred_bbox_points:
 | 
	
		
			
				|  |  | +                drawing_shape=current_copy.copy()
 | 
	
		
			
				|  |  | +                drawing_shape.setShapeRefined(
 | 
	
		
			
				|  |  | +                    shape_type="polygon",
 | 
	
		
			
				|  |  | +                    points=[QtCore.QPointF(point[0], point[1]) for point in bbox_point],
 | 
	
		
			
				|  |  | +                    point_labels=[1]*len(bbox_point)
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +                drawing_shape.close()
 | 
	
		
			
				|  |  | +                self.shapes.append(drawing_shape)
 | 
	
		
			
				|  |  | +                self.storeShapes()
 | 
	
		
			
				|  |  | +                self.update()
 | 
	
		
			
				|  |  | +                self.newShape.emit()
 | 
	
		
			
				|  |  | +            current_copy.close()
 | 
	
		
			
				|  |  | +            current_copy=None
 | 
	
		
			
				|  |  | +            if(self.current):
 | 
	
		
			
				|  |  | +                self.current.close()
 | 
	
		
			
				|  |  | +                self.current = None
 | 
	
		
			
				|  |  | +            self.setHiding(False)
 | 
	
		
			
				|  |  | +            self.draw_pred=False
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  |          self.current.close()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          self.shapes.append(self.current)
 | 
	
		
			
				|  |  |          self.storeShapes()
 | 
	
		
			
				|  |  |          self.current = None
 | 
	
	
		
			
				|  | @@ -959,6 +1022,29 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |                  self.finalise()
 | 
	
		
			
				|  |  |              elif modifiers == QtCore.Qt.AltModifier:
 | 
	
		
			
				|  |  |                  self.snapping = False
 | 
	
		
			
				|  |  | +            elif key == QtCore.Qt.Key_V:
 | 
	
		
			
				|  |  | +                if self._detection_model is None:
 | 
	
		
			
				|  |  | +                    logger.info(f"Initializing AI model")
 | 
	
		
			
				|  |  | +                    self.initializeBarcodeModel(self.detection_model_path, self.segmentation_model_path)
 | 
	
		
			
				|  |  | + 
 | 
	
		
			
				|  |  | +                self.current = Shape(
 | 
	
		
			
				|  |  | +                    shape_type="points" if self.createMode in ["ai_polygon", "ai_mask"] else self.createMode
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +                        
 | 
	
		
			
				|  |  | +                if self._detection_model:
 | 
	
		
			
				|  |  | +                    if self._segmentation_model is None:
 | 
	
		
			
				|  |  | +                        logger.info(f"Performing detection only.")
 | 
	
		
			
				|  |  | +                        # Get prediction from model
 | 
	
		
			
				|  |  | +                        self.pred_bbox_points = self._detection_model.predict_polygon_from_points()
 | 
	
		
			
				|  |  | +                        print("Predicted Bounding Box Points:", self.pred_bbox_points)                        
 | 
	
		
			
				|  |  | +                        if self.pred_bbox_points:
 | 
	
		
			
				|  |  | +                            self.draw_pred = True
 | 
	
		
			
				|  |  | +                            self.finalise()
 | 
	
		
			
				|  |  | +                        else:
 | 
	
		
			
				|  |  | +                            print("No bounding boxes detected.")    
 | 
	
		
			
				|  |  | +                    else:
 | 
	
		
			
				|  |  | +                            logger.info(f"Performing detection and segmentation.")
 | 
	
		
			
				|  |  | +                            self.detect_and_segment() 
 | 
	
		
			
				|  |  |          elif self.editing():
 | 
	
		
			
				|  |  |              if key == QtCore.Qt.Key_Up:
 | 
	
		
			
				|  |  |                  self.moveByKeyboard(QtCore.QPointF(0.0, -MOVE_SPEED))
 | 
	
	
		
			
				|  | @@ -969,6 +1055,164 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |              elif key == QtCore.Qt.Key_Right:
 | 
	
		
			
				|  |  |                  self.moveByKeyboard(QtCore.QPointF(MOVE_SPEED, 0.0))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    def scale_points(self, approx, mask_shape, cropped_shape, x_min, y_min):
 | 
	
		
			
				|  |  | +        scale_x = cropped_shape[1] / mask_shape[1]  # Scale factor for x-axis
 | 
	
		
			
				|  |  | +        scale_y = cropped_shape[0] / mask_shape[0]  # Scale factor for y-axis
 | 
	
		
			
				|  |  | +        return [[int(pt[0][0] * scale_x) + x_min, int(pt[0][1] * scale_y) + y_min] for pt in approx]
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    def detect_and_segment(self):
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Perform detection and segmentation (if both models are available).
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        logger.info("Performing detection and segmentation.")
 | 
	
		
			
				|  |  | +        self.current = Shape(
 | 
	
		
			
				|  |  | +            shape_type="points" if self.createMode in ["ai_polygon", "ai_mask"] else self.createMode
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Step 1: detection bounding box points
 | 
	
		
			
				|  |  | +        detection_results = self._detection_model.predict_polygon_from_points()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if not detection_results or len(detection_results) == 0:
 | 
	
		
			
				|  |  | +            logger.warning("No detection found")
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        logger.debug(f"Detection results: {detection_results}")
 | 
	
		
			
				|  |  | +        all_segmentation_results = []
 | 
	
		
			
				|  |  | +        rotated = False
 | 
	
		
			
				|  |  | +        # Step 2: Loop through each detection result since there are multiple per image
 | 
	
		
			
				|  |  | +        for detection_idx, detection_result in enumerate(detection_results):
 | 
	
		
			
				|  |  | +            logger.debug(f"Processing detection {detection_idx + 1}/{len(detection_results)}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            try:
 | 
	
		
			
				|  |  | +                #extracting
 | 
	
		
			
				|  |  | +                x_coords = [point[0] for point in detection_result]
 | 
	
		
			
				|  |  | +                y_coords = [point[1] for point in detection_result]
 | 
	
		
			
				|  |  | +                
 | 
	
		
			
				|  |  | +                #min and max values for x and y
 | 
	
		
			
				|  |  | +                x_min, x_max = min(x_coords), max(x_coords)
 | 
	
		
			
				|  |  | +                y_min, y_max = min(y_coords), max(y_coords)
 | 
	
		
			
				|  |  | +                
 | 
	
		
			
				|  |  | +                logger.debug(f"Bounding box for detection {detection_idx + 1} - x_min: {x_min}, y_min: {y_min}, x_max: {x_max}, y_max: {y_max}")
 | 
	
		
			
				|  |  | +            except Exception as e:
 | 
	
		
			
				|  |  | +                logger.error(f"Error extracting bounding box coordinates for detection {detection_idx + 1}: {e}")
 | 
	
		
			
				|  |  | +                continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # Converting bounding box values to integers for cropping
 | 
	
		
			
				|  |  | +            x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # Step 3: Cropping image based on detection output
 | 
	
		
			
				|  |  | +            try:
 | 
	
		
			
				|  |  | +                
 | 
	
		
			
				|  |  | +                cropped_image = self.pixmap.toImage().copy(x_min, y_min, x_max - x_min, y_max - y_min)
 | 
	
		
			
				|  |  | +                cropped_image = labelme.utils.img_qt_to_arr(cropped_image)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                orig_height, orig_width = cropped_image.shape[:2]
 | 
	
		
			
				|  |  | +                logger.debug(f"Original height: {orig_height}, Original width: {orig_width}")
 | 
	
		
			
				|  |  | +                # if the height is greater than the width we rotate for segmentaion
 | 
	
		
			
				|  |  | +                if orig_height > orig_width:  
 | 
	
		
			
				|  |  | +                    cropped_image = cv2.rotate(cropped_image, cv2.ROTATE_90_CLOCKWISE)
 | 
	
		
			
				|  |  | +                    logger.debug(f"Rotated cropped image for detection {detection_idx + 1} due to height > width.")
 | 
	
		
			
				|  |  | +                    orig_cropped_shape = cropped_image.shape[:2]
 | 
	
		
			
				|  |  | +                    rotated = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                else: 
 | 
	
		
			
				|  |  | +                    rotated = False
 | 
	
		
			
				|  |  | +                # Save crop image
 | 
	
		
			
				|  |  | +                cv2.imwrite(f"cropped_image_{detection_idx + 1}.png", cropped_image)
 | 
	
		
			
				|  |  | +                logger.debug(f"Saved cropped image for detection {detection_idx + 1}: {cropped_image.shape}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # logger.debug(f"Cropped image shape for detection {detection_idx + 1}: {cropped_image.shape}")
 | 
	
		
			
				|  |  | +            except Exception as e:
 | 
	
		
			
				|  |  | +                logger.error(f"Error cropping the image for detection {detection_idx + 1}: {e}")
 | 
	
		
			
				|  |  | +                continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # Step 4: Resize the cropped image to match segmentation input size (1 64 256)
 | 
	
		
			
				|  |  | +            try:
 | 
	
		
			
				|  |  | +                orig_cropped_shape = cropped_image.shape[:2]  # Save the original cropped image size
 | 
	
		
			
				|  |  | +                preprocessed_img = self._detection_model.preprocess_image(cropped_image, for_segmentation=True)
 | 
	
		
			
				|  |  | +                logger.debug(f"Preprocessed image shape for segmentation detection {detection_idx + 1}: {preprocessed_img.shape}")
 | 
	
		
			
				|  |  | +            except Exception as e:
 | 
	
		
			
				|  |  | +                logger.error(f"Error preprocessing the image for segmentation for detection {detection_idx + 1}: {e}")
 | 
	
		
			
				|  |  | +                continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # Step 5: inference on segmentation model on cropped image
 | 
	
		
			
				|  |  | +            try:
 | 
	
		
			
				|  |  | +                seg_result = self._segmentation_model.segmentation_sess.infer_new_request({'x': preprocessed_img})
 | 
	
		
			
				|  |  | +                logger.debug(f"Segmentation model inference completed for detection {detection_idx + 1}.")
 | 
	
		
			
				|  |  | +            except Exception as e:
 | 
	
		
			
				|  |  | +                logger.error(f"Error during segmentation model inference for detection {detection_idx + 1}: {e}")
 | 
	
		
			
				|  |  | +                continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # Step 6: Convert binary mask to polygon (contours)
 | 
	
		
			
				|  |  | +            try:
 | 
	
		
			
				|  |  | +                mask = seg_result['save_infer_model/scale_0.tmp_0']  #model output name
 | 
	
		
			
				|  |  | +                mask = mask.squeeze()  # Remove batch dimension, should result in (64, 256)
 | 
	
		
			
				|  |  | +                logger.debug(f"Segmentation mask shape for detection {detection_idx + 1}: {mask.shape}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # Normalize the mask to 0 and 255 and convert to uint8
 | 
	
		
			
				|  |  | +                mask = (mask * 255).astype(np.uint8)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                logger.debug(f"Converted mask shape for detection {detection_idx + 1}: {mask.shape}, dtype: {mask.dtype}")
 | 
	
		
			
				|  |  | +                cv2.imwrite(f"segmentation_mask_{detection_idx + 1}.png", mask)
 | 
	
		
			
				|  |  | +                if rotated:
 | 
	
		
			
				|  |  | +                    cropped_image = cv2.rotate(cropped_image, cv2.ROTATE_90_COUNTERCLOCKWISE)
 | 
	
		
			
				|  |  | +                    mask = cv2.rotate(mask, cv2.ROTATE_90_COUNTERCLOCKWISE)
 | 
	
		
			
				|  |  | +                    rotated_cropped_shape = cropped_image.shape[:2]
 | 
	
		
			
				|  |  | +                
 | 
	
		
			
				|  |  | +                # cv2.imwrite(f"segmentation_mask_{detection_idx + 1}.png", mask)
 | 
	
		
			
				|  |  | +                logger.debug(f"Saved segmentation mask for detection {detection_idx + 1}.")
 | 
	
		
			
				|  |  | +                
 | 
	
		
			
				|  |  | +                # Step 7: Find contours
 | 
	
		
			
				|  |  | +                contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 | 
	
		
			
				|  |  | +                logger.debug(f"Found {len(contours)} contours in the mask for detection {detection_idx + 1}.")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if len(contours) > 0:
 | 
	
		
			
				|  |  | +                    largest_contour = max(contours, key=cv2.contourArea)
 | 
	
		
			
				|  |  | +                    
 | 
	
		
			
				|  |  | +                    # Step 8: Approximate a polygon with exactly 4 points (quadrilateral)
 | 
	
		
			
				|  |  | +                    epsilon = 0.02 * cv2.arcLength(largest_contour, True)  # epsilon for precision
 | 
	
		
			
				|  |  | +                    approx = cv2.approxPolyDP(largest_contour, epsilon, True)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                    # If the approximation doesn't result in 4 points, force it
 | 
	
		
			
				|  |  | +                    if len(approx) != 4:
 | 
	
		
			
				|  |  | +                        # Using boundingRect as fallback in case of insufficient points
 | 
	
		
			
				|  |  | +                        print("log here")
 | 
	
		
			
				|  |  | +                        x, y, w, h = cv2.boundingRect(largest_contour)
 | 
	
		
			
				|  |  | +                        point_xy = [
 | 
	
		
			
				|  |  | +                            [x + x_min, y + y_min],          # Top-left
 | 
	
		
			
				|  |  | +                            [x + w + x_min, y + y_min],      # Top-right
 | 
	
		
			
				|  |  | +                            [x + w + x_min, y + h + y_min],  # Bottom-right
 | 
	
		
			
				|  |  | +                            [x + x_min, y + h + y_min]       # Bottom-left
 | 
	
		
			
				|  |  | +                        ]
 | 
	
		
			
				|  |  | +                    else:
 | 
	
		
			
				|  |  | +                        if rotated:
 | 
	
		
			
				|  |  | +                            point_xy = self.scale_points(approx, mask.shape, rotated_cropped_shape, x_min, y_min)
 | 
	
		
			
				|  |  | +                        else:
 | 
	
		
			
				|  |  | +                            point_xy = self.scale_points(approx, mask.shape, orig_cropped_shape, x_min, y_min)
 | 
	
		
			
				|  |  | +                    logger.debug(f"Generated 4 corner points for the polygon for detection {detection_idx + 1}: {point_xy}")
 | 
	
		
			
				|  |  | +                    self.pred_bbox_points = [point_xy]
 | 
	
		
			
				|  |  | +                    logger.debug(f"Predicted Bounding Box Points for detection {detection_idx + 1}: {self.pred_bbox_points}")
 | 
	
		
			
				|  |  | +                    if self.pred_bbox_points:
 | 
	
		
			
				|  |  | +                        self.draw_pred = True
 | 
	
		
			
				|  |  | +                        self.finalise()
 | 
	
		
			
				|  |  | +                    else:
 | 
	
		
			
				|  |  | +                        logger.info(f"No bounding boxes detected for detection {detection_idx + 1}.")
 | 
	
		
			
				|  |  | +                        
 | 
	
		
			
				|  |  | +                    # Collect the segmentation result
 | 
	
		
			
				|  |  | +                    all_segmentation_results.append(self.pred_bbox_points)
 | 
	
		
			
				|  |  | +            except Exception as e:
 | 
	
		
			
				|  |  | +                logger.error(f"Error creating the polygon shape for detection {detection_idx + 1}: {e}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # **Reset critical variables after each detection**:
 | 
	
		
			
				|  |  | +            self.pred_bbox_points = None
 | 
	
		
			
				|  |  | +            self.draw_pred = False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # You now have a list of segmentation results for all detections
 | 
	
		
			
				|  |  | +        if all_segmentation_results:
 | 
	
		
			
				|  |  | +            logger.info(f"Segmentation results for all detections: {all_segmentation_results}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def keyReleaseEvent(self, ev):
 | 
	
		
			
				|  |  |          modifiers = ev.modifiers()
 | 
	
		
			
				|  |  |          if self.drawing():
 | 
	
	
		
			
				|  | @@ -985,6 +1229,8 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def setLastLabel(self, text, flags):
 | 
	
		
			
				|  |  |          assert text
 | 
	
		
			
				|  |  | +        if(self.shapes is None or len(self.shapes)==0):
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  |          self.shapes[-1].label = text
 | 
	
		
			
				|  |  |          self.shapes[-1].flags = flags
 | 
	
		
			
				|  |  |          self.shapesBackups.pop()
 | 
	
	
		
			
				|  | @@ -992,11 +1238,12 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |          return self.shapes[-1]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def undoLastLine(self):
 | 
	
		
			
				|  |  | -        assert self.shapes
 | 
	
		
			
				|  |  | +        if(self.shapes is None or len(self.shapes)==0):
 | 
	
		
			
				|  |  | +            return
 | 
	
		
			
				|  |  |          self.current = self.shapes.pop()
 | 
	
		
			
				|  |  |          self.current.setOpen()
 | 
	
		
			
				|  |  |          self.current.restoreShapeRaw()
 | 
	
		
			
				|  |  | -        if self.createMode in ["polygon", "linestrip"]:
 | 
	
		
			
				|  |  | +        if self.createMode in ["polygon", "linestrip"] and self.draw_pred is False:
 | 
	
		
			
				|  |  |              self.line.points = [self.current[-1], self.current[0]]
 | 
	
		
			
				|  |  |          elif self.createMode in ["rectangle", "line", "circle"]:
 | 
	
		
			
				|  |  |              self.current.points = self.current.points[0:1]
 | 
	
	
		
			
				|  | @@ -1017,8 +1264,8 @@ class Canvas(QtWidgets.QWidget):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def loadPixmap(self, pixmap, clear_shapes=True):
 | 
	
		
			
				|  |  |          self.pixmap = pixmap
 | 
	
		
			
				|  |  | -        if self._ai_model:
 | 
	
		
			
				|  |  | -            self._ai_model.set_image(
 | 
	
		
			
				|  |  | +        if self._detection_model:
 | 
	
		
			
				|  |  | +            self._detection_model.set_image(
 | 
	
		
			
				|  |  |                  image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |          if clear_shapes:
 |