coco_video_metric.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import Sequence
  4. from mmengine.dist import broadcast_object_list, is_main_process
  5. from mmdet.registry import METRICS
  6. from .base_video_metric import collect_tracking_results
  7. from .coco_metric import CocoMetric
  8. @METRICS.register_module()
  9. class CocoVideoMetric(CocoMetric):
  10. """COCO evaluation metric.
  11. Evaluate AR, AP, and mAP for detection tasks including proposal/box
  12. detection and instance segmentation. Please refer to
  13. https://cocodataset.org/#detection-eval for more details.
  14. """
  15. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  16. """Process one batch of data samples and predictions.
  17. The processed results should be stored in ``self.results``, which will
  18. be used to compute the metrics when all batches have been processed.
  19. Args:
  20. data_batch (dict): A batch of data from the dataloader.
  21. data_samples (Sequence[dict]): A batch of data samples that
  22. contain annotations and predictions.
  23. """
  24. for track_data_sample in data_samples:
  25. video_data_samples = track_data_sample['video_data_samples']
  26. ori_video_len = video_data_samples[0].ori_video_length
  27. video_len = len(video_data_samples)
  28. if ori_video_len == video_len:
  29. # video process
  30. for frame_id in range(video_len):
  31. img_data_sample = video_data_samples[frame_id].to_dict()
  32. super().process(None, [img_data_sample])
  33. else:
  34. # image process
  35. img_data_sample = video_data_samples[0].to_dict()
  36. super().process(None, [img_data_sample])
  37. def evaluate(self, size: int = 1) -> dict:
  38. """Evaluate the model performance of the whole dataset after processing
  39. all batches.
  40. Args:
  41. size (int): Length of the entire validation dataset.
  42. Returns:
  43. dict: Evaluation metrics dict on the val dataset. The keys are the
  44. names of the metrics, and the values are corresponding results.
  45. """
  46. if len(self.results) == 0:
  47. warnings.warn(
  48. f'{self.__class__.__name__} got empty `self.results`. Please '
  49. 'ensure that the processed results are properly added into '
  50. '`self.results` in `process` method.')
  51. results = collect_tracking_results(self.results, self.collect_device)
  52. if is_main_process():
  53. _metrics = self.compute_metrics(results) # type: ignore
  54. # Add prefix to metric names
  55. if self.prefix:
  56. _metrics = {
  57. '/'.join((self.prefix, k)): v
  58. for k, v in _metrics.items()
  59. }
  60. metrics = [_metrics]
  61. else:
  62. metrics = [None] # type: ignore
  63. broadcast_object_list(metrics)
  64. # reset the results list
  65. self.results.clear()
  66. return metrics[0]