test_reid_metric.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.registry import init_default_scope
  5. from mmdet.registry import METRICS
  6. from mmdet.structures import ReIDDataSample
  7. class TestReIDMetrics(TestCase):
  8. @classmethod
  9. def setUpClass(cls):
  10. init_default_scope('mmdet')
  11. def test_evaluate(self):
  12. """Test using the metric in the same way as Evaluator."""
  13. data_samples = [
  14. ReIDDataSample().set_gt_label(i).to_dict()
  15. for i in [0, 0, 1, 1, 1, 1]
  16. ]
  17. pred_batch = [
  18. dict(pred_feature=torch.tensor(
  19. [1., .0, .1])), # [x,√,x,x,x],R1=0,R5=1,AP=0.50
  20. dict(pred_feature=torch.tensor(
  21. [.8, .0, .0])), # [x,√,x,x,x],R1=0,R5=1,AP=0.50
  22. dict(pred_feature=torch.tensor(
  23. [.1, 1., .1])), # [√,√,x,√,x],R1=1,R5=1,AP≈0.92
  24. dict(pred_feature=torch.tensor(
  25. [.0, .9, .1])), # [√,√,√,x,x],R1=1,R5=1,AP=1.00
  26. dict(pred_feature=torch.tensor(
  27. [.9, .1, .0])), # [x,x,√,√,√],R1=0,R5=1,AP≈0.48
  28. dict(pred_feature=torch.tensor(
  29. [.0, .1, 1.])), # [√,√,x,√,x],R1=1,R5=1,AP≈0.92
  30. ]
  31. # get union
  32. for idx in range(len(data_samples)):
  33. data_samples[idx] = {**data_samples[idx], **pred_batch[idx]}
  34. metric = METRICS.build(
  35. dict(
  36. type='ReIDMetrics',
  37. metric=['mAP', 'CMC'],
  38. metric_options=dict(rank_list=[1, 5], max_rank=5),
  39. ))
  40. prefix = 'reid-metric'
  41. data_batch = dict(input=None, data_samples=None)
  42. metric.process(data_batch, data_samples)
  43. results = metric.evaluate(6)
  44. self.assertIsInstance(results, dict)
  45. self.assertEqual(results[f'{prefix}/mAP'], 0.719)
  46. self.assertEqual(results[f'{prefix}/R1'], 0.5)
  47. self.assertEqual(results[f'{prefix}/R5'], 1.0)