benchmark.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import time
  4. from functools import partial
  5. from typing import List, Optional, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from mmcv.cnn import fuse_conv_bn
  10. # TODO need update
  11. # from mmcv.runner import wrap_fp16_model
  12. from mmengine import MMLogger
  13. from mmengine.config import Config
  14. from mmengine.device import get_max_cuda_memory
  15. from mmengine.dist import get_world_size
  16. from mmengine.runner import Runner, load_checkpoint
  17. from mmengine.utils.dl_utils import set_multi_processing
  18. from torch.nn.parallel import DistributedDataParallel
  19. from mmdet.registry import DATASETS, MODELS
  20. try:
  21. import psutil
  22. except ImportError:
  23. psutil = None
  24. def custom_round(value: Union[int, float],
  25. factor: Union[int, float],
  26. precision: int = 2) -> float:
  27. """Custom round function."""
  28. return round(value / factor, precision)
  29. gb_round = partial(custom_round, factor=1024**3)
  30. def print_log(msg: str, logger: Optional[MMLogger] = None) -> None:
  31. """Print a log message."""
  32. if logger is None:
  33. print(msg, flush=True)
  34. else:
  35. logger.info(msg)
  36. def print_process_memory(p: psutil.Process,
  37. logger: Optional[MMLogger] = None) -> None:
  38. """print process memory info."""
  39. mem_used = gb_round(psutil.virtual_memory().used)
  40. memory_full_info = p.memory_full_info()
  41. uss_mem = gb_round(memory_full_info.uss)
  42. if hasattr(memory_full_info, 'pss'):
  43. pss_mem = gb_round(memory_full_info.pss)
  44. for children in p.children():
  45. child_mem_info = children.memory_full_info()
  46. uss_mem += gb_round(child_mem_info.uss)
  47. if hasattr(child_mem_info, 'pss'):
  48. pss_mem += gb_round(child_mem_info.pss)
  49. process_count = 1 + len(p.children())
  50. log_msg = f'(GB) mem_used: {mem_used:.2f} | uss: {uss_mem:.2f} | '
  51. if hasattr(memory_full_info, 'pss'):
  52. log_msg += f'pss: {pss_mem:.2f} | '
  53. log_msg += f'total_proc: {process_count}'
  54. print_log(log_msg, logger)
  55. class BaseBenchmark:
  56. """The benchmark base class.
  57. The ``run`` method is an external calling interface, and it will
  58. call the ``run_once`` method ``repeat_num`` times for benchmarking.
  59. Finally, call the ``average_multiple_runs`` method to further process
  60. the results of multiple runs.
  61. Args:
  62. max_iter (int): maximum iterations of benchmark.
  63. log_interval (int): interval of logging.
  64. num_warmup (int): Number of Warmup.
  65. logger (MMLogger, optional): Formatted logger used to record messages.
  66. """
  67. def __init__(self,
  68. max_iter: int,
  69. log_interval: int,
  70. num_warmup: int,
  71. logger: Optional[MMLogger] = None):
  72. self.max_iter = max_iter
  73. self.log_interval = log_interval
  74. self.num_warmup = num_warmup
  75. self.logger = logger
  76. def run(self, repeat_num: int = 1) -> dict:
  77. """benchmark entry method.
  78. Args:
  79. repeat_num (int): Number of repeat benchmark.
  80. Defaults to 1.
  81. """
  82. assert repeat_num >= 1
  83. results = []
  84. for _ in range(repeat_num):
  85. results.append(self.run_once())
  86. results = self.average_multiple_runs(results)
  87. return results
  88. def run_once(self) -> dict:
  89. """Executes the benchmark once."""
  90. raise NotImplementedError()
  91. def average_multiple_runs(self, results: List[dict]) -> dict:
  92. """Average the results of multiple runs."""
  93. raise NotImplementedError()
  94. class InferenceBenchmark(BaseBenchmark):
  95. """The inference benchmark class. It will be statistical inference FPS,
  96. CUDA memory and CPU memory information.
  97. Args:
  98. cfg (mmengine.Config): config.
  99. checkpoint (str): Accept local filepath, URL, ``torchvision://xxx``,
  100. ``open-mmlab://xxx``.
  101. distributed (bool): distributed testing flag.
  102. is_fuse_conv_bn (bool): Whether to fuse conv and bn, this will
  103. slightly increase the inference speed.
  104. max_iter (int): maximum iterations of benchmark. Defaults to 2000.
  105. log_interval (int): interval of logging. Defaults to 50.
  106. num_warmup (int): Number of Warmup. Defaults to 5.
  107. logger (MMLogger, optional): Formatted logger used to record messages.
  108. """
  109. def __init__(self,
  110. cfg: Config,
  111. checkpoint: str,
  112. distributed: bool,
  113. is_fuse_conv_bn: bool,
  114. max_iter: int = 2000,
  115. log_interval: int = 50,
  116. num_warmup: int = 5,
  117. logger: Optional[MMLogger] = None):
  118. super().__init__(max_iter, log_interval, num_warmup, logger)
  119. assert get_world_size(
  120. ) == 1, 'Inference benchmark does not allow distributed multi-GPU'
  121. self.cfg = copy.deepcopy(cfg)
  122. self.distributed = distributed
  123. if psutil is None:
  124. raise ImportError('psutil is not installed, please install it by: '
  125. 'pip install psutil')
  126. self._process = psutil.Process()
  127. env_cfg = self.cfg.get('env_cfg')
  128. if env_cfg.get('cudnn_benchmark'):
  129. torch.backends.cudnn.benchmark = True
  130. mp_cfg: dict = env_cfg.get('mp_cfg', {})
  131. set_multi_processing(**mp_cfg, distributed=self.distributed)
  132. print_log('before build: ', self.logger)
  133. print_process_memory(self._process, self.logger)
  134. self.model = self._init_model(checkpoint, is_fuse_conv_bn)
  135. # Because multiple processes will occupy additional CPU resources,
  136. # FPS statistics will be more unstable when num_workers is not 0.
  137. # It is reasonable to set num_workers to 0.
  138. dataloader_cfg = cfg.test_dataloader
  139. dataloader_cfg['num_workers'] = 0
  140. dataloader_cfg['batch_size'] = 1
  141. dataloader_cfg['persistent_workers'] = False
  142. self.data_loader = Runner.build_dataloader(dataloader_cfg)
  143. print_log('after build: ', self.logger)
  144. print_process_memory(self._process, self.logger)
  145. def _init_model(self, checkpoint: str, is_fuse_conv_bn: bool) -> nn.Module:
  146. """Initialize the model."""
  147. model = MODELS.build(self.cfg.model)
  148. # TODO need update
  149. # fp16_cfg = self.cfg.get('fp16', None)
  150. # if fp16_cfg is not None:
  151. # wrap_fp16_model(model)
  152. load_checkpoint(model, checkpoint, map_location='cpu')
  153. if is_fuse_conv_bn:
  154. model = fuse_conv_bn(model)
  155. model = model.cuda()
  156. if self.distributed:
  157. model = DistributedDataParallel(
  158. model,
  159. device_ids=[torch.cuda.current_device()],
  160. broadcast_buffers=False,
  161. find_unused_parameters=False)
  162. model.eval()
  163. return model
  164. def run_once(self) -> dict:
  165. """Executes the benchmark once."""
  166. pure_inf_time = 0
  167. fps = 0
  168. for i, data in enumerate(self.data_loader):
  169. if (i + 1) % self.log_interval == 0:
  170. print_log('==================================', self.logger)
  171. torch.cuda.synchronize()
  172. start_time = time.perf_counter()
  173. with torch.no_grad():
  174. self.model(data)
  175. torch.cuda.synchronize()
  176. elapsed = time.perf_counter() - start_time
  177. if i >= self.num_warmup:
  178. pure_inf_time += elapsed
  179. if (i + 1) % self.log_interval == 0:
  180. fps = (i + 1 - self.num_warmup) / pure_inf_time
  181. cuda_memory = get_max_cuda_memory()
  182. print_log(
  183. f'Done image [{i + 1:<3}/{self.max_iter}], '
  184. f'fps: {fps:.1f} img/s, '
  185. f'times per image: {1000 / fps:.1f} ms/img, '
  186. f'cuda memory: {cuda_memory} MB', self.logger)
  187. print_process_memory(self._process, self.logger)
  188. if (i + 1) == self.max_iter:
  189. fps = (i + 1 - self.num_warmup) / pure_inf_time
  190. break
  191. return {'fps': fps}
  192. def average_multiple_runs(self, results: List[dict]) -> dict:
  193. """Average the results of multiple runs."""
  194. print_log('============== Done ==================', self.logger)
  195. fps_list_ = [round(result['fps'], 1) for result in results]
  196. avg_fps_ = sum(fps_list_) / len(fps_list_)
  197. outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_}
  198. print(fps_list_)
  199. if len(fps_list_) > 1:
  200. times_pre_image_list_ = [
  201. round(1000 / result['fps'], 1) for result in results
  202. ]
  203. avg_times_pre_image_ = sum(times_pre_image_list_) / len(
  204. times_pre_image_list_)
  205. print_log(
  206. f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, '
  207. 'times per image: '
  208. f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] '
  209. 'ms/img', self.logger)
  210. else:
  211. print_log(
  212. f'Overall fps: {fps_list_[0]:.1f} img/s, '
  213. f'times per image: {1000 / fps_list_[0]:.1f} ms/img',
  214. self.logger)
  215. print_log(f'cuda memory: {get_max_cuda_memory()} MB', self.logger)
  216. print_process_memory(self._process, self.logger)
  217. return outputs
  218. class DataLoaderBenchmark(BaseBenchmark):
  219. """The dataloader benchmark class. It will be statistical inference FPS and
  220. CPU memory information.
  221. Args:
  222. cfg (mmengine.Config): config.
  223. distributed (bool): distributed testing flag.
  224. dataset_type (str): benchmark data type, only supports ``train``,
  225. ``val`` and ``test``.
  226. max_iter (int): maximum iterations of benchmark. Defaults to 2000.
  227. log_interval (int): interval of logging. Defaults to 50.
  228. num_warmup (int): Number of Warmup. Defaults to 5.
  229. logger (MMLogger, optional): Formatted logger used to record messages.
  230. """
  231. def __init__(self,
  232. cfg: Config,
  233. distributed: bool,
  234. dataset_type: str,
  235. max_iter: int = 2000,
  236. log_interval: int = 50,
  237. num_warmup: int = 5,
  238. logger: Optional[MMLogger] = None):
  239. super().__init__(max_iter, log_interval, num_warmup, logger)
  240. assert dataset_type in ['train', 'val', 'test'], \
  241. 'dataset_type only supports train,' \
  242. f' val and test, but got {dataset_type}'
  243. assert get_world_size(
  244. ) == 1, 'Dataloader benchmark does not allow distributed multi-GPU'
  245. self.cfg = copy.deepcopy(cfg)
  246. self.distributed = distributed
  247. if psutil is None:
  248. raise ImportError('psutil is not installed, please install it by: '
  249. 'pip install psutil')
  250. self._process = psutil.Process()
  251. mp_cfg = self.cfg.get('env_cfg', {}).get('mp_cfg')
  252. if mp_cfg is not None:
  253. set_multi_processing(distributed=self.distributed, **mp_cfg)
  254. else:
  255. set_multi_processing(distributed=self.distributed)
  256. print_log('before build: ', self.logger)
  257. print_process_memory(self._process, self.logger)
  258. if dataset_type == 'train':
  259. self.data_loader = Runner.build_dataloader(cfg.train_dataloader)
  260. elif dataset_type == 'test':
  261. self.data_loader = Runner.build_dataloader(cfg.test_dataloader)
  262. else:
  263. self.data_loader = Runner.build_dataloader(cfg.val_dataloader)
  264. self.batch_size = self.data_loader.batch_size
  265. self.num_workers = self.data_loader.num_workers
  266. print_log('after build: ', self.logger)
  267. print_process_memory(self._process, self.logger)
  268. def run_once(self) -> dict:
  269. """Executes the benchmark once."""
  270. pure_inf_time = 0
  271. fps = 0
  272. # benchmark with 2000 image and take the average
  273. start_time = time.perf_counter()
  274. for i, data in enumerate(self.data_loader):
  275. elapsed = time.perf_counter() - start_time
  276. if (i + 1) % self.log_interval == 0:
  277. print_log('==================================', self.logger)
  278. if i >= self.num_warmup:
  279. pure_inf_time += elapsed
  280. if (i + 1) % self.log_interval == 0:
  281. fps = (i + 1 - self.num_warmup) / pure_inf_time
  282. print_log(
  283. f'Done batch [{i + 1:<3}/{self.max_iter}], '
  284. f'fps: {fps:.1f} batch/s, '
  285. f'times per batch: {1000 / fps:.1f} ms/batch, '
  286. f'batch size: {self.batch_size}, num_workers: '
  287. f'{self.num_workers}', self.logger)
  288. print_process_memory(self._process, self.logger)
  289. if (i + 1) == self.max_iter:
  290. fps = (i + 1 - self.num_warmup) / pure_inf_time
  291. break
  292. start_time = time.perf_counter()
  293. return {'fps': fps}
  294. def average_multiple_runs(self, results: List[dict]) -> dict:
  295. """Average the results of multiple runs."""
  296. print_log('============== Done ==================', self.logger)
  297. fps_list_ = [round(result['fps'], 1) for result in results]
  298. avg_fps_ = sum(fps_list_) / len(fps_list_)
  299. outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_}
  300. if len(fps_list_) > 1:
  301. times_pre_image_list_ = [
  302. round(1000 / result['fps'], 1) for result in results
  303. ]
  304. avg_times_pre_image_ = sum(times_pre_image_list_) / len(
  305. times_pre_image_list_)
  306. print_log(
  307. f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, '
  308. 'times per batch: '
  309. f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] '
  310. f'ms/batch, batch size: {self.batch_size}, num_workers: '
  311. f'{self.num_workers}', self.logger)
  312. else:
  313. print_log(
  314. f'Overall fps: {fps_list_[0]:.1f} batch/s, '
  315. f'times per batch: {1000 / fps_list_[0]:.1f} ms/batch, '
  316. f'batch size: {self.batch_size}, num_workers: '
  317. f'{self.num_workers}', self.logger)
  318. print_process_memory(self._process, self.logger)
  319. return outputs
  320. class DatasetBenchmark(BaseBenchmark):
  321. """The dataset benchmark class. It will be statistical inference FPS, FPS
  322. pre transform and CPU memory information.
  323. Args:
  324. cfg (mmengine.Config): config.
  325. dataset_type (str): benchmark data type, only supports ``train``,
  326. ``val`` and ``test``.
  327. max_iter (int): maximum iterations of benchmark. Defaults to 2000.
  328. log_interval (int): interval of logging. Defaults to 50.
  329. num_warmup (int): Number of Warmup. Defaults to 5.
  330. logger (MMLogger, optional): Formatted logger used to record messages.
  331. """
  332. def __init__(self,
  333. cfg: Config,
  334. dataset_type: str,
  335. max_iter: int = 2000,
  336. log_interval: int = 50,
  337. num_warmup: int = 5,
  338. logger: Optional[MMLogger] = None):
  339. super().__init__(max_iter, log_interval, num_warmup, logger)
  340. assert dataset_type in ['train', 'val', 'test'], \
  341. 'dataset_type only supports train,' \
  342. f' val and test, but got {dataset_type}'
  343. assert get_world_size(
  344. ) == 1, 'Dataset benchmark does not allow distributed multi-GPU'
  345. self.cfg = copy.deepcopy(cfg)
  346. if dataset_type == 'train':
  347. dataloader_cfg = copy.deepcopy(cfg.train_dataloader)
  348. elif dataset_type == 'test':
  349. dataloader_cfg = copy.deepcopy(cfg.test_dataloader)
  350. else:
  351. dataloader_cfg = copy.deepcopy(cfg.val_dataloader)
  352. dataset_cfg = dataloader_cfg.pop('dataset')
  353. dataset = DATASETS.build(dataset_cfg)
  354. if hasattr(dataset, 'full_init'):
  355. dataset.full_init()
  356. self.dataset = dataset
  357. def run_once(self) -> dict:
  358. """Executes the benchmark once."""
  359. pure_inf_time = 0
  360. fps = 0
  361. total_index = list(range(len(self.dataset)))
  362. np.random.shuffle(total_index)
  363. start_time = time.perf_counter()
  364. for i, idx in enumerate(total_index):
  365. if (i + 1) % self.log_interval == 0:
  366. print_log('==================================', self.logger)
  367. get_data_info_start_time = time.perf_counter()
  368. data_info = self.dataset.get_data_info(idx)
  369. get_data_info_elapsed = time.perf_counter(
  370. ) - get_data_info_start_time
  371. if (i + 1) % self.log_interval == 0:
  372. print_log(f'get_data_info - {get_data_info_elapsed * 1000} ms',
  373. self.logger)
  374. for t in self.dataset.pipeline.transforms:
  375. transform_start_time = time.perf_counter()
  376. data_info = t(data_info)
  377. transform_elapsed = time.perf_counter() - transform_start_time
  378. if (i + 1) % self.log_interval == 0:
  379. print_log(
  380. f'{t.__class__.__name__} - '
  381. f'{transform_elapsed * 1000} ms', self.logger)
  382. if data_info is None:
  383. break
  384. elapsed = time.perf_counter() - start_time
  385. if i >= self.num_warmup:
  386. pure_inf_time += elapsed
  387. if (i + 1) % self.log_interval == 0:
  388. fps = (i + 1 - self.num_warmup) / pure_inf_time
  389. print_log(
  390. f'Done img [{i + 1:<3}/{self.max_iter}], '
  391. f'fps: {fps:.1f} img/s, '
  392. f'times per img: {1000 / fps:.1f} ms/img', self.logger)
  393. if (i + 1) == self.max_iter:
  394. fps = (i + 1 - self.num_warmup) / pure_inf_time
  395. break
  396. start_time = time.perf_counter()
  397. return {'fps': fps}
  398. def average_multiple_runs(self, results: List[dict]) -> dict:
  399. """Average the results of multiple runs."""
  400. print_log('============== Done ==================', self.logger)
  401. fps_list_ = [round(result['fps'], 1) for result in results]
  402. avg_fps_ = sum(fps_list_) / len(fps_list_)
  403. outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_}
  404. if len(fps_list_) > 1:
  405. times_pre_image_list_ = [
  406. round(1000 / result['fps'], 1) for result in results
  407. ]
  408. avg_times_pre_image_ = sum(times_pre_image_list_) / len(
  409. times_pre_image_list_)
  410. print_log(
  411. f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, '
  412. 'times per img: '
  413. f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] '
  414. 'ms/img', self.logger)
  415. else:
  416. print_log(
  417. f'Overall fps: {fps_list_[0]:.1f} img/s, '
  418. f'times per img: {1000 / fps_list_[0]:.1f} ms/img',
  419. self.logger)
  420. return outputs