test_mot_challenge_metrics.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os
  4. import tempfile
  5. from unittest import TestCase
  6. import torch
  7. from mmengine.structures import BaseDataElement, InstanceData
  8. from mmdet.evaluation import MOTChallengeMetric
  9. from mmdet.structures import DetDataSample, TrackDataSample
  10. class TestMOTChallengeMetric(TestCase):
  11. def test_init(self):
  12. with self.assertRaisesRegex(KeyError, 'metric unknown is not'):
  13. MOTChallengeMetric(metric='unknown')
  14. with self.assertRaises(AssertionError):
  15. MOTChallengeMetric(benchmark='MOT21')
  16. def __del__(self):
  17. self.tmp_dir.cleanup()
  18. @staticmethod
  19. def _get_predictions_demo():
  20. instances = [{
  21. 'bbox_label': 0,
  22. 'bbox': [0, 0, 100, 100],
  23. 'ignore_flag': 0,
  24. 'instance_id': 1,
  25. 'mot_conf': 1.0,
  26. 'category_id': 1,
  27. 'visibility': 1.0
  28. }, {
  29. 'bbox_label': 0,
  30. 'bbox': [0, 0, 100, 100],
  31. 'ignore_flag': 0,
  32. 'instance_id': 2,
  33. 'mot_conf': 1.0,
  34. 'category_id': 1,
  35. 'visibility': 1.0
  36. }]
  37. instances_2 = copy.deepcopy(instances)
  38. sep = os.sep
  39. pred_instances_data = dict(
  40. bboxes=torch.tensor([
  41. [0, 0, 100, 100],
  42. [0, 0, 100, 40],
  43. ]),
  44. instances_id=torch.tensor([1, 2]),
  45. scores=torch.tensor([1.0, 1.0]))
  46. pred_instances_data_2 = copy.deepcopy(pred_instances_data)
  47. pred_instances = InstanceData(**pred_instances_data)
  48. pred_instances_2 = InstanceData(**pred_instances_data_2)
  49. img_data_sample = DetDataSample()
  50. img_data_sample.pred_track_instances = pred_instances
  51. img_data_sample.instances = instances
  52. img_data_sample.set_metainfo(
  53. dict(
  54. frame_id=0,
  55. ori_video_length=2,
  56. video_length=2,
  57. img_id=1,
  58. img_path=f'xxx{sep}MOT17-09-DPM{sep}img1{sep}000001.jpg',
  59. ))
  60. img_data_sample_2 = DetDataSample()
  61. img_data_sample_2.pred_track_instances = pred_instances_2
  62. img_data_sample_2.instances = instances_2
  63. img_data_sample_2.set_metainfo(
  64. dict(
  65. frame_id=1,
  66. ori_video_length=2,
  67. video_length=2,
  68. img_id=2,
  69. img_path=f'xxx{sep}MOT17-09-DPM{sep}img1{sep}000002.jpg',
  70. ))
  71. track_data_sample = TrackDataSample()
  72. track_data_sample.video_data_samples = [
  73. img_data_sample, img_data_sample_2
  74. ]
  75. # [TrackDataSample]
  76. predictions = []
  77. if isinstance(track_data_sample, BaseDataElement):
  78. predictions.append(track_data_sample.to_dict())
  79. return predictions
  80. def _test_evaluate(self, format_only, outfile_predix=None):
  81. """Test using the metric in the same way as Evaluator."""
  82. metric = MOTChallengeMetric(
  83. metric=['HOTA', 'CLEAR', 'Identity'],
  84. format_only=format_only,
  85. outfile_prefix=outfile_predix)
  86. metric.dataset_meta = {'classes': ('pedestrian', )}
  87. data_batch = dict(input=None, data_samples=None)
  88. predictions = self._get_predictions_demo()
  89. metric.process(data_batch, predictions)
  90. eval_results = metric.evaluate()
  91. return eval_results
  92. def test_evaluate(self):
  93. eval_results = self._test_evaluate(False)
  94. target = {
  95. 'motchallenge-metric/IDF1': 0.5,
  96. 'motchallenge-metric/MOTA': 0,
  97. 'motchallenge-metric/HOTA': 0.755,
  98. 'motchallenge-metric/IDSW': 0,
  99. }
  100. for key in target:
  101. assert eval_results[key] - target[key] < 1e-3
  102. def test_evaluate_format_only(self):
  103. self.tmp_dir = tempfile.TemporaryDirectory()
  104. eval_results = self._test_evaluate(
  105. True, outfile_predix=self.tmp_dir.name)
  106. assert eval_results == dict()