test_linear_reid_head.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmdet.registry import MODELS
  5. from mmdet.structures import ReIDDataSample
  6. from mmdet.utils import register_all_modules
  7. class TestLinearReIDHead(TestCase):
  8. @classmethod
  9. def setUpClass(cls) -> None:
  10. register_all_modules()
  11. head_cfg = dict(
  12. type='LinearReIDHead',
  13. num_fcs=1,
  14. in_channels=128,
  15. fc_channels=64,
  16. out_channels=32,
  17. num_classes=2,
  18. loss_cls=dict(type='mmpretrain.CrossEntropyLoss', loss_weight=1.0),
  19. loss_triplet=dict(type='TripletLoss', margin=0.3, loss_weight=1.0),
  20. norm_cfg=dict(type='BN1d'),
  21. act_cfg=dict(type='ReLU'))
  22. cls.head = MODELS.build(head_cfg)
  23. cls.inputs = (torch.rand(4, 128), torch.rand(4, 128))
  24. cls.data_samples = [
  25. ReIDDataSample().set_gt_label(label) for label in (0, 0, 1, 1)
  26. ]
  27. def test_forward(self):
  28. outputs = self.head(self.inputs)
  29. assert outputs.shape == (4, 32)
  30. def test_loss(self):
  31. losses = self.head.loss(self.inputs, self.data_samples)
  32. assert losses.keys() == {'triplet_loss', 'ce_loss', 'accuracy_top-1'}
  33. assert losses['ce_loss'].item() >= 0
  34. assert losses['triplet_loss'].item() >= 0
  35. def test_predict(self):
  36. predictions = self.head.predict(self.inputs, self.data_samples)
  37. for pred in predictions:
  38. assert isinstance(pred, ReIDDataSample)
  39. assert isinstance(pred.pred_feature, torch.Tensor)
  40. assert isinstance(pred.gt_label.label, torch.Tensor)
  41. assert pred.pred_feature.shape == (32, )