Преглед изворни кода

Support directory with --output option

Kentaro Wada пре 6 година
родитељ
комит
54f055ade0
2 измењених фајлова са 73 додато и 17 уклоњено
  1. 52 15
      labelme/app.py
  2. 21 2
      labelme/main.py

+ 52 - 15
labelme/app.py

@@ -3,7 +3,6 @@ import io
 import os
 import os.path as osp
 import re
-import warnings
 import webbrowser
 
 import PIL.Image
@@ -19,6 +18,7 @@ from labelme import QT5
 from labelme.config import get_config
 from labelme.label_file import LabelFile
 from labelme.label_file import LabelFileError
+from labelme import logger
 from labelme.shape import DEFAULT_FILL_COLOR
 from labelme.shape import DEFAULT_LINE_COLOR
 from labelme.shape import Shape
@@ -68,7 +68,21 @@ class WindowMixin(object):
 class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = 0, 1, 2
 
-    def __init__(self, config=None, filename=None, output=None):
+    def __init__(
+        self,
+        config=None,
+        filename=None,
+        output=None,
+        output_file=None,
+        output_dir=None,
+    ):
+        if output is not None:
+            logger.warn(
+                'argument output is deprecated, use output_file instead'
+            )
+            if output_file is None:
+                output_file = output
+
         # see labelme/config/default_config.yaml for valid configuration
         if config is None:
             config = get_config()
@@ -507,15 +521,18 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.statusBar().showMessage('%s started.' % __appname__)
         self.statusBar().show()
 
+        if output_file is not None and self._config['auto_save']:
+            logger.warn(
+                'If `auto_save` argument is True, `output_file` argument '
+                'is ignored and output filename is automatically '
+                'set as IMAGE_BASENAME.json.'
+            )
+        self.output_file = output_file
+        self.output_dir = output_dir
+
         # Application state.
         self.image = QtGui.QImage()
         self.imagePath = None
-        if self._config['auto_save'] and output is not None:
-            warnings.warn('If `auto_save` argument is True, `output` argument '
-                          'is ignored and output filename is automatically '
-                          'set as IMAGE_BASENAME.json.')
-        self.labeling_once = output is not None
-        self.output = output
         self.recentFiles = []
         self.maxRecent = 7
         self.lineColor = None
@@ -595,6 +612,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     def setDirty(self):
         if self._config['auto_save'] or self.actions.saveAuto.isChecked():
             label_file = osp.splitext(self.imagePath)[0] + '.json'
+            if self.output_dir:
+                label_file = osp.join(self.output_dir, label_file)
             self.saveLabels(label_file)
             return
         self.dirty = True
@@ -898,6 +917,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
             imagePath = osp.relpath(
                 self.imagePath, osp.dirname(filename))
             imageData = self.imageData if self._config['store_data'] else None
+            if not osp.exists(osp.dirname(filename)):
+                os.makedirs(osp.dirname(filename))
             lf.save(
                 filename=filename,
                 shapes=shapes,
@@ -1053,6 +1074,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         # assumes same name, but json extension
         self.status("Loading %s..." % osp.basename(str(filename)))
         label_file = osp.splitext(filename)[0] + '.json'
+        if self.output_dir:
+            label_file = osp.join(self.output_dir, label_file)
         if QtCore.QFile.exists(label_file) and \
                 LabelFile.isLabelFile(label_file):
             try:
@@ -1255,8 +1278,9 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
             if self.labelFile:
                 # DL20180323 - overwrite when in directory
                 self._saveFile(self.labelFile.filename)
-            elif self.output:
-                self._saveFile(self.output)
+            elif self.output_file:
+                self._saveFile(self.output_file)
+                self.close()
             else:
                 self._saveFile(self.saveFileDialog())
 
@@ -1268,14 +1292,27 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     def saveFileDialog(self):
         caption = '%s - Choose File' % __appname__
         filters = 'Label files (*%s)' % LabelFile.suffix
-        dlg = QtWidgets.QFileDialog(self, caption, self.currentPath(), filters)
+        if self.output_dir:
+            dlg = QtWidgets.QFileDialog(
+                self, caption, self.output_dir, filters
+            )
+        else:
+            dlg = QtWidgets.QFileDialog(
+                self, caption, self.currentPath(), filters
+            )
         dlg.setDefaultSuffix(LabelFile.suffix[1:])
         dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
         dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False)
         dlg.setOption(QtWidgets.QFileDialog.DontUseNativeDialog, False)
         basename = osp.splitext(self.filename)[0]
-        default_labelfile_name = osp.join(
-            self.currentPath(), basename + LabelFile.suffix)
+        if self.output_dir:
+            default_labelfile_name = osp.join(
+                self.output_dir, basename + LabelFile.suffix
+            )
+        else:
+            default_labelfile_name = osp.join(
+                self.currentPath(), basename + LabelFile.suffix
+            )
         filename = dlg.getSaveFileName(
             self, 'Choose File', default_labelfile_name,
             'Label files (*%s)' % LabelFile.suffix)
@@ -1288,8 +1325,6 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         if filename and self.saveLabels(filename):
             self.addRecentFile(filename)
             self.setClean()
-            if self.labeling_once:
-                self.close()
 
     def closeFile(self, _value=False):
         if not self.mayContinue():
@@ -1429,6 +1464,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
             if pattern and pattern not in filename:
                 continue
             label_file = osp.splitext(filename)[0] + '.json'
+            if self.output_dir:
+                label_file = osp.join(self.output_dir, label_file)
             item = QtWidgets.QListWidgetItem(filename)
             item.setFlags(Qt.ItemIsEnabled | Qt.ItemIsSelectable)
             if QtCore.QFile.exists(label_file) and \

+ 21 - 2
labelme/main.py

@@ -22,7 +22,13 @@ def main():
         '--reset-config', action='store_true', help='reset qt config'
     )
     parser.add_argument('filename', nargs='?', help='image or label filename')
-    parser.add_argument('--output', '-O', '-o', help='output label name')
+    parser.add_argument(
+        '--output',
+        '-O',
+        '-o',
+        help='output file or directory (if it ends with .json it is '
+             'recognized as file, else as directory)'
+    )
     default_config_file = os.path.join(os.path.expanduser('~'), '.labelmerc')
     parser.add_argument(
         '--config',
@@ -115,10 +121,23 @@ def main():
                      '(ex. ~/.labelmerc).')
         sys.exit(1)
 
+    output_file = None
+    output_dir = None
+    if output is not None:
+        if output.endswith('.json'):
+            output_file = output
+        else:
+            output_dir = output
+
     app = QtWidgets.QApplication(sys.argv)
     app.setApplicationName(__appname__)
     app.setWindowIcon(newIcon('icon'))
-    win = MainWindow(config=config, filename=filename, output=output)
+    win = MainWindow(
+        config=config,
+        filename=filename,
+        output_file=output_file,
+        output_dir=output_dir,
+    )
 
     if reset_config:
         print('Resetting Qt config: %s' % win.settings.fileName())