test_reid_data_sample.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import LabelData
  6. from mmdet.structures import ReIDDataSample
  7. def _equal(a, b):
  8. if isinstance(a, (torch.Tensor, np.ndarray)):
  9. return (a == b).all()
  10. else:
  11. return a == b
  12. class TestReIDDataSample(TestCase):
  13. def test_init(self):
  14. img_shape = (256, 128)
  15. ori_shape = (64, 64)
  16. num_classes = 5
  17. meta_info = dict(
  18. img_shape=img_shape, ori_shape=ori_shape, num_classes=num_classes)
  19. data_sample = ReIDDataSample(metainfo=meta_info)
  20. self.assertIn('img_shape', data_sample)
  21. self.assertIn('ori_shape', data_sample)
  22. self.assertIn('num_classes', data_sample)
  23. self.assertTrue(_equal(data_sample.get('img_shape'), img_shape))
  24. self.assertTrue(_equal(data_sample.get('ori_shape'), ori_shape))
  25. self.assertTrue(_equal(data_sample.get('num_classes'), num_classes))
  26. def test_set_gt_label(self):
  27. data_sample = ReIDDataSample(metainfo=dict(num_classes=5))
  28. method = getattr(data_sample, 'set_' + 'gt_label')
  29. # Test number
  30. method(1)
  31. label = data_sample.get('gt_label')
  32. self.assertIsInstance(label, LabelData)
  33. self.assertIsInstance(label.label, torch.LongTensor)
  34. # Test tensor with single number
  35. method(torch.tensor(2))
  36. label = data_sample.get('gt_label')
  37. self.assertIsInstance(label, LabelData)
  38. self.assertIsInstance(label.label, torch.LongTensor)
  39. # Test array with single number
  40. method(np.array(3))
  41. label = data_sample.get('gt_label')
  42. self.assertIsInstance(label, LabelData)
  43. self.assertIsInstance(label.label, torch.LongTensor)
  44. # Test tensor
  45. _label = torch.tensor([1, 2, 3])
  46. method(_label)
  47. label = data_sample.get('gt_label')
  48. self.assertIsInstance(label, LabelData)
  49. self.assertIsInstance(label.label, torch.Tensor)
  50. self.assertTrue(_equal(label.label, _label))
  51. # Test array
  52. _label = np.array([1, 2, 3])
  53. method(_label)
  54. label = data_sample.get('gt_label')
  55. self.assertIsInstance(label, LabelData)
  56. self.assertIsInstance(label.label, torch.Tensor)
  57. self.assertTrue(_equal(label.label, torch.from_numpy(_label)))
  58. # Test Sequence
  59. _label = [1, 2, 3.]
  60. method(_label)
  61. label = data_sample.get('gt_label')
  62. self.assertIsInstance(label, LabelData)
  63. self.assertIsInstance(label.label, torch.Tensor)
  64. self.assertTrue(_equal(label.label, torch.tensor(_label)))
  65. # Test set num_classes
  66. self.assertEqual(label.num_classes, 5)
  67. # Test unavailable type
  68. with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
  69. method('hi')
  70. def test_set_gt_score(self):
  71. data_sample = ReIDDataSample(metainfo={'num_classes': 5})
  72. method = getattr(data_sample, 'set_' + 'gt_score')
  73. # Test set
  74. score = [0.1, 0.1, 0.6, 0.1, 0.1]
  75. method(torch.tensor(score))
  76. sample_gt_label = getattr(data_sample, 'gt_label')
  77. self.assertIn('score', sample_gt_label)
  78. torch.testing.assert_allclose(sample_gt_label.score, score)
  79. self.assertEqual(sample_gt_label.num_classes, 5)
  80. # Test set again
  81. score = [0.2, 0.1, 0.5, 0.1, 0.1]
  82. method(torch.tensor(score))
  83. torch.testing.assert_allclose(sample_gt_label.score, score)
  84. # Test invalid type
  85. with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
  86. method(score)
  87. # Test invalid dims
  88. with self.assertRaisesRegex(AssertionError, 'but got 2'):
  89. method(torch.tensor([score]))
  90. # Test invalid num_classes
  91. with self.assertRaisesRegex(AssertionError, r'length of value \(6\)'):
  92. method(torch.tensor(score + [0.1]))
  93. # Test auto inter num_classes
  94. data_sample = ReIDDataSample()
  95. method = getattr(data_sample, 'set_gt_score')
  96. method(torch.tensor(score))
  97. sample_gt_label = getattr(data_sample, 'gt_label')
  98. self.assertEqual(sample_gt_label.num_classes, len(score))
  99. def test_del_gt_label(self):
  100. data_sample = ReIDDataSample()
  101. self.assertNotIn('gt_label', data_sample)
  102. data_sample.set_gt_label(1)
  103. self.assertIn('gt_label', data_sample)
  104. del data_sample.gt_label
  105. self.assertNotIn('gt_label', data_sample)