gather_models.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import os
  5. import os.path as osp
  6. import shutil
  7. import subprocess
  8. import time
  9. from collections import OrderedDict
  10. import torch
  11. import yaml
  12. from mmengine.config import Config
  13. from mmengine.fileio import dump
  14. from mmengine.utils import mkdir_or_exist, scandir
  15. def ordered_yaml_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds):
  16. class OrderedDumper(Dumper):
  17. pass
  18. def _dict_representer(dumper, data):
  19. return dumper.represent_mapping(
  20. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items())
  21. OrderedDumper.add_representer(OrderedDict, _dict_representer)
  22. return yaml.dump(data, stream, OrderedDumper, **kwds)
  23. def process_checkpoint(in_file, out_file):
  24. checkpoint = torch.load(in_file, map_location='cpu')
  25. # remove optimizer for smaller file size
  26. if 'optimizer' in checkpoint:
  27. del checkpoint['optimizer']
  28. if 'ema_state_dict' in checkpoint:
  29. del checkpoint['ema_state_dict']
  30. # remove ema state_dict
  31. for key in list(checkpoint['state_dict']):
  32. if key.startswith('ema_'):
  33. checkpoint['state_dict'].pop(key)
  34. elif key.startswith('data_preprocessor'):
  35. checkpoint['state_dict'].pop(key)
  36. # if it is necessary to remove some sensitive data in checkpoint['meta'],
  37. # add the code here.
  38. if torch.__version__ >= '1.6':
  39. torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
  40. else:
  41. torch.save(checkpoint, out_file)
  42. sha = subprocess.check_output(['sha256sum', out_file]).decode()
  43. final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
  44. subprocess.Popen(['mv', out_file, final_file])
  45. return final_file
  46. def is_by_epoch(config):
  47. cfg = Config.fromfile('./configs/' + config)
  48. return cfg.train_cfg.type == 'EpochBasedTrainLoop'
  49. def get_final_epoch_or_iter(config):
  50. cfg = Config.fromfile('./configs/' + config)
  51. if cfg.train_cfg.type == 'EpochBasedTrainLoop':
  52. return cfg.train_cfg.max_epochs
  53. else:
  54. return cfg.train_cfg.max_iters
  55. def get_best_epoch_or_iter(exp_dir):
  56. best_epoch_iter_full_path = list(
  57. sorted(glob.glob(osp.join(exp_dir, 'best_*.pth'))))[-1]
  58. best_epoch_or_iter_model_path = best_epoch_iter_full_path.split('/')[-1]
  59. best_epoch_or_iter = best_epoch_or_iter_model_path.\
  60. split('_')[-1].split('.')[0]
  61. return best_epoch_or_iter_model_path, int(best_epoch_or_iter)
  62. def get_real_epoch_or_iter(config):
  63. cfg = Config.fromfile('./configs/' + config)
  64. if cfg.train_cfg.type == 'EpochBasedTrainLoop':
  65. epoch = cfg.train_cfg.max_epochs
  66. return epoch
  67. else:
  68. return cfg.train_cfg.max_iters
  69. def get_final_results(log_json_path,
  70. epoch_or_iter,
  71. results_lut='coco/bbox_mAP',
  72. by_epoch=True):
  73. result_dict = dict()
  74. with open(log_json_path) as f:
  75. r = f.readlines()[-1]
  76. last_metric = r.split(',')[0].split(': ')[-1].strip()
  77. result_dict[results_lut] = last_metric
  78. return result_dict
  79. def get_dataset_name(config):
  80. # If there are more dataset, add here.
  81. name_map = dict(
  82. CityscapesDataset='Cityscapes',
  83. CocoDataset='COCO',
  84. CocoPanopticDataset='COCO',
  85. DeepFashionDataset='Deep Fashion',
  86. LVISV05Dataset='LVIS v0.5',
  87. LVISV1Dataset='LVIS v1',
  88. VOCDataset='Pascal VOC',
  89. WIDERFaceDataset='WIDER Face',
  90. OpenImagesDataset='OpenImagesDataset',
  91. OpenImagesChallengeDataset='OpenImagesChallengeDataset',
  92. Objects365V1Dataset='Objects365 v1',
  93. Objects365V2Dataset='Objects365 v2')
  94. cfg = Config.fromfile('./configs/' + config)
  95. return name_map[cfg.dataset_type]
  96. def find_last_dir(model_dir):
  97. dst_times = []
  98. for time_stamp in os.scandir(model_dir):
  99. if osp.isdir(time_stamp):
  100. dst_time = time.mktime(
  101. time.strptime(time_stamp.name, '%Y%m%d_%H%M%S'))
  102. dst_times.append([dst_time, time_stamp.name])
  103. return max(dst_times, key=lambda x: x[0])[1]
  104. def convert_model_info_to_pwc(model_infos):
  105. pwc_files = {}
  106. for model in model_infos:
  107. cfg_folder_name = osp.split(model['config'])[-2]
  108. pwc_model_info = OrderedDict()
  109. pwc_model_info['Name'] = osp.split(model['config'])[-1].split('.')[0]
  110. pwc_model_info['In Collection'] = 'Please fill in Collection name'
  111. pwc_model_info['Config'] = osp.join('configs', model['config'])
  112. # get metadata
  113. meta_data = OrderedDict()
  114. if 'epochs' in model:
  115. meta_data['Epochs'] = get_real_epoch_or_iter(model['config'])
  116. else:
  117. meta_data['Iterations'] = get_real_epoch_or_iter(model['config'])
  118. pwc_model_info['Metadata'] = meta_data
  119. # get dataset name
  120. dataset_name = get_dataset_name(model['config'])
  121. # get results
  122. results = []
  123. # if there are more metrics, add here.
  124. if 'bbox_mAP' in model['results']:
  125. metric = round(model['results']['bbox_mAP'] * 100, 1)
  126. results.append(
  127. OrderedDict(
  128. Task='Object Detection',
  129. Dataset=dataset_name,
  130. Metrics={'box AP': metric}))
  131. if 'segm_mAP' in model['results']:
  132. metric = round(model['results']['segm_mAP'] * 100, 1)
  133. results.append(
  134. OrderedDict(
  135. Task='Instance Segmentation',
  136. Dataset=dataset_name,
  137. Metrics={'mask AP': metric}))
  138. if 'PQ' in model['results']:
  139. metric = round(model['results']['PQ'], 1)
  140. results.append(
  141. OrderedDict(
  142. Task='Panoptic Segmentation',
  143. Dataset=dataset_name,
  144. Metrics={'PQ': metric}))
  145. pwc_model_info['Results'] = results
  146. link_string = 'https://download.openmmlab.com/mmdetection/v3.0/'
  147. link_string += '{}/{}'.format(model['config'].rstrip('.py'),
  148. osp.split(model['model_path'])[-1])
  149. pwc_model_info['Weights'] = link_string
  150. if cfg_folder_name in pwc_files:
  151. pwc_files[cfg_folder_name].append(pwc_model_info)
  152. else:
  153. pwc_files[cfg_folder_name] = [pwc_model_info]
  154. return pwc_files
  155. def parse_args():
  156. parser = argparse.ArgumentParser(description='Gather benchmarked models')
  157. parser.add_argument(
  158. 'root',
  159. type=str,
  160. default='work_dirs',
  161. help='root path of benchmarked models to be gathered')
  162. parser.add_argument(
  163. '--out',
  164. type=str,
  165. default='gather',
  166. help='output path of gathered models to be stored')
  167. parser.add_argument(
  168. '--best',
  169. action='store_true',
  170. help='whether to gather the best model.')
  171. args = parser.parse_args()
  172. return args
  173. def main():
  174. args = parse_args()
  175. models_root = args.root
  176. models_out = args.out
  177. mkdir_or_exist(models_out)
  178. # find all models in the root directory to be gathered
  179. raw_configs = list(scandir('./configs', '.py', recursive=True))
  180. # filter configs that is not trained in the experiments dir
  181. used_configs = []
  182. for raw_config in raw_configs:
  183. if osp.exists(osp.join(models_root, raw_config)):
  184. used_configs.append(raw_config)
  185. print(f'Find {len(used_configs)} models to be gathered')
  186. # find final_ckpt and log file for trained each config
  187. # and parse the best performance
  188. model_infos = []
  189. for used_config in used_configs:
  190. exp_dir = osp.join(models_root, used_config)
  191. by_epoch = is_by_epoch(used_config)
  192. # check whether the exps is finished
  193. if args.best is True:
  194. final_model, final_epoch_or_iter = get_best_epoch_or_iter(exp_dir)
  195. else:
  196. final_epoch_or_iter = get_final_epoch_or_iter(used_config)
  197. final_model = '{}_{}.pth'.format('epoch' if by_epoch else 'iter',
  198. final_epoch_or_iter)
  199. model_path = osp.join(exp_dir, final_model)
  200. # skip if the model is still training
  201. if not osp.exists(model_path):
  202. continue
  203. # get the latest logs
  204. latest_exp_name = find_last_dir(exp_dir)
  205. latest_exp_json = osp.join(exp_dir, latest_exp_name, 'vis_data',
  206. latest_exp_name + '.json')
  207. model_performance = get_final_results(
  208. latest_exp_json, final_epoch_or_iter, by_epoch=by_epoch)
  209. if model_performance is None:
  210. continue
  211. model_info = dict(
  212. config=used_config,
  213. results=model_performance,
  214. final_model=final_model,
  215. latest_exp_json=latest_exp_json,
  216. latest_exp_name=latest_exp_name)
  217. model_info['epochs' if by_epoch else 'iterations'] =\
  218. final_epoch_or_iter
  219. model_infos.append(model_info)
  220. # publish model for each checkpoint
  221. publish_model_infos = []
  222. for model in model_infos:
  223. model_publish_dir = osp.join(models_out, model['config'].rstrip('.py'))
  224. mkdir_or_exist(model_publish_dir)
  225. model_name = osp.split(model['config'])[-1].split('.')[0]
  226. model_name += '_' + model['latest_exp_name']
  227. publish_model_path = osp.join(model_publish_dir, model_name)
  228. trained_model_path = osp.join(models_root, model['config'],
  229. model['final_model'])
  230. # convert model
  231. final_model_path = process_checkpoint(trained_model_path,
  232. publish_model_path)
  233. # copy log
  234. shutil.copy(model['latest_exp_json'],
  235. osp.join(model_publish_dir, f'{model_name}.log.json'))
  236. # copy config to guarantee reproducibility
  237. config_path = model['config']
  238. config_path = osp.join(
  239. 'configs',
  240. config_path) if 'configs' not in config_path else config_path
  241. target_config_path = osp.split(config_path)[-1]
  242. shutil.copy(config_path, osp.join(model_publish_dir,
  243. target_config_path))
  244. model['model_path'] = final_model_path
  245. publish_model_infos.append(model)
  246. models = dict(models=publish_model_infos)
  247. print(f'Totally gathered {len(publish_model_infos)} models')
  248. dump(models, osp.join(models_out, 'model_info.json'))
  249. pwc_files = convert_model_info_to_pwc(publish_model_infos)
  250. for name in pwc_files:
  251. with open(osp.join(models_out, name + '_metafile.yml'), 'w') as f:
  252. ordered_yaml_dump(pwc_files[name], f, encoding='utf-8')
  253. if __name__ == '__main__':
  254. main()