test_reid_dataset.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from unittest import TestCase
  4. from mmdet.datasets import ReIDDataset
  5. PREFIX = osp.join(osp.dirname(__file__), '../data')
  6. # This is a demo annotation file for ReIDDataset.
  7. REID_ANN_FILE = f'{PREFIX}/demo_reid_data/mot17_reid/ann.txt'
  8. class TestReIDDataset(TestCase):
  9. @classmethod
  10. def setUpClass(cls):
  11. cls.num_ids = 8
  12. cls.ins_per_id = 4
  13. cls.dataset = ReIDDataset(
  14. pipeline=[], ann_file=REID_ANN_FILE, data_prefix=dict(img_path=''))
  15. cls.dataset_triplet = ReIDDataset(
  16. pipeline=[],
  17. triplet_sampler=dict(
  18. num_ids=cls.num_ids, ins_per_id=cls.ins_per_id),
  19. ann_file=REID_ANN_FILE,
  20. data_prefix=dict(img_path=''))
  21. def test_get_data_info(self):
  22. # id 0 has 21 objects
  23. img_id = 0
  24. data_list = [
  25. self.dataset.get_data_info(i) for i in range(len(self.dataset))
  26. ]
  27. assert len([
  28. data_info for data_info in data_list
  29. if data_info['gt_label'] == img_id
  30. ]) == 21
  31. # id 11 doesn't have objects
  32. img_id = 11
  33. assert len([
  34. data_info for data_info in data_list
  35. if data_info['gt_label'] == img_id
  36. ]) == 0
  37. def test_len(self):
  38. assert len(self.dataset) == 704
  39. assert len(self.dataset_triplet) == 704
  40. def test_getitem(self):
  41. for i in range(len(self.dataset)):
  42. results = self.dataset[i]
  43. assert isinstance(results, dict) # no triplet -> dict
  44. assert 'img_path' in results
  45. assert 'gt_label' in results
  46. for i in range(len(self.dataset_triplet)):
  47. num = self.num_ids * self.ins_per_id
  48. results = self.dataset_triplet[i]
  49. assert isinstance(results, dict) # triplet -> dict
  50. assert len(results['img_path']) == num
  51. assert 'img_path' in results
  52. assert 'gt_label' in results
  53. for idx in range(num - 1):
  54. if (idx + 1) % self.ins_per_id != 0:
  55. assert results['gt_label'][idx] == \
  56. results['gt_label'][idx + 1]