openimages_metric.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from collections import OrderedDict
  4. from typing import List, Optional, Sequence, Union
  5. import numpy as np
  6. from mmengine.evaluator import BaseMetric
  7. from mmengine.logging import MMLogger, print_log
  8. from mmdet.registry import METRICS
  9. from ..functional import eval_map
  10. @METRICS.register_module()
  11. class OpenImagesMetric(BaseMetric):
  12. """OpenImages evaluation metric.
  13. Evaluate detection mAP for OpenImages. Please refer to
  14. https://storage.googleapis.com/openimages/web/evaluation.html for more
  15. details.
  16. Args:
  17. iou_thrs (float or List[float]): IoU threshold. Defaults to 0.5.
  18. ioa_thrs (float or List[float]): IoA threshold. Defaults to 0.5.
  19. scale_ranges (List[tuple], optional): Scale ranges for evaluating
  20. mAP. If not specified, all bounding boxes would be included in
  21. evaluation. Defaults to None
  22. use_group_of (bool): Whether consider group of groud truth bboxes
  23. during evaluating. Defaults to True.
  24. get_supercategory (bool): Whether to get parent class of the
  25. current class. Default: True.
  26. filter_labels (bool): Whether filter unannotated classes.
  27. Default: True.
  28. collect_device (str): Device name used for collecting results from
  29. different ranks during distributed training. Must be 'cpu' or
  30. 'gpu'. Defaults to 'cpu'.
  31. prefix (str, optional): The prefix that will be added in the metric
  32. names to disambiguate homonymous metrics of different evaluators.
  33. If prefix is not provided in the argument, self.default_prefix
  34. will be used instead. Defaults to None.
  35. """
  36. default_prefix: Optional[str] = 'openimages'
  37. def __init__(self,
  38. iou_thrs: Union[float, List[float]] = 0.5,
  39. ioa_thrs: Union[float, List[float]] = 0.5,
  40. scale_ranges: Optional[List[tuple]] = None,
  41. use_group_of: bool = True,
  42. get_supercategory: bool = True,
  43. filter_labels: bool = True,
  44. collect_device: str = 'cpu',
  45. prefix: Optional[str] = None) -> None:
  46. super().__init__(collect_device=collect_device, prefix=prefix)
  47. self.iou_thrs = [iou_thrs] if isinstance(iou_thrs, float) else iou_thrs
  48. self.ioa_thrs = [ioa_thrs] if (isinstance(ioa_thrs, float)
  49. or ioa_thrs is None) else ioa_thrs
  50. assert isinstance(self.iou_thrs, list) and isinstance(
  51. self.ioa_thrs, list)
  52. assert len(self.iou_thrs) == len(self.ioa_thrs)
  53. self.scale_ranges = scale_ranges
  54. self.use_group_of = use_group_of
  55. self.get_supercategory = get_supercategory
  56. self.filter_labels = filter_labels
  57. def _get_supercategory_ann(self, instances: List[dict]) -> List[dict]:
  58. """Get parent classes's annotation of the corresponding class.
  59. Args:
  60. instances (List[dict]): A list of annotations of the instances.
  61. Returns:
  62. List[dict]: Annotations extended with super-category.
  63. """
  64. supercat_instances = []
  65. relation_matrix = self.dataset_meta['RELATION_MATRIX']
  66. for instance in instances:
  67. labels = np.where(relation_matrix[instance['bbox_label']])[0]
  68. for label in labels:
  69. if label == instance['bbox_label']:
  70. continue
  71. new_instance = copy.deepcopy(instance)
  72. new_instance['bbox_label'] = label
  73. supercat_instances.append(new_instance)
  74. return supercat_instances
  75. def _process_predictions(self, pred_bboxes: np.ndarray,
  76. pred_scores: np.ndarray, pred_labels: np.ndarray,
  77. gt_instances: list,
  78. image_level_labels: np.ndarray) -> tuple:
  79. """Process results of the corresponding class of the detection bboxes.
  80. Note: It will choose to do the following two processing according to
  81. the parameters:
  82. 1. Whether to add parent classes of the corresponding class of the
  83. detection bboxes.
  84. 2. Whether to ignore the classes that unannotated on that image.
  85. Args:
  86. pred_bboxes (np.ndarray): bboxes predicted by the model
  87. pred_scores (np.ndarray): scores predicted by the model
  88. pred_labels (np.ndarray): labels predicted by the model
  89. gt_instances (list): ground truth annotations
  90. image_level_labels (np.ndarray): human-verified image level labels
  91. Returns:
  92. tuple: Processed bboxes, scores, and labels.
  93. """
  94. processed_bboxes = copy.deepcopy(pred_bboxes)
  95. processed_scores = copy.deepcopy(pred_scores)
  96. processed_labels = copy.deepcopy(pred_labels)
  97. gt_labels = np.array([ins['bbox_label'] for ins in gt_instances],
  98. dtype=np.int64)
  99. if image_level_labels is not None:
  100. allowed_classes = np.unique(
  101. np.append(gt_labels, image_level_labels))
  102. else:
  103. allowed_classes = np.unique(gt_labels)
  104. relation_matrix = self.dataset_meta['RELATION_MATRIX']
  105. pred_classes = np.unique(pred_labels)
  106. for pred_class in pred_classes:
  107. classes = np.where(relation_matrix[pred_class])[0]
  108. for cls in classes:
  109. if (cls in allowed_classes and cls != pred_class
  110. and self.get_supercategory):
  111. # add super-supercategory preds
  112. index = np.where(pred_labels == pred_class)[0]
  113. processed_scores = np.concatenate(
  114. [processed_scores, pred_scores[index]])
  115. processed_bboxes = np.concatenate(
  116. [processed_bboxes, pred_bboxes[index]])
  117. extend_labels = np.full(index.shape, cls, dtype=np.int64)
  118. processed_labels = np.concatenate(
  119. [processed_labels, extend_labels])
  120. elif cls not in allowed_classes and self.filter_labels:
  121. # remove unannotated preds
  122. index = np.where(processed_labels != cls)[0]
  123. processed_scores = processed_scores[index]
  124. processed_bboxes = processed_bboxes[index]
  125. processed_labels = processed_labels[index]
  126. return processed_bboxes, processed_scores, processed_labels
  127. # TODO: data_batch is no longer needed, consider adjusting the
  128. # parameter position
  129. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  130. """Process one batch of data samples and predictions. The processed
  131. results should be stored in ``self.results``, which will be used to
  132. compute the metrics when all batches have been processed.
  133. Args:
  134. data_batch (dict): A batch of data from the dataloader.
  135. data_samples (Sequence[dict]): A batch of data samples that
  136. contain annotations and predictions.
  137. """
  138. for data_sample in data_samples:
  139. gt = copy.deepcopy(data_sample)
  140. # add super-category instances
  141. # TODO: Need to refactor to support LoadAnnotations
  142. instances = gt['instances']
  143. if self.get_supercategory:
  144. supercat_instances = self._get_supercategory_ann(instances)
  145. instances.extend(supercat_instances)
  146. gt_labels = []
  147. gt_bboxes = []
  148. is_group_ofs = []
  149. for ins in instances:
  150. gt_labels.append(ins['bbox_label'])
  151. gt_bboxes.append(ins['bbox'])
  152. is_group_ofs.append(ins['is_group_of'])
  153. ann = dict(
  154. labels=np.array(gt_labels, dtype=np.int64),
  155. bboxes=np.array(gt_bboxes, dtype=np.float32).reshape((-1, 4)),
  156. gt_is_group_ofs=np.array(is_group_ofs, dtype=bool))
  157. image_level_labels = gt.get('image_level_labels', None)
  158. pred = data_sample['pred_instances']
  159. pred_bboxes = pred['bboxes'].cpu().numpy()
  160. pred_scores = pred['scores'].cpu().numpy()
  161. pred_labels = pred['labels'].cpu().numpy()
  162. pred_bboxes, pred_scores, pred_labels = self._process_predictions(
  163. pred_bboxes, pred_scores, pred_labels, instances,
  164. image_level_labels)
  165. dets = []
  166. for label in range(len(self.dataset_meta['classes'])):
  167. index = np.where(pred_labels == label)[0]
  168. pred_bbox_scores = np.hstack(
  169. [pred_bboxes[index], pred_scores[index].reshape((-1, 1))])
  170. dets.append(pred_bbox_scores)
  171. self.results.append((ann, dets))
  172. def compute_metrics(self, results: list) -> dict:
  173. """Compute the metrics from processed results.
  174. Args:
  175. results (list): The processed results of each batch.
  176. Returns:
  177. dict: The computed metrics. The keys are the names of the metrics,
  178. and the values are corresponding results.
  179. """
  180. logger = MMLogger.get_current_instance()
  181. gts, preds = zip(*results)
  182. eval_results = OrderedDict()
  183. # get dataset type
  184. dataset_type = self.dataset_meta.get('dataset_type')
  185. if dataset_type not in ['oid_challenge', 'oid_v6']:
  186. dataset_type = 'oid_v6'
  187. print_log(
  188. 'Cannot infer dataset type from the length of the'
  189. ' classes. Set `oid_v6` as dataset type.',
  190. logger='current')
  191. mean_aps = []
  192. for i, (iou_thr,
  193. ioa_thr) in enumerate(zip(self.iou_thrs, self.ioa_thrs)):
  194. if self.use_group_of:
  195. assert ioa_thr is not None, 'ioa_thr must have value when' \
  196. ' using group_of in evaluation.'
  197. print_log(f'\n{"-" * 15}iou_thr, ioa_thr: {iou_thr}, {ioa_thr}'
  198. f'{"-" * 15}')
  199. mean_ap, _ = eval_map(
  200. preds,
  201. gts,
  202. scale_ranges=self.scale_ranges,
  203. iou_thr=iou_thr,
  204. ioa_thr=ioa_thr,
  205. dataset=dataset_type,
  206. logger=logger,
  207. use_group_of=self.use_group_of)
  208. mean_aps.append(mean_ap)
  209. eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
  210. eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
  211. return eval_results