coco_occluded_metric.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Union
  3. import mmengine
  4. import numpy as np
  5. from mmengine.fileio import load
  6. from mmengine.logging import print_log
  7. from pycocotools import mask as coco_mask
  8. from terminaltables import AsciiTable
  9. from mmdet.registry import METRICS
  10. from .coco_metric import CocoMetric
  11. @METRICS.register_module()
  12. class CocoOccludedSeparatedMetric(CocoMetric):
  13. """Metric of separated and occluded masks which presented in paper `A Tri-
  14. Layer Plugin to Improve Occluded Detection.
  15. <https://arxiv.org/abs/2210.10046>`_.
  16. Separated COCO and Occluded COCO are automatically generated subsets of
  17. COCO val dataset, collecting separated objects and partially occluded
  18. objects for a large variety of categories. In this way, we define
  19. occlusion into two major categories: separated and partially occluded.
  20. - Separation: target object segmentation mask is separated into distinct
  21. regions by the occluder.
  22. - Partial Occlusion: target object is partially occluded but the
  23. segmentation mask is connected.
  24. These two new scalable real-image datasets are to benchmark a model's
  25. capability to detect occluded objects of 80 common categories.
  26. Please cite the paper if you use this dataset:
  27. @article{zhan2022triocc,
  28. title={A Tri-Layer Plugin to Improve Occluded Detection},
  29. author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew},
  30. journal={British Machine Vision Conference},
  31. year={2022}
  32. }
  33. Args:
  34. occluded_ann (str): Path to the occluded coco annotation file.
  35. separated_ann (str): Path to the separated coco annotation file.
  36. score_thr (float): Score threshold of the detection masks.
  37. Defaults to 0.3.
  38. iou_thr (float): IoU threshold for the recall calculation.
  39. Defaults to 0.75.
  40. metric (str | List[str]): Metrics to be evaluated. Valid metrics
  41. include 'bbox', 'segm', 'proposal', and 'proposal_fast'.
  42. Defaults to 'bbox'.
  43. """
  44. default_prefix: Optional[str] = 'coco'
  45. def __init__(
  46. self,
  47. *args,
  48. occluded_ann:
  49. str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/occluded_coco.pkl', # noqa
  50. separated_ann:
  51. str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/separated_coco.pkl', # noqa
  52. score_thr: float = 0.3,
  53. iou_thr: float = 0.75,
  54. metric: Union[str, List[str]] = ['bbox', 'segm'],
  55. **kwargs) -> None:
  56. super().__init__(*args, metric=metric, **kwargs)
  57. self.occluded_ann = load(occluded_ann)
  58. self.separated_ann = load(separated_ann)
  59. self.score_thr = score_thr
  60. self.iou_thr = iou_thr
  61. def compute_metrics(self, results: list) -> Dict[str, float]:
  62. """Compute the metrics from processed results.
  63. Args:
  64. results (list): The processed results of each batch.
  65. Returns:
  66. Dict[str, float]: The computed metrics. The keys are the names of
  67. the metrics, and the values are corresponding results.
  68. """
  69. coco_metric_res = super().compute_metrics(results)
  70. eval_res = self.evaluate_occluded_separated(results)
  71. coco_metric_res.update(eval_res)
  72. return coco_metric_res
  73. def evaluate_occluded_separated(self, results: List[tuple]) -> dict:
  74. """Compute the recall of occluded and separated masks.
  75. Args:
  76. results (list[tuple]): Testing results of the dataset.
  77. Returns:
  78. dict[str, float]: The recall of occluded and separated masks.
  79. """
  80. dict_det = {}
  81. print_log('processing detection results...')
  82. prog_bar = mmengine.ProgressBar(len(results))
  83. for i in range(len(results)):
  84. gt, dt = results[i]
  85. img_id = dt['img_id']
  86. cur_img_name = self._coco_api.imgs[img_id]['file_name']
  87. if cur_img_name not in dict_det.keys():
  88. dict_det[cur_img_name] = []
  89. for bbox, score, label, mask in zip(dt['bboxes'], dt['scores'],
  90. dt['labels'], dt['masks']):
  91. cur_binary_mask = coco_mask.decode(mask)
  92. dict_det[cur_img_name].append([
  93. score, self.dataset_meta['classes'][label],
  94. cur_binary_mask, bbox
  95. ])
  96. dict_det[cur_img_name].sort(
  97. key=lambda x: (-x[0], x[3][0], x[3][1])
  98. ) # rank by confidence from high to low, avoid same confidence
  99. prog_bar.update()
  100. print_log('\ncomputing occluded mask recall...', logger='current')
  101. occluded_correct_num, occluded_recall = self.compute_recall(
  102. dict_det, gt_ann=self.occluded_ann, is_occ=True)
  103. print_log(
  104. f'\nCOCO occluded mask recall: {occluded_recall:.2f}%',
  105. logger='current')
  106. print_log(
  107. f'COCO occluded mask success num: {occluded_correct_num}',
  108. logger='current')
  109. print_log('computing separated mask recall...', logger='current')
  110. separated_correct_num, separated_recall = self.compute_recall(
  111. dict_det, gt_ann=self.separated_ann, is_occ=False)
  112. print_log(
  113. f'\nCOCO separated mask recall: {separated_recall:.2f}%',
  114. logger='current')
  115. print_log(
  116. f'COCO separated mask success num: {separated_correct_num}',
  117. logger='current')
  118. table_data = [
  119. ['mask type', 'recall', 'num correct'],
  120. ['occluded', f'{occluded_recall:.2f}%', occluded_correct_num],
  121. ['separated', f'{separated_recall:.2f}%', separated_correct_num]
  122. ]
  123. table = AsciiTable(table_data)
  124. print_log('\n' + table.table, logger='current')
  125. return dict(
  126. occluded_recall=occluded_recall, separated_recall=separated_recall)
  127. def compute_recall(self,
  128. result_dict: dict,
  129. gt_ann: list,
  130. is_occ: bool = True) -> tuple:
  131. """Compute the recall of occluded or separated masks.
  132. Args:
  133. result_dict (dict): Processed mask results.
  134. gt_ann (list): Occluded or separated coco annotations.
  135. is_occ (bool): Whether the annotation is occluded mask.
  136. Defaults to True.
  137. Returns:
  138. tuple: number of correct masks and the recall.
  139. """
  140. correct = 0
  141. prog_bar = mmengine.ProgressBar(len(gt_ann))
  142. for iter_i in range(len(gt_ann)):
  143. cur_item = gt_ann[iter_i]
  144. cur_img_name = cur_item[0]
  145. cur_gt_bbox = cur_item[3]
  146. if is_occ:
  147. cur_gt_bbox = [
  148. cur_gt_bbox[0], cur_gt_bbox[1],
  149. cur_gt_bbox[0] + cur_gt_bbox[2],
  150. cur_gt_bbox[1] + cur_gt_bbox[3]
  151. ]
  152. cur_gt_class = cur_item[1]
  153. cur_gt_mask = coco_mask.decode(cur_item[4])
  154. assert cur_img_name in result_dict.keys()
  155. cur_detections = result_dict[cur_img_name]
  156. correct_flag = False
  157. for i in range(len(cur_detections)):
  158. cur_det_confidence = cur_detections[i][0]
  159. if cur_det_confidence < self.score_thr:
  160. break
  161. cur_det_class = cur_detections[i][1]
  162. if cur_det_class != cur_gt_class:
  163. continue
  164. cur_det_mask = cur_detections[i][2]
  165. cur_iou = self.mask_iou(cur_det_mask, cur_gt_mask)
  166. if cur_iou >= self.iou_thr:
  167. correct_flag = True
  168. break
  169. if correct_flag:
  170. correct += 1
  171. prog_bar.update()
  172. recall = correct / len(gt_ann) * 100
  173. return correct, recall
  174. def mask_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
  175. """Compute IoU between two masks."""
  176. mask1_area = np.count_nonzero(mask1 == 1)
  177. mask2_area = np.count_nonzero(mask2 == 1)
  178. intersection = np.count_nonzero(np.logical_and(mask1 == 1, mask2 == 1))
  179. iou = intersection / (mask1_area + mask2_area - intersection)
  180. return iou