Explorar el Código

Changes made to labelme to integrate with Paddle Detection and Paddle Segmentation

Noman hace 7 meses
padre
commit
96c7132de1

+ 2 - 20
examples/bbox_detection/labels.txt

@@ -1,22 +1,4 @@
 __ignore__
 _background_
-aeroplane
-bicycle
-bird
-boat
-bottle
-bus
-car
-cat
-chair
-cow
-diningtable
-dog
-horse
-motorbike
-person
-potted plant
-sheep
-sofa
-train
-tv/monitor
+barcode
+tough_barcode

+ 13 - 1
labelme/__main__.py

@@ -103,6 +103,17 @@ def main():
         help="epsilon to find nearest vertex on canvas",
         default=argparse.SUPPRESS,
     )
+    parser.add_argument(
+        "--model",
+        type=str,
+        help="model weight for barcode detection",
+        default=argparse.SUPPRESS,
+    )
+    parser.add_argument(
+        "--segmentation_model",
+        type=str,
+        help="Path to the segmentation model"
+    )
     args = parser.parse_args()
 
     if args.version:
@@ -139,7 +150,8 @@ def main():
     output = config_from_args.pop("output")
     config_file_or_yaml = config_from_args.pop("config")
     config = get_config(config_file_or_yaml, config_from_args)
-
+    config["model"]=config_from_args.pop("model")
+    config["segmentation_model"]=config_from_args.pop("segmentation_model")
     if not config["labels"] and config["validate_label"]:
         logger.error(
             "--labels must be specified with --validatelabel or "

+ 14 - 6
labelme/ai/__init__.py

@@ -1,13 +1,20 @@
 import gdown
-
 from .efficient_sam import EfficientSam
 from .segment_anything_model import SegmentAnythingModel
+from .barcode_model import BarcodePredictModel
 
+class BarcodePredict(BarcodePredictModel):
+    name="BarcodePredict(ov)"
+    def __init__(self, detection_model_path=None, segmentation_model_path=None):
+        super().__init__(
+            detection_model_path=detection_model_path,
+            segmentation_model_path=segmentation_model_path
+        )
 
 class SegmentAnythingModelVitB(SegmentAnythingModel):
     name = "SegmentAnything (speed)"
 
-    def __init__(self):
+    def __init__(self,model_path=None):
         super().__init__(
             encoder_path=gdown.cached_download(
                 url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx",  # NOQA
@@ -23,7 +30,7 @@ class SegmentAnythingModelVitB(SegmentAnythingModel):
 class SegmentAnythingModelVitL(SegmentAnythingModel):
     name = "SegmentAnything (balanced)"
 
-    def __init__(self):
+    def __init__(self,model_path=None):
         super().__init__(
             encoder_path=gdown.cached_download(
                 url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx",  # NOQA
@@ -39,7 +46,7 @@ class SegmentAnythingModelVitL(SegmentAnythingModel):
 class SegmentAnythingModelVitH(SegmentAnythingModel):
     name = "SegmentAnything (accuracy)"
 
-    def __init__(self):
+    def __init__(self,model_path=None):
         super().__init__(
             encoder_path=gdown.cached_download(
                 url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx",  # NOQA
@@ -55,7 +62,7 @@ class SegmentAnythingModelVitH(SegmentAnythingModel):
 class EfficientSamVitT(EfficientSam):
     name = "EfficientSam (speed)"
 
-    def __init__(self):
+    def __init__(self,model_path=None):
         super().__init__(
             encoder_path=gdown.cached_download(
                 url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_encoder.onnx",  # NOQA
@@ -71,7 +78,7 @@ class EfficientSamVitT(EfficientSam):
 class EfficientSamVitS(EfficientSam):
     name = "EfficientSam (accuracy)"
 
-    def __init__(self):
+    def __init__(self,model_path=None):
         super().__init__(
             encoder_path=gdown.cached_download(
                 url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_encoder.onnx",  # NOQA
@@ -90,4 +97,5 @@ MODELS = [
     SegmentAnythingModelVitH,
     EfficientSamVitT,
     EfficientSamVitS,
+    BarcodePredict,
 ]

+ 2 - 1
labelme/app.py

@@ -126,7 +126,6 @@ class MainWindow(QtWidgets.QMainWindow):
         self.shape_dock = QtWidgets.QDockWidget(self.tr("Polygon Labels"), self)
         self.shape_dock.setObjectName("Labels")
         self.shape_dock.setWidget(self.labelList)
-
         self.uniqLabelList = UniqueLabelQListWidget()
         self.uniqLabelList.setToolTip(
             self.tr(
@@ -167,6 +166,8 @@ class MainWindow(QtWidgets.QMainWindow):
             double_click=self._config["canvas"]["double_click"],
             num_backups=self._config["canvas"]["num_backups"],
             crosshair=self._config["canvas"]["crosshair"],
+            detection_model_path=self._config["model"],
+            segmentation_model_path=self._config["segmentation_model"],
         )
         self.canvas.zoomRequest.connect(self.zoomRequest)
         self.canvas.mouseMoved.connect(

+ 1 - 0
labelme/config/__init__.py

@@ -69,6 +69,7 @@ def get_config(config_file_or_yaml=None, config_from_args=None):
         update_dict(config, config_from_yaml, validate_item=validate_config_item)
 
     # 3. command line argument or specified config file
+    # print(config_from_args)
     if config_from_args is not None:
         update_dict(config, config_from_args, validate_item=validate_config_item)
 

+ 2 - 1
labelme/utils/image.py

@@ -3,7 +3,7 @@
 
 import base64
 import io
-
+import cv2
 import numpy as np
 import PIL.ExifTags
 import PIL.Image
@@ -63,6 +63,7 @@ def img_qt_to_arr(img_qt):
     w, h, d = img_qt.size().width(), img_qt.size().height(), img_qt.depth()
     bytes_ = img_qt.bits().asstring(w * h * d // 8)
     img_arr = np.frombuffer(bytes_, dtype=np.uint8).reshape((h, w, d // 8))
+    img_arr=cv2.cvtColor(img_arr,cv2.COLOR_RGBA2GRAY)
     return img_arr
 
 

+ 272 - 25
labelme/widgets/canvas.py

@@ -8,7 +8,8 @@ import labelme.utils
 from labelme import QT5
 from labelme.logger import logger
 from labelme.shape import Shape
-
+import numpy as np
+import cv2
 # TODO(unknown):
 # - [maybe] Find optimal epsilon value.
 
@@ -41,6 +42,8 @@ class Canvas(QtWidgets.QWidget):
 
     def __init__(self, *args, **kwargs):
         self.epsilon = kwargs.pop("epsilon", 10.0)
+        self.detection_model_path=kwargs.pop("detection_model_path",None)
+        self.segmentation_model_path=kwargs.pop("segmentation_model_path",None)
         self.double_click = kwargs.pop("double_click", "close")
         if self.double_click not in [None, "close"]:
             raise ValueError(
@@ -93,6 +96,9 @@ class Canvas(QtWidgets.QWidget):
         self.hShapeIsSelected = False
         self._painter = QtGui.QPainter()
         self._cursor = CURSOR_DEFAULT
+        self.draw_pred=False
+        self.pred_bbox_points=None
+        self.current_bbox_point=None
         # Menus:
         # 0: right-click without selection and dragging of shapes
         # 1: right-click with selection and dragging of shapes
@@ -102,7 +108,8 @@ class Canvas(QtWidgets.QWidget):
         self.setFocusPolicy(QtCore.Qt.WheelFocus)
 
         self._ai_model = None
-
+        self._detection_model = None
+        self._segmentation_model = None
     def fillDrawing(self):
         return self._fill_drawing
 
@@ -128,16 +135,36 @@ class Canvas(QtWidgets.QWidget):
             raise ValueError("Unsupported createMode: %s" % value)
         self._createMode = value
 
-    def initializeAiModel(self, name):
-        if name not in [model.name for model in labelme.ai.MODELS]:
-            raise ValueError("Unsupported ai model: %s" % name)
-        model = [model for model in labelme.ai.MODELS if model.name == name][0]
+    def initializeBarcodeModel(self, detection_model_path, segmentation_model_path=None):
+        if not detection_model_path:
+            raise ValueError("Detection model path is required.")
 
-        if self._ai_model is not None and self._ai_model.name == model.name:
-            logger.debug("AI model is already initialized: %r" % model.name)
-        else:
-            logger.debug("Initializing AI model: %r" % model.name)
-            self._ai_model = model()
+        logger.debug("Initializing only detection model: %r" % "BarcodePredictModel")
+        self._detection_model = labelme.ai.BarcodePredictModel(detection_model_path)
+
+        if segmentation_model_path:
+            logger.debug("Initializing barcode detection  & Segmentation model: %r" % "BarcodePredictModel")
+            self._segmentation_model = labelme.ai.BarcodePredictModel(
+                detection_model_path, segmentation_model_path
+            )
+        if self.pixmap is None:
+            logger.warning("Pixmap is not set yet")
+            return
+
+        self._detection_model.set_image(
+            image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
+        )
+
+    def initializeAiModel(self, name, weight_path=None):
+        if self._ai_model is not None:
+            logger.debug("AI model is already initialized.")
+            return
+
+        if name not in [model.name for model in labelme.ai.MODELS]:
+            raise ValueError("Unsupported AI model: %s" % name)
+        model_class = [model for model in labelme.ai.MODELS if model.name == name][0]
+        logger.debug(f"Initializing AI model: {name}")
+        self._ai_model = model_class(weight_path)
 
         if self.pixmap is None:
             logger.warning("Pixmap is not set yet")
@@ -145,6 +172,7 @@ class Canvas(QtWidgets.QWidget):
 
         self._ai_model.set_image(
             image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
+            # image=self.pixmap.toImage()
         )
 
     def storeShapes(self):
@@ -154,7 +182,7 @@ class Canvas(QtWidgets.QWidget):
         if len(self.shapesBackups) > self.num_backups:
             self.shapesBackups = self.shapesBackups[-self.num_backups - 1 :]
         self.shapesBackups.append(shapesBackup)
-
+        
     @property
     def isShapeRestorable(self):
         # We save the state AFTER each edit (not before) so for an
@@ -170,6 +198,7 @@ class Canvas(QtWidgets.QWidget):
         # and app.py::loadShapes and our own Canvas::loadShapes function.
         if not self.isShapeRestorable:
             return
+        print(f"shape is restorable")
         self.shapesBackups.pop()  # latest
 
         # The application will eventually call Canvas.loadShapes which will
@@ -358,7 +387,7 @@ class Canvas(QtWidgets.QWidget):
                 self.setStatusTip(self.toolTip())
                 self.update()
                 break
-            elif shape.containsPoint(pos):
+            elif len(shape.points)!=0 and shape.containsPoint(pos) :
                 if self.selectedVertex():
                     self.hShape.highlightClear()
                 self.prevhVertex = self.hVertex
@@ -683,7 +712,6 @@ class Canvas(QtWidgets.QWidget):
     def paintEvent(self, event):
         if not self.pixmap:
             return super(Canvas, self).paintEvent(event)
-
         p = self._painter
         p.begin(self)
         p.setRenderHint(QtGui.QPainter.Antialiasing)
@@ -728,7 +756,19 @@ class Canvas(QtWidgets.QWidget):
         if self.selectedShapesCopy:
             for s in self.selectedShapesCopy:
                 s.paint(p)
-
+        # if(self.draw_pred and self.current is not None):
+        #     print("pred mode on")
+        #     for bbox_points in self.pred_bbox_points:
+        #         drawing_shape = self.current.copy()
+        #         drawing_shape.setShapeRefined(
+        #             shape_type="polygon",
+        #             points=[QtCore.QPointF(point[0], point[1]) for point in bbox_points],
+        #             point_labels=[1]*len(bbox_points)
+        #         )
+        #         drawing_shape.fill = self.fillDrawing()
+        #         drawing_shape.selected = True
+        #         drawing_shape.paint(p)
+        #     self.draw_pred=False
         if (
             self.fillDrawing()
             and self.createMode == "polygon"
@@ -751,7 +791,7 @@ class Canvas(QtWidgets.QWidget):
                 point=self.line.points[1],
                 label=self.line.point_labels[1],
             )
-            points = self._ai_model.predict_polygon_from_points(
+            points = self._detection_model.predict_polygon_from_points(
                 points=[[point.x(), point.y()] for point in drawing_shape.points],
                 point_labels=drawing_shape.point_labels,
             )
@@ -770,7 +810,7 @@ class Canvas(QtWidgets.QWidget):
                 point=self.line.points[1],
                 label=self.line.point_labels[1],
             )
-            mask = self._ai_model.predict_mask_from_points(
+            mask = self._detection_model.predict_mask_from_points(
                 points=[[point.x(), point.y()] for point in drawing_shape.points],
                 point_labels=drawing_shape.point_labels,
             )
@@ -804,11 +844,12 @@ class Canvas(QtWidgets.QWidget):
         return not (0 <= p.x() <= w - 1 and 0 <= p.y() <= h - 1)
 
     def finalise(self):
-        assert self.current
+        if(self.current is None):
+            return
         if self.createMode == "ai_polygon":
             # convert points to polygon by an AI model
             assert self.current.shape_type == "points"
-            points = self._ai_model.predict_polygon_from_points(
+            points = self._detection_model.predict_polygon_from_points(
                 points=[[point.x(), point.y()] for point in self.current.points],
                 point_labels=self.current.point_labels,
             )
@@ -820,7 +861,7 @@ class Canvas(QtWidgets.QWidget):
         elif self.createMode == "ai_mask":
             # convert points to mask by an AI model
             assert self.current.shape_type == "points"
-            mask = self._ai_model.predict_mask_from_points(
+            mask = self._detection_model.predict_mask_from_points(
                 points=[[point.x(), point.y()] for point in self.current.points],
                 point_labels=self.current.point_labels,
             )
@@ -831,8 +872,30 @@ class Canvas(QtWidgets.QWidget):
                 point_labels=[1, 1],
                 mask=mask[y1 : y2 + 1, x1 : x2 + 1],
             )
+        elif self.pred_bbox_points is not None and self.draw_pred:
+            print("pred mode on")
+            current_copy=self.current.copy()
+            for bbox_point in self.pred_bbox_points:
+                drawing_shape=current_copy.copy()
+                drawing_shape.setShapeRefined(
+                    shape_type="polygon",
+                    points=[QtCore.QPointF(point[0], point[1]) for point in bbox_point],
+                    point_labels=[1]*len(bbox_point)
+                )
+                drawing_shape.close()
+                self.shapes.append(drawing_shape)
+                self.storeShapes()
+                self.update()
+                self.newShape.emit()
+            current_copy.close()
+            current_copy=None
+            if(self.current):
+                self.current.close()
+                self.current = None
+            self.setHiding(False)
+            self.draw_pred=False
+            return
         self.current.close()
-
         self.shapes.append(self.current)
         self.storeShapes()
         self.current = None
@@ -959,6 +1022,29 @@ class Canvas(QtWidgets.QWidget):
                 self.finalise()
             elif modifiers == QtCore.Qt.AltModifier:
                 self.snapping = False
+            elif key == QtCore.Qt.Key_V:
+                if self._detection_model is None:
+                    logger.info(f"Initializing AI model")
+                    self.initializeBarcodeModel(self.detection_model_path, self.segmentation_model_path)
+ 
+                self.current = Shape(
+                    shape_type="points" if self.createMode in ["ai_polygon", "ai_mask"] else self.createMode
+                )
+                        
+                if self._detection_model:
+                    if self._segmentation_model is None:
+                        logger.info(f"Performing detection only.")
+                        # Get prediction from model
+                        self.pred_bbox_points = self._detection_model.predict_polygon_from_points()
+                        print("Predicted Bounding Box Points:", self.pred_bbox_points)                        
+                        if self.pred_bbox_points:
+                            self.draw_pred = True
+                            self.finalise()
+                        else:
+                            print("No bounding boxes detected.")    
+                    else:
+                            logger.info(f"Performing detection and segmentation.")
+                            self.detect_and_segment() 
         elif self.editing():
             if key == QtCore.Qt.Key_Up:
                 self.moveByKeyboard(QtCore.QPointF(0.0, -MOVE_SPEED))
@@ -969,6 +1055,164 @@ class Canvas(QtWidgets.QWidget):
             elif key == QtCore.Qt.Key_Right:
                 self.moveByKeyboard(QtCore.QPointF(MOVE_SPEED, 0.0))
 
+    
+    def scale_points(self, approx, mask_shape, cropped_shape, x_min, y_min):
+        scale_x = cropped_shape[1] / mask_shape[1]  # Scale factor for x-axis
+        scale_y = cropped_shape[0] / mask_shape[0]  # Scale factor for y-axis
+        return [[int(pt[0][0] * scale_x) + x_min, int(pt[0][1] * scale_y) + y_min] for pt in approx]
+    
+    def detect_and_segment(self):
+        """
+        Perform detection and segmentation (if both models are available).
+        """
+        logger.info("Performing detection and segmentation.")
+        self.current = Shape(
+            shape_type="points" if self.createMode in ["ai_polygon", "ai_mask"] else self.createMode
+        )
+
+        # Step 1: detection bounding box points
+        detection_results = self._detection_model.predict_polygon_from_points()
+
+        if not detection_results or len(detection_results) == 0:
+            logger.warning("No detection found")
+            return
+
+        logger.debug(f"Detection results: {detection_results}")
+        all_segmentation_results = []
+        rotated = False
+        # Step 2: Loop through each detection result since there are multiple per image
+        for detection_idx, detection_result in enumerate(detection_results):
+            logger.debug(f"Processing detection {detection_idx + 1}/{len(detection_results)}")
+
+            try:
+                #extracting
+                x_coords = [point[0] for point in detection_result]
+                y_coords = [point[1] for point in detection_result]
+                
+                #min and max values for x and y
+                x_min, x_max = min(x_coords), max(x_coords)
+                y_min, y_max = min(y_coords), max(y_coords)
+                
+                logger.debug(f"Bounding box for detection {detection_idx + 1} - x_min: {x_min}, y_min: {y_min}, x_max: {x_max}, y_max: {y_max}")
+            except Exception as e:
+                logger.error(f"Error extracting bounding box coordinates for detection {detection_idx + 1}: {e}")
+                continue
+
+            # Converting bounding box values to integers for cropping
+            x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
+
+            # Step 3: Cropping image based on detection output
+            try:
+                
+                cropped_image = self.pixmap.toImage().copy(x_min, y_min, x_max - x_min, y_max - y_min)
+                cropped_image = labelme.utils.img_qt_to_arr(cropped_image)
+
+                orig_height, orig_width = cropped_image.shape[:2]
+                logger.debug(f"Original height: {orig_height}, Original width: {orig_width}")
+                # if the height is greater than the width we rotate for segmentaion
+                if orig_height > orig_width:  
+                    cropped_image = cv2.rotate(cropped_image, cv2.ROTATE_90_CLOCKWISE)
+                    logger.debug(f"Rotated cropped image for detection {detection_idx + 1} due to height > width.")
+                    orig_cropped_shape = cropped_image.shape[:2]
+                    rotated = True
+
+                else: 
+                    rotated = False
+                # Save crop image
+                cv2.imwrite(f"cropped_image_{detection_idx + 1}.png", cropped_image)
+                logger.debug(f"Saved cropped image for detection {detection_idx + 1}: {cropped_image.shape}")
+
+                # logger.debug(f"Cropped image shape for detection {detection_idx + 1}: {cropped_image.shape}")
+            except Exception as e:
+                logger.error(f"Error cropping the image for detection {detection_idx + 1}: {e}")
+                continue
+
+            # Step 4: Resize the cropped image to match segmentation input size (1 64 256)
+            try:
+                orig_cropped_shape = cropped_image.shape[:2]  # Save the original cropped image size
+                preprocessed_img = self._detection_model.preprocess_image(cropped_image, for_segmentation=True)
+                logger.debug(f"Preprocessed image shape for segmentation detection {detection_idx + 1}: {preprocessed_img.shape}")
+            except Exception as e:
+                logger.error(f"Error preprocessing the image for segmentation for detection {detection_idx + 1}: {e}")
+                continue
+
+            # Step 5: inference on segmentation model on cropped image
+            try:
+                seg_result = self._segmentation_model.segmentation_sess.infer_new_request({'x': preprocessed_img})
+                logger.debug(f"Segmentation model inference completed for detection {detection_idx + 1}.")
+            except Exception as e:
+                logger.error(f"Error during segmentation model inference for detection {detection_idx + 1}: {e}")
+                continue
+
+            # Step 6: Convert binary mask to polygon (contours)
+            try:
+                mask = seg_result['save_infer_model/scale_0.tmp_0']  #model output name
+                mask = mask.squeeze()  # Remove batch dimension, should result in (64, 256)
+                logger.debug(f"Segmentation mask shape for detection {detection_idx + 1}: {mask.shape}")
+
+                # Normalize the mask to 0 and 255 and convert to uint8
+                mask = (mask * 255).astype(np.uint8)
+
+                logger.debug(f"Converted mask shape for detection {detection_idx + 1}: {mask.shape}, dtype: {mask.dtype}")
+                cv2.imwrite(f"segmentation_mask_{detection_idx + 1}.png", mask)
+                if rotated:
+                    cropped_image = cv2.rotate(cropped_image, cv2.ROTATE_90_COUNTERCLOCKWISE)
+                    mask = cv2.rotate(mask, cv2.ROTATE_90_COUNTERCLOCKWISE)
+                    rotated_cropped_shape = cropped_image.shape[:2]
+                
+                # cv2.imwrite(f"segmentation_mask_{detection_idx + 1}.png", mask)
+                logger.debug(f"Saved segmentation mask for detection {detection_idx + 1}.")
+                
+                # Step 7: Find contours
+                contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+                logger.debug(f"Found {len(contours)} contours in the mask for detection {detection_idx + 1}.")
+
+                if len(contours) > 0:
+                    largest_contour = max(contours, key=cv2.contourArea)
+                    
+                    # Step 8: Approximate a polygon with exactly 4 points (quadrilateral)
+                    epsilon = 0.02 * cv2.arcLength(largest_contour, True)  # epsilon for precision
+                    approx = cv2.approxPolyDP(largest_contour, epsilon, True)
+
+                    # If the approximation doesn't result in 4 points, force it
+                    if len(approx) != 4:
+                        # Using boundingRect as fallback in case of insufficient points
+                        print("log here")
+                        x, y, w, h = cv2.boundingRect(largest_contour)
+                        point_xy = [
+                            [x + x_min, y + y_min],          # Top-left
+                            [x + w + x_min, y + y_min],      # Top-right
+                            [x + w + x_min, y + h + y_min],  # Bottom-right
+                            [x + x_min, y + h + y_min]       # Bottom-left
+                        ]
+                    else:
+                        if rotated:
+                            point_xy = self.scale_points(approx, mask.shape, rotated_cropped_shape, x_min, y_min)
+                        else:
+                            point_xy = self.scale_points(approx, mask.shape, orig_cropped_shape, x_min, y_min)
+                    logger.debug(f"Generated 4 corner points for the polygon for detection {detection_idx + 1}: {point_xy}")
+                    self.pred_bbox_points = [point_xy]
+                    logger.debug(f"Predicted Bounding Box Points for detection {detection_idx + 1}: {self.pred_bbox_points}")
+                    if self.pred_bbox_points:
+                        self.draw_pred = True
+                        self.finalise()
+                    else:
+                        logger.info(f"No bounding boxes detected for detection {detection_idx + 1}.")
+                        
+                    # Collect the segmentation result
+                    all_segmentation_results.append(self.pred_bbox_points)
+            except Exception as e:
+                logger.error(f"Error creating the polygon shape for detection {detection_idx + 1}: {e}")
+
+            # **Reset critical variables after each detection**:
+            self.pred_bbox_points = None
+            self.draw_pred = False
+
+        # You now have a list of segmentation results for all detections
+        if all_segmentation_results:
+            logger.info(f"Segmentation results for all detections: {all_segmentation_results}")
+
+
     def keyReleaseEvent(self, ev):
         modifiers = ev.modifiers()
         if self.drawing():
@@ -985,6 +1229,8 @@ class Canvas(QtWidgets.QWidget):
 
     def setLastLabel(self, text, flags):
         assert text
+        if(self.shapes is None or len(self.shapes)==0):
+            return
         self.shapes[-1].label = text
         self.shapes[-1].flags = flags
         self.shapesBackups.pop()
@@ -992,11 +1238,12 @@ class Canvas(QtWidgets.QWidget):
         return self.shapes[-1]
 
     def undoLastLine(self):
-        assert self.shapes
+        if(self.shapes is None or len(self.shapes)==0):
+            return
         self.current = self.shapes.pop()
         self.current.setOpen()
         self.current.restoreShapeRaw()
-        if self.createMode in ["polygon", "linestrip"]:
+        if self.createMode in ["polygon", "linestrip"] and self.draw_pred is False:
             self.line.points = [self.current[-1], self.current[0]]
         elif self.createMode in ["rectangle", "line", "circle"]:
             self.current.points = self.current.points[0:1]
@@ -1017,8 +1264,8 @@ class Canvas(QtWidgets.QWidget):
 
     def loadPixmap(self, pixmap, clear_shapes=True):
         self.pixmap = pixmap
-        if self._ai_model:
-            self._ai_model.set_image(
+        if self._detection_model:
+            self._detection_model.set_image(
                 image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
             )
         if clear_shapes:

+ 17 - 0
labelme/widgets/label_dialog.py

@@ -47,6 +47,7 @@ class LabelDialog(QtWidgets.QDialog):
         self.edit.setPlaceholderText(text)
         self.edit.setValidator(labelme.utils.labelValidator())
         self.edit.editingFinished.connect(self.postProcess)
+        self.dragging=False
         if flags:
             self.edit.textChanged.connect(self.updateFlags)
         self.edit_group_id = QtWidgets.QLineEdit()
@@ -200,6 +201,7 @@ class LabelDialog(QtWidgets.QDialog):
         return None
 
     def popUp(self, text=None, move=True, flags=None, group_id=None, description=None):
+        self.show()
         if self._fit_to_content["row"]:
             self.labelList.setMinimumHeight(
                 self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
@@ -242,3 +244,18 @@ class LabelDialog(QtWidgets.QDialog):
             )
         else:
             return None, None, None, None
+    def mousePressEvent(self, event):
+        if event.button() == QtCore.Qt.LeftButton:
+            self.dragging = True
+            self.startPos = event.globalPos() - self.frameGeometry().topLeft()
+            event.accept()
+
+    def mouseMoveEvent(self, event):
+        if self.dragging:
+            self.move(event.globalPos() - self.startPos)
+            event.accept()
+
+    def mouseReleaseEvent(self, event):
+        if event.button() == QtCore.Qt.LeftButton:
+            self.dragging = False
+            event.accept()