123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import numpy as np
- import torch
- from mmengine.structures import LabelData
- from mmdet.structures import ReIDDataSample
- def _equal(a, b):
- if isinstance(a, (torch.Tensor, np.ndarray)):
- return (a == b).all()
- else:
- return a == b
- class TestReIDDataSample(TestCase):
- def test_init(self):
- img_shape = (256, 128)
- ori_shape = (64, 64)
- num_classes = 5
- meta_info = dict(
- img_shape=img_shape, ori_shape=ori_shape, num_classes=num_classes)
- data_sample = ReIDDataSample(metainfo=meta_info)
- self.assertIn('img_shape', data_sample)
- self.assertIn('ori_shape', data_sample)
- self.assertIn('num_classes', data_sample)
- self.assertTrue(_equal(data_sample.get('img_shape'), img_shape))
- self.assertTrue(_equal(data_sample.get('ori_shape'), ori_shape))
- self.assertTrue(_equal(data_sample.get('num_classes'), num_classes))
- def test_set_gt_label(self):
- data_sample = ReIDDataSample(metainfo=dict(num_classes=5))
- method = getattr(data_sample, 'set_' + 'gt_label')
- # Test number
- method(1)
- label = data_sample.get('gt_label')
- self.assertIsInstance(label, LabelData)
- self.assertIsInstance(label.label, torch.LongTensor)
- # Test tensor with single number
- method(torch.tensor(2))
- label = data_sample.get('gt_label')
- self.assertIsInstance(label, LabelData)
- self.assertIsInstance(label.label, torch.LongTensor)
- # Test array with single number
- method(np.array(3))
- label = data_sample.get('gt_label')
- self.assertIsInstance(label, LabelData)
- self.assertIsInstance(label.label, torch.LongTensor)
- # Test tensor
- _label = torch.tensor([1, 2, 3])
- method(_label)
- label = data_sample.get('gt_label')
- self.assertIsInstance(label, LabelData)
- self.assertIsInstance(label.label, torch.Tensor)
- self.assertTrue(_equal(label.label, _label))
- # Test array
- _label = np.array([1, 2, 3])
- method(_label)
- label = data_sample.get('gt_label')
- self.assertIsInstance(label, LabelData)
- self.assertIsInstance(label.label, torch.Tensor)
- self.assertTrue(_equal(label.label, torch.from_numpy(_label)))
- # Test Sequence
- _label = [1, 2, 3.]
- method(_label)
- label = data_sample.get('gt_label')
- self.assertIsInstance(label, LabelData)
- self.assertIsInstance(label.label, torch.Tensor)
- self.assertTrue(_equal(label.label, torch.tensor(_label)))
- # Test set num_classes
- self.assertEqual(label.num_classes, 5)
- # Test unavailable type
- with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
- method('hi')
- def test_set_gt_score(self):
- data_sample = ReIDDataSample(metainfo={'num_classes': 5})
- method = getattr(data_sample, 'set_' + 'gt_score')
- # Test set
- score = [0.1, 0.1, 0.6, 0.1, 0.1]
- method(torch.tensor(score))
- sample_gt_label = getattr(data_sample, 'gt_label')
- self.assertIn('score', sample_gt_label)
- torch.testing.assert_allclose(sample_gt_label.score, score)
- self.assertEqual(sample_gt_label.num_classes, 5)
- # Test set again
- score = [0.2, 0.1, 0.5, 0.1, 0.1]
- method(torch.tensor(score))
- torch.testing.assert_allclose(sample_gt_label.score, score)
- # Test invalid type
- with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
- method(score)
- # Test invalid dims
- with self.assertRaisesRegex(AssertionError, 'but got 2'):
- method(torch.tensor([score]))
- # Test invalid num_classes
- with self.assertRaisesRegex(AssertionError, r'length of value \(6\)'):
- method(torch.tensor(score + [0.1]))
- # Test auto inter num_classes
- data_sample = ReIDDataSample()
- method = getattr(data_sample, 'set_gt_score')
- method(torch.tensor(score))
- sample_gt_label = getattr(data_sample, 'gt_label')
- self.assertEqual(sample_gt_label.num_classes, len(score))
- def test_del_gt_label(self):
- data_sample = ReIDDataSample()
- self.assertNotIn('gt_label', data_sample)
- data_sample.set_gt_label(1)
- self.assertIn('gt_label', data_sample)
- del data_sample.gt_label
- self.assertNotIn('gt_label', data_sample)
|