Forráskód Böngészése

Fix the use of --config arg in main.py

Kentaro Wada 5 éve
szülő
commit
e62be5c570
2 módosított fájl, 17 hozzáadás és 11 törlés
  1. 16 4
      labelme/config/__init__.py
  2. 1 7
      labelme/main.py

+ 16 - 4
labelme/config/__init__.py

@@ -55,13 +55,25 @@ def validate_config_item(key, value):
         )
 
 
-def get_config(config_specified=None):
+def get_config(config_file_or_yaml=None, config_from_args=None):
     # 1. default config
     config = get_default_config()
 
-    # 2. command line argument or specified config file
-    if config_specified is not None:
-        update_dict(config, config_specified,
+    # 2. specified as file or yaml
+    if config_file_or_yaml is not None:
+        config_from_yaml = yaml.safe_load(config_file_or_yaml)
+        if not isinstance(config_from_yaml, dict):
+            with open(config_from_yaml) as f:
+                logger.info(
+                    'Loading config file from: {}'.format(config_from_yaml)
+                )
+                config_from_yaml = yaml.safe_load(f)
+        update_dict(config, config_from_yaml,
+                    validate_item=validate_config_item)
+
+    # 3. command line argument or specified config file
+    if config_from_args is not None:
+        update_dict(config, config_from_args,
                     validate_item=validate_config_item)
 
     return config

+ 1 - 7
labelme/main.py

@@ -142,13 +142,7 @@ def main():
     filename = config_from_args.pop('filename')
     output = config_from_args.pop('output')
     config_file_or_yaml = config_from_args.pop('config')
-    config = yaml.safe_load(config_file_or_yaml)
-    if not isinstance(config, dict):
-        logger.info('Loading config file from: {}'.format(config))
-        with open(config) as f:
-            config = yaml.safe_load(f)
-    config.update(config_from_args)  # prioritize config_from_args
-    config = get_config(config)
+    config = get_config(config_file_or_yaml, config_from_args)
 
     if not config['labels'] and config['validate_label']:
         logger.error('--labels must be specified with --validatelabel or '