浏览代码

Fix the order of loading configuration

Close https://github.com/wkentaro/labelme/issues/149

1. default config (lowest priority)
2. config file passed by command line argument or ~/.labelmerc
3. command line argument (highest priority)
Kentaro Wada 7 年之前
父节点
当前提交
1486f03588
共有 2 个文件被更改,包括 27 次插入18 次删除
  1. 23 17
      labelme/config/__init__.py
  2. 4 1
      tests/test_app.py

+ 23 - 17
labelme/config/__init__.py

@@ -1,4 +1,3 @@
-import os
 import os.path as osp
 
 import yaml
@@ -41,29 +40,36 @@ def validate_config_item(key, value):
 
 
 def get_config(config_from_args=None, config_file=None):
-    # default config
+    # Configuration load order:
+    #
+    #   1. default config (lowest priority)
+    #   2. config file passed by command line argument or ~/.labelmerc
+    #   3. command line argument (highest priority)
+
+    # 1. default config
     config = get_default_config()
 
-    if config_from_args is not None:
-        update_dict(config, config_from_args,
-                    validate_item=validate_config_item)
+    # save default config to ~/.labelmerc
+    home = osp.expanduser('~')
+    default_config_file = osp.join(home, '.labelmerc')
+    if not osp.exists(default_config_file):
+        try:
+            with open(config_file, 'w') as f:
+                yaml.safe_dump(config, f, default_flow_style=False)
+        except Exception:
+            logger.warn('Failed to save config: {}'.format(config_file))
 
-    save_config_file = False
+    # 2. config from yaml file
     if config_file is None:
-        home = os.path.expanduser('~')
-        config_file = os.path.join(home, '.labelmerc')
-        save_config_file = True
-
-    if os.path.exists(config_file):
+        config_file = default_config_file
+    if osp.exists(config_file):
         with open(config_file) as f:
             user_config = yaml.load(f) or {}
         update_dict(config, user_config, validate_item=validate_config_item)
 
-    if save_config_file:
-        try:
-            with open(config_file, 'w') as f:
-                yaml.safe_dump(config, f, default_flow_style=False)
-        except Exception:
-            logger.warn('Failed to save config: {}'.format(config_file))
+    # 3. command line argument
+    if config_from_args is not None:
+        update_dict(config, config_from_args,
+                    validate_item=validate_config_item)
 
     return config

+ 4 - 1
tests/test_app.py

@@ -3,6 +3,7 @@ import shutil
 import tempfile
 
 import labelme.app
+import labelme.config
 import labelme.testing
 
 
@@ -33,7 +34,9 @@ def test_MainWindow_annotate_jpg(qtbot):
                 filename)
     output = osp.join(tmp_dir, 'apc2016_obj3.json')
 
-    win = labelme.app.MainWindow(filename=filename, output=output)
+    config = labelme.config.get_default_config()
+    win = labelme.app.MainWindow(
+        config=config, filename=filename, output=output)
     qtbot.addWidget(win)
     win.show()