|
@@ -101,10 +101,7 @@ class Canvas(QtWidgets.QWidget):
|
|
|
self.setMouseTracking(True)
|
|
|
self.setFocusPolicy(QtCore.Qt.WheelFocus)
|
|
|
|
|
|
- self._ai_callback = None
|
|
|
-
|
|
|
- def setAiCallback(self, ai_callback):
|
|
|
- self._ai_callback = ai_callback
|
|
|
+ self._ai_model = None
|
|
|
|
|
|
def fillDrawing(self):
|
|
|
return self._fill_drawing
|
|
@@ -129,6 +126,12 @@ class Canvas(QtWidgets.QWidget):
|
|
|
]:
|
|
|
raise ValueError("Unsupported createMode: %s" % value)
|
|
|
|
|
|
+ if value == "ai_polygon" and self._ai_model is None:
|
|
|
+ self._ai_model = labelme.ai.SegmentAnythingModel()
|
|
|
+ self._ai_model.set_image(
|
|
|
+ image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
|
|
|
+ )
|
|
|
+
|
|
|
self._createMode = value
|
|
|
|
|
|
def storeShapes(self):
|
|
@@ -757,7 +760,7 @@ class Canvas(QtWidgets.QWidget):
|
|
|
if self.createMode == "ai_polygon":
|
|
|
# convert points to polygon by an AI model
|
|
|
assert self.current.shape_type == "points"
|
|
|
- points = self._ai_callback(
|
|
|
+ points = self._ai_model.points_to_polygon_callback(
|
|
|
points=[
|
|
|
[point.x(), point.y()] for point in self.current.points
|
|
|
],
|
|
@@ -961,6 +964,10 @@ class Canvas(QtWidgets.QWidget):
|
|
|
|
|
|
def loadPixmap(self, pixmap, clear_shapes=True):
|
|
|
self.pixmap = pixmap
|
|
|
+ if self._ai_model:
|
|
|
+ self._ai_model.set_image(
|
|
|
+ image=labelme.utils.img_qt_to_arr(self.pixmap.toImage())
|
|
|
+ )
|
|
|
if clear_shapes:
|
|
|
self.shapes = []
|
|
|
self.update()
|