test_base_reid.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from parameterized import parameterized
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import ReIDDataSample
  7. from mmdet.testing import get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestBaseReID(TestCase):
  10. @classmethod
  11. def setUpClass(cls) -> None:
  12. register_all_modules()
  13. @parameterized.expand([
  14. 'reid/reid_r50_8xb32-6e_mot17train80_test-mot17val20.py',
  15. ])
  16. def test_forward(self, cfg_file):
  17. model_cfg = get_detector_cfg(cfg_file)
  18. model = MODELS.build(model_cfg)
  19. inputs = torch.rand(1, 4, 3, 256, 128)
  20. data_samples = [
  21. ReIDDataSample().set_gt_label(label) for label in (0, 0, 1, 1)
  22. ]
  23. # test mode='tensor'
  24. feats = model(inputs, mode='tensor')
  25. assert feats.shape == (4, 128)
  26. # test mode='loss'
  27. losses = model(inputs, data_samples, mode='loss')
  28. assert losses.keys() == {'triplet_loss', 'ce_loss', 'accuracy_top-1'}
  29. assert losses['ce_loss'].item() > 0
  30. assert losses['triplet_loss'].item() > 0
  31. # test mode='predict'
  32. predictions = model(inputs, data_samples, mode='predict')
  33. for pred in predictions:
  34. assert isinstance(pred, ReIDDataSample)
  35. assert isinstance(pred.pred_feature, torch.Tensor)
  36. assert isinstance(pred.gt_label.label, torch.Tensor)
  37. assert pred.pred_feature.shape == (128, )