# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase import numpy as np import torch from mmengine.structures import LabelData from mmdet.structures import ReIDDataSample def _equal(a, b): if isinstance(a, (torch.Tensor, np.ndarray)): return (a == b).all() else: return a == b class TestReIDDataSample(TestCase): def test_init(self): img_shape = (256, 128) ori_shape = (64, 64) num_classes = 5 meta_info = dict( img_shape=img_shape, ori_shape=ori_shape, num_classes=num_classes) data_sample = ReIDDataSample(metainfo=meta_info) self.assertIn('img_shape', data_sample) self.assertIn('ori_shape', data_sample) self.assertIn('num_classes', data_sample) self.assertTrue(_equal(data_sample.get('img_shape'), img_shape)) self.assertTrue(_equal(data_sample.get('ori_shape'), ori_shape)) self.assertTrue(_equal(data_sample.get('num_classes'), num_classes)) def test_set_gt_label(self): data_sample = ReIDDataSample(metainfo=dict(num_classes=5)) method = getattr(data_sample, 'set_' + 'gt_label') # Test number method(1) label = data_sample.get('gt_label') self.assertIsInstance(label, LabelData) self.assertIsInstance(label.label, torch.LongTensor) # Test tensor with single number method(torch.tensor(2)) label = data_sample.get('gt_label') self.assertIsInstance(label, LabelData) self.assertIsInstance(label.label, torch.LongTensor) # Test array with single number method(np.array(3)) label = data_sample.get('gt_label') self.assertIsInstance(label, LabelData) self.assertIsInstance(label.label, torch.LongTensor) # Test tensor _label = torch.tensor([1, 2, 3]) method(_label) label = data_sample.get('gt_label') self.assertIsInstance(label, LabelData) self.assertIsInstance(label.label, torch.Tensor) self.assertTrue(_equal(label.label, _label)) # Test array _label = np.array([1, 2, 3]) method(_label) label = data_sample.get('gt_label') self.assertIsInstance(label, LabelData) self.assertIsInstance(label.label, torch.Tensor) self.assertTrue(_equal(label.label, torch.from_numpy(_label))) # Test Sequence _label = [1, 2, 3.] method(_label) label = data_sample.get('gt_label') self.assertIsInstance(label, LabelData) self.assertIsInstance(label.label, torch.Tensor) self.assertTrue(_equal(label.label, torch.tensor(_label))) # Test set num_classes self.assertEqual(label.num_classes, 5) # Test unavailable type with self.assertRaisesRegex(TypeError, " is not"): method('hi') def test_set_gt_score(self): data_sample = ReIDDataSample(metainfo={'num_classes': 5}) method = getattr(data_sample, 'set_' + 'gt_score') # Test set score = [0.1, 0.1, 0.6, 0.1, 0.1] method(torch.tensor(score)) sample_gt_label = getattr(data_sample, 'gt_label') self.assertIn('score', sample_gt_label) torch.testing.assert_allclose(sample_gt_label.score, score) self.assertEqual(sample_gt_label.num_classes, 5) # Test set again score = [0.2, 0.1, 0.5, 0.1, 0.1] method(torch.tensor(score)) torch.testing.assert_allclose(sample_gt_label.score, score) # Test invalid type with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'): method(score) # Test invalid dims with self.assertRaisesRegex(AssertionError, 'but got 2'): method(torch.tensor([score])) # Test invalid num_classes with self.assertRaisesRegex(AssertionError, r'length of value \(6\)'): method(torch.tensor(score + [0.1])) # Test auto inter num_classes data_sample = ReIDDataSample() method = getattr(data_sample, 'set_gt_score') method(torch.tensor(score)) sample_gt_label = getattr(data_sample, 'gt_label') self.assertEqual(sample_gt_label.num_classes, len(score)) def test_del_gt_label(self): data_sample = ReIDDataSample() self.assertNotIn('gt_label', data_sample) data_sample.set_gt_label(1) self.assertIn('gt_label', data_sample) del data_sample.gt_label self.assertNotIn('gt_label', data_sample)