__init__.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os.path as osp
  2. import shutil
  3. import yaml
  4. from labelme.logger import logger
  5. here = osp.dirname(osp.abspath(__file__))
  6. def update_dict(target_dict, new_dict, validate_item=None):
  7. for key, value in new_dict.items():
  8. if validate_item:
  9. validate_item(key, value)
  10. if key not in target_dict:
  11. logger.warn("Skipping unexpected key in config: {}".format(key))
  12. continue
  13. if isinstance(target_dict[key], dict) and isinstance(value, dict):
  14. update_dict(target_dict[key], value, validate_item=validate_item)
  15. else:
  16. target_dict[key] = value
  17. # -----------------------------------------------------------------------------
  18. def get_default_config():
  19. config_file = osp.join(here, "default_config.yaml")
  20. with open(config_file) as f:
  21. config = yaml.safe_load(f)
  22. # save default config to ~/.labelmerc
  23. user_config_file = osp.join(osp.expanduser("~"), ".labelmerc")
  24. if not osp.exists(user_config_file):
  25. try:
  26. shutil.copy(config_file, user_config_file)
  27. except Exception:
  28. logger.warn("Failed to save config: {}".format(user_config_file))
  29. return config
  30. def validate_config_item(key, value):
  31. if key == "validate_label" and value not in [None, "exact"]:
  32. raise ValueError(
  33. "Unexpected value for config key 'validate_label': {}".format(
  34. value
  35. )
  36. )
  37. if key == "shape_color" and value not in [None, "auto", "manual"]:
  38. raise ValueError(
  39. "Unexpected value for config key 'shape_color': {}".format(value)
  40. )
  41. if key == "labels" and value is not None and len(value) != len(set(value)):
  42. raise ValueError(
  43. "Duplicates are detected for config key 'labels': {}".format(value)
  44. )
  45. def get_config(config_file_or_yaml=None, config_from_args=None):
  46. # 1. default config
  47. config = get_default_config()
  48. # 2. specified as file or yaml
  49. if config_file_or_yaml is not None:
  50. config_from_yaml = yaml.safe_load(config_file_or_yaml)
  51. if not isinstance(config_from_yaml, dict):
  52. with open(config_from_yaml) as f:
  53. logger.info(
  54. "Loading config file from: {}".format(config_from_yaml)
  55. )
  56. config_from_yaml = yaml.safe_load(f)
  57. update_dict(
  58. config, config_from_yaml, validate_item=validate_config_item
  59. )
  60. # 3. command line argument or specified config file
  61. if config_from_args is not None:
  62. update_dict(
  63. config, config_from_args, validate_item=validate_config_item
  64. )
  65. return config