Quellcode durchsuchen

Add support for loading labels

Refactor label file format into separate class/file.
Michael Pitidis vor 13 Jahren
Ursprung
Commit
cf9e6ab53b
3 geänderte Dateien mit 94 neuen und 24 gelöschten Zeilen
  1. 10 0
      canvas.py
  2. 46 0
      labelFile.py
  3. 38 24
      labelme.py

+ 10 - 0
canvas.py

@@ -285,6 +285,16 @@ class Canvas(QWidget):
         assert self.shapes
         self.shapes.pop()
 
+    def loadPixmap(self, pixmap):
+        self.pixmap = pixmap
+        self.shapes = []
+        self.repaint()
+
+    def loadShapes(self, shapes):
+        self.shapes = list(shapes)
+        self.current = None
+        self.repaint()
+
 
 def distance(p):
     return sqrt(p.x() * p.x() + p.y() * p.y())

+ 46 - 0
labelFile.py

@@ -0,0 +1,46 @@
+
+import json
+import os.path
+
+from base64 import b64encode, b64decode
+
+class LabelFileError(Exception):
+    pass
+
+class LabelFile(object):
+    suffix = '.lif'
+
+    def __init__(self):
+        self.shapes = ()
+        self.imagePath = None
+        self.imageData = None
+
+    def load(self, filename):
+        try:
+            with open(filename, 'rb') as f:
+                data = json.load(f)
+                imagePath = data['imagePath']
+                imageData = b64decode(data['imageData'])
+                shapes = ((s['label'], s['points']) for s in data['shapes'])
+                # Only replace data after everything is loaded.
+                self.shapes = shapes
+                self.imagePath = imagePath
+                self.imageData = imageData
+        except Exception, e:
+            raise LabelFileError(e)
+
+    def save(self, filename, shapes, imagePath, imageData):
+        try:
+            with open(filename, 'wb') as f:
+                json.dump(dict(
+                    shapes=[dict(label=l, points=p) for (l, p) in shapes],
+                    imagePath=imagePath,
+                    imageData=b64encode(imageData)),
+                    f, ensure_ascii=True, indent=2)
+        except Exception, e:
+            raise LabelFileError(e)
+
+    @staticmethod
+    def isLabelFile(filename):
+        return os.path.splitext(filename)[1].lower() == LabelFile.suffix
+

+ 38 - 24
labelme.py

@@ -21,6 +21,7 @@ from shape import Shape
 from canvas import Canvas
 from zoomWidget import ZoomWidget
 from labelDialog import LabelDialog
+from labelFile import LabelFile
 
 
 __appname__ = 'labelme'
@@ -176,19 +177,22 @@ class MainWindow(QMainWindow, WindowMixin):
         self.zoom_widget.editingFinished.connect(self.paintCanvas)
 
 
+    def loadLabels(self, shapes):
+        s = []
+        for label, points in shapes:
+            shape = Shape(label=label)
+            shape.fill = True
+            for x, y in points:
+                shape.addPoint(QPointF(x, y))
+            s.append(shape)
+            self.addLabel(label, shape)
+        self.canvas.loadShapes(s)
+
     def saveLabels(self, filename):
-        shapes = []
-        for shape in self.canvas.shapes:
-            data = {}
-            data['points'] = [(p.x(), p.y()) for p in shape.points]
-            data['label'] = unicode(shape.label)
-            shapes.append(data)
-        with open(filename, 'wb') as f:
-            json.dump(dict(
-                shapes=shapes,
-                image_path=unicode(self.filename),
-                image_data=b64encode(self.image_data)),
-                f, ensure_ascii=True, indent=2)
+        lf = LabelFile()
+        shapes = [(unicode(shape.label), [(p.x(), p.y()) for p in shape.points])\
+                for shape in self.canvas.shapes]
+        lf.save(filename, shapes, unicode(self.filename), self.imageData)
 
     def addLabel(self, label, shape):
         item = QListWidgetItem(label)
@@ -248,19 +252,31 @@ class MainWindow(QMainWindow, WindowMixin):
         """Load the specified file, or the last opened file if None."""
         if filename is None:
             filename = self.settings['filename']
-        # FIXME: Load the actual file here.
+        filename = unicode(filename)
         if QFile.exists(filename):
-            # Load image: read data first and store for saving into label file.
-            #image = QImage(filename)
-            self.image_data = read(filename, None)
-            image = QImage.fromData(self.image_data)
+            if LabelFile.isLabelFile(filename):
+                # TODO: Error handling.
+                lf = LabelFile()
+                lf.load(filename)
+                self.labelFile = lf
+                self.imageData = lf.imageData
+            else:
+                # Load image:
+                # read data first and store for saving into label file.
+                self.imageData = read(filename, None)
+                self.labelFile = None
+            image = QImage.fromData(self.imageData)
             if image.isNull():
                 message = "Failed to read %s" % filename
             else:
                 message = "Loaded %s" % os.path.basename(unicode(filename))
                 self.image = image
                 self.filename = filename
-                self.loadPixmap()
+                self.labels = {}
+                self.labelList.clear()
+                self.canvas.loadPixmap(QPixmap.fromImage(image))
+                if self.labelFile:
+                    self.loadLabels(self.labelFile.shapes)
             self.statusBar().showMessage(message)
 
     def resizeEvent(self, event):
@@ -268,10 +284,6 @@ class MainWindow(QMainWindow, WindowMixin):
             self.paintCanvas()
         super(MainWindow, self).resizeEvent(event)
 
-    def loadPixmap(self):
-        assert not self.image.isNull(), "cannot load null image"
-        self.canvas.pixmap = QPixmap.fromImage(self.image)
-
     def paintCanvas(self):
         assert not self.image.isNull(), "cannot paint null image"
         self.canvas.scale = self.fitSize() if self.fit_window\
@@ -313,8 +325,10 @@ class MainWindow(QMainWindow, WindowMixin):
                 if self.filename else '.'
         formats = ['*.%s' % unicode(fmt).lower()\
                 for fmt in QImageReader.supportedImageFormats()]
+        filters = 'Image files (%s)\nLabel files (*%s)'\
+                % (' '.join(formats), LabelFile.suffix)
         filename = unicode(QFileDialog.getOpenFileName(self,
-            '%s - Choose Image', path, 'Image files (%s)' % ' '.join(formats)))
+            '%s - Choose Image', path, filters))
         if filename:
             self.loadFile(filename)
 
@@ -324,7 +338,7 @@ class MainWindow(QMainWindow, WindowMixin):
         assert self.labels, "cannot save empty labels"
         path = os.path.dirname(unicode(self.filename))\
                 if self.filename else '.'
-        formats = ['*.lif']
+        formats = ['*%s' % LabelFile.suffix]
         filename = unicode(QFileDialog.getSaveFileName(self,
             '%s - Choose File', path, 'Label files (%s)' % ''.join(formats)))
         if filename: