__init__.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os.path as osp
  2. import shutil
  3. import yaml
  4. from labelme.logger import logger
  5. here = osp.dirname(osp.abspath(__file__))
  6. def update_dict(target_dict, new_dict, validate_item=None):
  7. for key, value in new_dict.items():
  8. if validate_item:
  9. validate_item(key, value)
  10. if key not in target_dict:
  11. logger.warn('Skipping unexpected key in config: {}'
  12. .format(key))
  13. continue
  14. if isinstance(target_dict[key], dict) and \
  15. isinstance(value, dict):
  16. update_dict(target_dict[key], value, validate_item=validate_item)
  17. else:
  18. target_dict[key] = value
  19. # -----------------------------------------------------------------------------
  20. def get_default_config():
  21. config_file = osp.join(here, 'default_config.yaml')
  22. with open(config_file) as f:
  23. config = yaml.safe_load(f)
  24. # save default config to ~/.labelmerc
  25. user_config_file = osp.join(osp.expanduser('~'), '.labelmerc')
  26. if not osp.exists(user_config_file):
  27. try:
  28. shutil.copy(config_file, user_config_file)
  29. except Exception:
  30. logger.warn('Failed to save config: {}'.format(user_config_file))
  31. return config
  32. def validate_config_item(key, value):
  33. if key == 'validate_label' and value not in [None, 'exact']:
  34. raise ValueError(
  35. "Unexpected value for config key 'validate_label': {}"
  36. .format(value)
  37. )
  38. if key == 'shape_color' and value not in [None, 'auto', 'manual']:
  39. raise ValueError(
  40. "Unexpected value for config key 'shape_color': {}"
  41. .format(value)
  42. )
  43. if key == 'labels' and value is not None and len(value) != len(set(value)):
  44. raise ValueError(
  45. "Duplicates are detected for config key 'labels': {}".format(value)
  46. )
  47. def get_config(config_file_or_yaml=None, config_from_args=None):
  48. # 1. default config
  49. config = get_default_config()
  50. # 2. specified as file or yaml
  51. if config_file_or_yaml is not None:
  52. config_from_yaml = yaml.safe_load(config_file_or_yaml)
  53. if not isinstance(config_from_yaml, dict):
  54. with open(config_from_yaml) as f:
  55. logger.info(
  56. 'Loading config file from: {}'.format(config_from_yaml)
  57. )
  58. config_from_yaml = yaml.safe_load(f)
  59. update_dict(config, config_from_yaml,
  60. validate_item=validate_config_item)
  61. # 3. command line argument or specified config file
  62. if config_from_args is not None:
  63. update_dict(config, config_from_args,
  64. validate_item=validate_config_item)
  65. return config