Jelajahi Sumber

Enable user configure GUI with args and a yaml file

Kentaro Wada 7 tahun lalu
induk
melakukan
6b308ea563
4 mengubah file dengan 84 tambahan dan 56 penghapusan
  1. 42 40
      labelme/app.py
  2. 31 14
      labelme/config/__init__.py
  3. 8 0
      labelme/config/default_config.yaml
  4. 3 2
      labelme/labelDialog.py

+ 42 - 40
labelme/app.py

@@ -121,14 +121,14 @@ class LabelQListWidget(QtWidgets.QListWidget):
 class MainWindow(QtWidgets.QMainWindow, WindowMixin):
     FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = 0, 1, 2
 
-    def __init__(self, filename=None, output=None, store_data=True,
-                 labels=None, sort_labels=True, auto_save=False,
-                 validate_label=None, config=None):
-        super(MainWindow, self).__init__()
-        self.setWindowTitle(__appname__)
-
+    def __init__(self, config=None, filename=None, output=None):
+        # see labelme/config/default_config.yaml for valid configuration
         if config is None:
             config = get_config()
+        self._config = config
+
+        super(MainWindow, self).__init__()
+        self.setWindowTitle(__appname__)
 
         # Whether we need to save or not.
         self.dirty = False
@@ -136,8 +136,12 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self._noSelectionSlot = False
 
         # Main widgets and related state.
-        self.labelDialog = LabelDialog(parent=self, labels=labels,
-                                       sort_labels=sort_labels)
+        self.labelDialog = LabelDialog(
+            parent=self,
+            labels=self._config['labels'],
+            sort_labels=self._config['sort_labels'],
+            show_text_field=self._config['show_label_text_field'],
+        )
 
         self.labelList = LabelQListWidget()
         self.lastOpenDir = None
@@ -164,8 +168,8 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         self.uniqLabelList.setToolTip(
             "Select label to start annotating for it. "
             "Press 'Esc' to deselect.")
-        if labels:
-            self.uniqLabelList.addItems(labels)
+        if self._config['labels']:
+            self.uniqLabelList.addItems(self._config['labels'])
             self.uniqLabelList.sortItems()
         self.labelsdock = QtWidgets.QDockWidget(u'Label List', self)
         self.labelsdock.setObjectName(u'Label List')
@@ -220,7 +224,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 
         # Actions
         action = functools.partial(newAction, self)
-        shortcuts = config['shortcuts']
+        shortcuts = self._config['shortcuts']
         quit = action('&Quit', self.close, shortcuts['quit'], 'quit',
                       'Quit application')
         open_ = action('&Open', self.openFile, shortcuts['open'], 'open',
@@ -400,18 +404,12 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         # Application state.
         self.image = QtGui.QImage()
         self.imagePath = None
-        if auto_save and output is not None:
+        if self._config['auto_save'] and output is not None:
             warnings.warn('If `auto_save` argument is True, `output` argument '
                           'is ignored and output filename is automatically '
                           'set as IMAGE_BASENAME.json.')
         self.labeling_once = output is not None
         self.output = output
-        self._auto_save = auto_save
-        self._store_data = store_data
-        if validate_label not in [None, 'exact', 'instance']:
-            raise ValueError('Unexpected `validate_label`: {}'
-                             .format(validate_label))
-        self._validate_label = validate_label
         self.recentFiles = []
         self.maxRecent = 7
         self.lineColor = None
@@ -477,7 +475,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         addActions(self.menus.edit, actions + self.actions.editMenu)
 
     def setDirty(self):
-        if self._auto_save:
+        if self._config['auto_save']:
             label_file = os.path.splitext(self.imagePath)[0] + '.json'
             self.saveLabels(label_file)
             return
@@ -587,15 +585,15 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
 
     def validateLabel(self, label):
         # no validation
-        if self._validate_label is None:
+        if self._config['validate_label'] is None:
             return True
 
         for i in range(self.uniqLabelList.count()):
             l = self.uniqLabelList.item(i).text()
-            if self._validate_label in ['exact', 'instance']:
+            if self._config['validate_label'] in ['exact', 'instance']:
                 if l == label:
                     return True
-            if self._validate_label == 'instance':
+            if self._config['validate_label'] == 'instance':
                 m = re.match(r'^{}-[0-9]*$'.format(l), label)
                 if m:
                     return True
@@ -611,7 +609,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         if not self.validateLabel(text):
             self.errorMessage('Invalid label',
                               "Invalid label '{}' with validation type '{}'"
-                              .format(text, self._validate_label))
+                              .format(text, self._config['validate_label']))
             return
         item.setText(text)
         self.setDirty()
@@ -702,7 +700,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         try:
             imagePath = os.path.relpath(
                 self.imagePath, os.path.dirname(filename))
-            imageData = self.imageData if self._store_data else None
+            imageData = self.imageData if self._config['store_data'] else None
             lf.save(filename, shapes, imagePath, imageData,
                     self.lineColor.getRgb(), self.fillColor.getRgb(),
                     self.otherData)
@@ -750,7 +748,7 @@ class MainWindow(QtWidgets.QMainWindow, WindowMixin):
         if text is not None and not self.validateLabel(text):
             self.errorMessage('Invalid label',
                               "Invalid label '{}' with validation type '{}'"
-                              .format(text, self._validate_label))
+                              .format(text, self._config['validate_label']))
             text = None
         if text is None:
             self.canvas.undoLastLine()
@@ -1200,21 +1198,24 @@ def read(filename, default=None):
 
 
 def main():
-    """Standard boilerplate Qt application code."""
     parser = argparse.ArgumentParser()
     parser.add_argument('--version', '-V', action='store_true',
                         help='show version')
     parser.add_argument('filename', nargs='?', help='image or label filename')
     parser.add_argument('--output', '-O', '-o', help='output label name')
+    parser.add_argument('--config', dest='config_file', help='config file')
+    # config
     parser.add_argument('--nodata', dest='store_data', action='store_false',
                         help='stop storing image data to JSON file')
-    parser.add_argument('--autosave', action='store_true', help='auto save')
+    parser.add_argument('--autosave', dest='auto_save', action='store_true',
+                        help='auto save')
     parser.add_argument('--labels',
                         help='comma separated list of labels OR file '
                         'containing one label per line')
     parser.add_argument('--nosortlabels', dest='sort_labels',
                         action='store_false', help='stop sorting labels')
-    parser.add_argument('--validatelabel', choices=['exact', 'instance'],
+    parser.add_argument('--validatelabel', dest='validate_label',
+                        choices=['exact', 'instance'],
                         help='label validation types')
     args = parser.parse_args()
 
@@ -1223,8 +1224,10 @@ def main():
         sys.exit(0)
 
     if args.labels is None:
-        if args.validatelabel is not None:
-            logger.error('--labels must be specified with --validatelabel')
+        if args.validate_label is not None:
+            logger.error('--labels must be specified with --validatelabel or '
+                         'validate_label: true in the config file '
+                         '(ex. ~/.labelmerc).')
             sys.exit(1)
     else:
         if os.path.isfile(args.labels):
@@ -1233,18 +1236,17 @@ def main():
         else:
             args.labels = [l for l in args.labels.split(',') if l]
 
+    config_from_args = args.__dict__
+    config_from_args.pop('version')
+    filename = config_from_args.pop('filename')
+    output = config_from_args.pop('output')
+    config_file = config_from_args.pop('config_file')
+    config = get_config(config_from_args, config_file)
+
     app = QtWidgets.QApplication(sys.argv)
     app.setApplicationName(__appname__)
-    app.setWindowIcon(newIcon("icon"))
-    win = MainWindow(
-        filename=args.filename,
-        output=args.output,
-        store_data=args.store_data,
-        labels=args.labels,
-        sort_labels=args.sort_labels,
-        auto_save=args.autosave,
-        validate_label=args.validatelabel,
-    )
+    app.setWindowIcon(newIcon('icon'))
+    win = MainWindow(config=config, filename=filename, output=output)
     win.show()
     win.raise_()
     sys.exit(app.exec_())

+ 31 - 14
labelme/config/__init__.py

@@ -9,42 +9,59 @@ from labelme import logger
 here = osp.dirname(osp.abspath(__file__))
 
 
-def update_dict(target_dict, new_dict):
+def update_dict(target_dict, new_dict, validate_item=None):
     for key, value in new_dict.items():
+        if validate_item:
+            validate_item(key, value)
         if key not in target_dict:
             logger.warn('Skipping unexpected key in config: {}'
                         .format(key))
             continue
         if isinstance(target_dict[key], dict) and \
                 isinstance(value, dict):
-            update_dict(target_dict[key], value)
+            update_dict(target_dict[key], value, validate_item=validate_item)
         else:
             target_dict[key] = value
 
 
+# -----------------------------------------------------------------------------
+
+
 def get_default_config():
     config_file = osp.join(here, 'default_config.yaml')
     config = yaml.load(open(config_file))
     return config
 
 
-def get_config():
+def validate_config_item(key, value):
+    if key == 'validate_label' and value not in [None, 'exact', 'instance']:
+        raise ValueError('Unexpected value `{}` for key `{}`'
+                         .format(value, key))
+
+
+def get_config(config_from_args=None, config_file=None):
     # default config
     config = get_default_config()
 
-    # shortcuts for actions
-    home = os.path.expanduser('~')
-    config_file = os.path.join(home, '.labelmerc')
+    if config_from_args is not None:
+        update_dict(config, config_from_args,
+                    validate_item=validate_config_item)
+
+    save_config_file = False
+    if config_file is None:
+        home = os.path.expanduser('~')
+        config_file = os.path.join(home, '.labelmerc')
+        save_config_file = True
 
     if os.path.exists(config_file):
         user_config = yaml.load(open(config_file)) or {}
-        update_dict(config, user_config)
-
-    # save config
-    try:
-        yaml.safe_dump(config, open(config_file, 'w'),
-                       default_flow_style=False)
-    except Exception:
-        logger.warn('Failed to save config: {}'.format(config_file))
+        update_dict(config, user_config, validate_item=validate_config_item)
+
+    if save_config_file:
+        try:
+            yaml.safe_dump(config, open(config_file, 'w'),
+                           default_flow_style=False)
+        except Exception:
+            logger.warn('Failed to save config: {}'.format(config_file))
 
     return config

+ 8 - 0
labelme/config/default_config.yaml

@@ -1,3 +1,11 @@
+auto_save: false
+store_data: true
+
+labels: null
+sort_labels: true
+validate_label: null
+show_label_text_field: true
+
 shortcuts:
   close: Ctrl+W
   open: Ctrl+O

+ 3 - 2
labelme/labelDialog.py

@@ -28,14 +28,15 @@ class LabelQLineEdit(QtWidgets.QLineEdit):
 class LabelDialog(QtWidgets.QDialog):
 
     def __init__(self, text="Enter object label", parent=None, labels=None,
-                 sort_labels=True):
+                 sort_labels=True, show_text_field=True):
         super(LabelDialog, self).__init__(parent)
         self.edit = LabelQLineEdit()
         self.edit.setPlaceholderText(text)
         self.edit.setValidator(labelValidator())
         self.edit.editingFinished.connect(self.postProcess)
         layout = QtWidgets.QVBoxLayout()
-        layout.addWidget(self.edit)
+        if show_text_field:
+            layout.addWidget(self.edit)
         # buttons
         self.buttonBox = bb = QtWidgets.QDialogButtonBox(
             QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,