lvis_metric.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import itertools
  3. import os.path as osp
  4. import tempfile
  5. import warnings
  6. from collections import OrderedDict
  7. from typing import Dict, List, Optional, Sequence, Union
  8. import numpy as np
  9. from mmengine.fileio import get_local_path
  10. from mmengine.logging import MMLogger
  11. from terminaltables import AsciiTable
  12. from mmdet.registry import METRICS
  13. from mmdet.structures.mask import encode_mask_results
  14. from ..functional import eval_recalls
  15. from .coco_metric import CocoMetric
  16. try:
  17. import lvis
  18. if getattr(lvis, '__version__', '0') >= '10.5.3':
  19. warnings.warn(
  20. 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
  21. UserWarning)
  22. from lvis import LVIS, LVISEval, LVISResults
  23. except ImportError:
  24. lvis = None
  25. LVISEval = None
  26. LVISResults = None
  27. @METRICS.register_module()
  28. class LVISMetric(CocoMetric):
  29. """LVIS evaluation metric.
  30. Args:
  31. ann_file (str, optional): Path to the coco format annotation file.
  32. If not specified, ground truth annotations from the dataset will
  33. be converted to coco format. Defaults to None.
  34. metric (str | List[str]): Metrics to be evaluated. Valid metrics
  35. include 'bbox', 'segm', 'proposal', and 'proposal_fast'.
  36. Defaults to 'bbox'.
  37. classwise (bool): Whether to evaluate the metric class-wise.
  38. Defaults to False.
  39. proposal_nums (Sequence[int]): Numbers of proposals to be evaluated.
  40. Defaults to (100, 300, 1000).
  41. iou_thrs (float | List[float], optional): IoU threshold to compute AP
  42. and AR. If not specified, IoUs from 0.5 to 0.95 will be used.
  43. Defaults to None.
  44. metric_items (List[str], optional): Metric result names to be
  45. recorded in the evaluation result. Defaults to None.
  46. format_only (bool): Format the output results without perform
  47. evaluation. It is useful when you want to format the result
  48. to a specific format and submit it to the test server.
  49. Defaults to False.
  50. outfile_prefix (str, optional): The prefix of json files. It includes
  51. the file path and the prefix of filename, e.g., "a/b/prefix".
  52. If not specified, a temp file will be created. Defaults to None.
  53. collect_device (str): Device name used for collecting results from
  54. different ranks during distributed training. Must be 'cpu' or
  55. 'gpu'. Defaults to 'cpu'.
  56. prefix (str, optional): The prefix that will be added in the metric
  57. names to disambiguate homonymous metrics of different evaluators.
  58. If prefix is not provided in the argument, self.default_prefix
  59. will be used instead. Defaults to None.
  60. file_client_args (dict, optional): Arguments to instantiate the
  61. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  62. backend_args (dict, optional): Arguments to instantiate the
  63. corresponding backend. Defaults to None.
  64. """
  65. default_prefix: Optional[str] = 'lvis'
  66. def __init__(self,
  67. ann_file: Optional[str] = None,
  68. metric: Union[str, List[str]] = 'bbox',
  69. classwise: bool = False,
  70. proposal_nums: Sequence[int] = (100, 300, 1000),
  71. iou_thrs: Optional[Union[float, Sequence[float]]] = None,
  72. metric_items: Optional[Sequence[str]] = None,
  73. format_only: bool = False,
  74. outfile_prefix: Optional[str] = None,
  75. collect_device: str = 'cpu',
  76. prefix: Optional[str] = None,
  77. file_client_args: dict = None,
  78. backend_args: dict = None) -> None:
  79. if lvis is None:
  80. raise RuntimeError(
  81. 'Package lvis is not installed. Please run "pip install '
  82. 'git+https://github.com/lvis-dataset/lvis-api.git".')
  83. super().__init__(collect_device=collect_device, prefix=prefix)
  84. # coco evaluation metrics
  85. self.metrics = metric if isinstance(metric, list) else [metric]
  86. allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
  87. for metric in self.metrics:
  88. if metric not in allowed_metrics:
  89. raise KeyError(
  90. "metric should be one of 'bbox', 'segm', 'proposal', "
  91. f"'proposal_fast', but got {metric}.")
  92. # do class wise evaluation, default False
  93. self.classwise = classwise
  94. # proposal_nums used to compute recall or precision.
  95. self.proposal_nums = list(proposal_nums)
  96. # iou_thrs used to compute recall or precision.
  97. if iou_thrs is None:
  98. iou_thrs = np.linspace(
  99. .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
  100. self.iou_thrs = iou_thrs
  101. self.metric_items = metric_items
  102. self.format_only = format_only
  103. if self.format_only:
  104. assert outfile_prefix is not None, 'outfile_prefix must be not'
  105. 'None when format_only is True, otherwise the result files will'
  106. 'be saved to a temp directory which will be cleaned up at the end.'
  107. self.outfile_prefix = outfile_prefix
  108. self.backend_args = backend_args
  109. if file_client_args is not None:
  110. raise RuntimeError(
  111. 'The `file_client_args` is deprecated, '
  112. 'please use `backend_args` instead, please refer to'
  113. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  114. )
  115. # if ann_file is not specified,
  116. # initialize lvis api with the converted dataset
  117. if ann_file is not None:
  118. with get_local_path(
  119. ann_file, backend_args=self.backend_args) as local_path:
  120. self._lvis_api = LVIS(local_path)
  121. else:
  122. self._lvis_api = None
  123. # handle dataset lazy init
  124. self.cat_ids = None
  125. self.img_ids = None
  126. def fast_eval_recall(self,
  127. results: List[dict],
  128. proposal_nums: Sequence[int],
  129. iou_thrs: Sequence[float],
  130. logger: Optional[MMLogger] = None) -> np.ndarray:
  131. """Evaluate proposal recall with LVIS's fast_eval_recall.
  132. Args:
  133. results (List[dict]): Results of the dataset.
  134. proposal_nums (Sequence[int]): Proposal numbers used for
  135. evaluation.
  136. iou_thrs (Sequence[float]): IoU thresholds used for evaluation.
  137. logger (MMLogger, optional): Logger used for logging the recall
  138. summary.
  139. Returns:
  140. np.ndarray: Averaged recall results.
  141. """
  142. gt_bboxes = []
  143. pred_bboxes = [result['bboxes'] for result in results]
  144. for i in range(len(self.img_ids)):
  145. ann_ids = self._lvis_api.get_ann_ids(img_ids=[self.img_ids[i]])
  146. ann_info = self._lvis_api.load_anns(ann_ids)
  147. if len(ann_info) == 0:
  148. gt_bboxes.append(np.zeros((0, 4)))
  149. continue
  150. bboxes = []
  151. for ann in ann_info:
  152. x1, y1, w, h = ann['bbox']
  153. bboxes.append([x1, y1, x1 + w, y1 + h])
  154. bboxes = np.array(bboxes, dtype=np.float32)
  155. if bboxes.shape[0] == 0:
  156. bboxes = np.zeros((0, 4))
  157. gt_bboxes.append(bboxes)
  158. recalls = eval_recalls(
  159. gt_bboxes, pred_bboxes, proposal_nums, iou_thrs, logger=logger)
  160. ar = recalls.mean(axis=1)
  161. return ar
  162. # TODO: data_batch is no longer needed, consider adjusting the
  163. # parameter position
  164. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  165. """Process one batch of data samples and predictions. The processed
  166. results should be stored in ``self.results``, which will be used to
  167. compute the metrics when all batches have been processed.
  168. Args:
  169. data_batch (dict): A batch of data from the dataloader.
  170. data_samples (Sequence[dict]): A batch of data samples that
  171. contain annotations and predictions.
  172. """
  173. for data_sample in data_samples:
  174. result = dict()
  175. pred = data_sample['pred_instances']
  176. result['img_id'] = data_sample['img_id']
  177. result['bboxes'] = pred['bboxes'].cpu().numpy()
  178. result['scores'] = pred['scores'].cpu().numpy()
  179. result['labels'] = pred['labels'].cpu().numpy()
  180. # encode mask to RLE
  181. if 'masks' in pred:
  182. result['masks'] = encode_mask_results(
  183. pred['masks'].detach().cpu().numpy())
  184. # some detectors use different scores for bbox and mask
  185. if 'mask_scores' in pred:
  186. result['mask_scores'] = pred['mask_scores'].cpu().numpy()
  187. # parse gt
  188. gt = dict()
  189. gt['width'] = data_sample['ori_shape'][1]
  190. gt['height'] = data_sample['ori_shape'][0]
  191. gt['img_id'] = data_sample['img_id']
  192. if self._lvis_api is None:
  193. # TODO: Need to refactor to support LoadAnnotations
  194. assert 'instances' in data_sample, \
  195. 'ground truth is required for evaluation when ' \
  196. '`ann_file` is not provided'
  197. gt['anns'] = data_sample['instances']
  198. # add converted result to the results list
  199. self.results.append((gt, result))
  200. def compute_metrics(self, results: list) -> Dict[str, float]:
  201. """Compute the metrics from processed results.
  202. Args:
  203. results (list): The processed results of each batch.
  204. Returns:
  205. Dict[str, float]: The computed metrics. The keys are the names of
  206. the metrics, and the values are corresponding results.
  207. """
  208. logger: MMLogger = MMLogger.get_current_instance()
  209. # split gt and prediction list
  210. gts, preds = zip(*results)
  211. tmp_dir = None
  212. if self.outfile_prefix is None:
  213. tmp_dir = tempfile.TemporaryDirectory()
  214. outfile_prefix = osp.join(tmp_dir.name, 'results')
  215. else:
  216. outfile_prefix = self.outfile_prefix
  217. if self._lvis_api is None:
  218. # use converted gt json file to initialize coco api
  219. logger.info('Converting ground truth to coco format...')
  220. coco_json_path = self.gt_to_coco_json(
  221. gt_dicts=gts, outfile_prefix=outfile_prefix)
  222. self._lvis_api = LVIS(coco_json_path)
  223. # handle lazy init
  224. if self.cat_ids is None:
  225. self.cat_ids = self._lvis_api.get_cat_ids()
  226. if self.img_ids is None:
  227. self.img_ids = self._lvis_api.get_img_ids()
  228. # convert predictions to coco format and dump to json file
  229. result_files = self.results2json(preds, outfile_prefix)
  230. eval_results = OrderedDict()
  231. if self.format_only:
  232. logger.info('results are saved in '
  233. f'{osp.dirname(outfile_prefix)}')
  234. return eval_results
  235. lvis_gt = self._lvis_api
  236. for metric in self.metrics:
  237. logger.info(f'Evaluating {metric}...')
  238. # TODO: May refactor fast_eval_recall to an independent metric?
  239. # fast eval recall
  240. if metric == 'proposal_fast':
  241. ar = self.fast_eval_recall(
  242. preds, self.proposal_nums, self.iou_thrs, logger=logger)
  243. log_msg = []
  244. for i, num in enumerate(self.proposal_nums):
  245. eval_results[f'AR@{num}'] = ar[i]
  246. log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
  247. log_msg = ''.join(log_msg)
  248. logger.info(log_msg)
  249. continue
  250. try:
  251. lvis_dt = LVISResults(lvis_gt, result_files[metric])
  252. except IndexError:
  253. logger.info(
  254. 'The testing results of the whole dataset is empty.')
  255. break
  256. iou_type = 'bbox' if metric == 'proposal' else metric
  257. lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type)
  258. lvis_eval.params.imgIds = self.img_ids
  259. metric_items = self.metric_items
  260. if metric == 'proposal':
  261. lvis_eval.params.useCats = 0
  262. lvis_eval.params.maxDets = list(self.proposal_nums)
  263. lvis_eval.evaluate()
  264. lvis_eval.accumulate()
  265. lvis_eval.summarize()
  266. if metric_items is None:
  267. metric_items = ['AR@300', 'ARs@300', 'ARm@300', 'ARl@300']
  268. for k, v in lvis_eval.get_results().items():
  269. if k in metric_items:
  270. val = float('{:.3f}'.format(float(v)))
  271. eval_results[k] = val
  272. else:
  273. lvis_eval.evaluate()
  274. lvis_eval.accumulate()
  275. lvis_eval.summarize()
  276. lvis_results = lvis_eval.get_results()
  277. if self.classwise: # Compute per-category AP
  278. # Compute per-category AP
  279. # from https://github.com/facebookresearch/detectron2/
  280. precisions = lvis_eval.eval['precision']
  281. # precision: (iou, recall, cls, area range, max dets)
  282. assert len(self.cat_ids) == precisions.shape[2]
  283. results_per_category = []
  284. for idx, catId in enumerate(self.cat_ids):
  285. # area range index 0: all area ranges
  286. # max dets index -1: typically 100 per image
  287. # the dimensions of precisions are
  288. # [num_thrs, num_recalls, num_cats, num_area_rngs]
  289. nm = self._lvis_api.load_cats([catId])[0]
  290. precision = precisions[:, :, idx, 0]
  291. precision = precision[precision > -1]
  292. if precision.size:
  293. ap = np.mean(precision)
  294. else:
  295. ap = float('nan')
  296. results_per_category.append(
  297. (f'{nm["name"]}', f'{float(ap):0.3f}'))
  298. eval_results[f'{nm["name"]}_precision'] = round(ap, 3)
  299. num_columns = min(6, len(results_per_category) * 2)
  300. results_flatten = list(
  301. itertools.chain(*results_per_category))
  302. headers = ['category', 'AP'] * (num_columns // 2)
  303. results_2d = itertools.zip_longest(*[
  304. results_flatten[i::num_columns]
  305. for i in range(num_columns)
  306. ])
  307. table_data = [headers]
  308. table_data += [result for result in results_2d]
  309. table = AsciiTable(table_data)
  310. logger.info('\n' + table.table)
  311. if metric_items is None:
  312. metric_items = [
  313. 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'APr',
  314. 'APc', 'APf'
  315. ]
  316. for k, v in lvis_results.items():
  317. if k in metric_items:
  318. key = '{}_{}'.format(metric, k)
  319. val = float('{:.3f}'.format(float(v)))
  320. eval_results[key] = val
  321. lvis_eval.print_results()
  322. if tmp_dir is not None:
  323. tmp_dir.cleanup()
  324. return eval_results