test_sort.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import time
  3. import unittest
  4. from unittest import TestCase
  5. import torch
  6. from mmengine.logging import MessageHub
  7. from mmengine.registry import init_default_scope
  8. from parameterized import parameterized
  9. from mmdet.registry import MODELS
  10. from mmdet.testing import demo_track_inputs, get_detector_cfg
  11. class TestDeepSORT(TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. init_default_scope('mmdet')
  15. @parameterized.expand([
  16. 'sort/sort_faster-rcnn_r50_fpn_8xb2-4e'
  17. '_mot17halftrain_test-mot17halfval.py'
  18. ])
  19. def test_init(self, cfg_file):
  20. model = get_detector_cfg(cfg_file)
  21. model = MODELS.build(model)
  22. assert model.detector
  23. assert model.tracker
  24. @parameterized.expand([
  25. ('sort/sort_faster-rcnn_r50_fpn_8xb2-4e'
  26. '_mot17halftrain_test-mot17halfval.py', ('cpu', 'cuda')),
  27. ])
  28. def test_deepsort_forward_predict_mode(self, cfg_file, devices):
  29. message_hub = MessageHub.get_instance(
  30. f'test_deepsort_forward_predict_mode-{time.time()}')
  31. message_hub.update_info('iter', 0)
  32. message_hub.update_info('epoch', 0)
  33. assert all([device in ['cpu', 'cuda'] for device in devices])
  34. for device in devices:
  35. _model = get_detector_cfg(cfg_file)
  36. model = MODELS.build(_model)
  37. if device == 'cuda':
  38. if not torch.cuda.is_available():
  39. return unittest.skip('test requires GPU and torch+cuda')
  40. model = model.cuda()
  41. packed_inputs = demo_track_inputs(
  42. batch_size=1,
  43. num_frames=2,
  44. image_shapes=[(3, 256, 256)],
  45. num_classes=1)
  46. out_data = model.data_preprocessor(packed_inputs, False)
  47. # Test forward test
  48. model.eval()
  49. with torch.no_grad():
  50. batch_results = model.forward(**out_data, mode='predict')
  51. assert len(batch_results) == 1