test_qdtrack.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import time
  3. import unittest
  4. from unittest import TestCase
  5. import torch
  6. from mmengine.logging import MessageHub
  7. from mmengine.registry import init_default_scope
  8. from parameterized import parameterized
  9. from mmdet.registry import MODELS
  10. from mmdet.testing import demo_track_inputs, get_detector_cfg
  11. class TestQDTrack(TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. init_default_scope('mmdet')
  15. @parameterized.expand([
  16. 'qdtrack/qdtrack_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_'
  17. 'test-mot17halfval.py',
  18. ])
  19. def test_qdtrack_init(self, cfg_file):
  20. model = get_detector_cfg(cfg_file)
  21. model = MODELS.build(model)
  22. assert model.detector
  23. assert model.track_head
  24. @parameterized.expand([
  25. ('qdtrack/qdtrack_faster-rcnn_r50_fpn_8xb2-4e_mot17'
  26. 'halftrain_test-mot17halfval.py', ('cpu', 'cuda')),
  27. ])
  28. def test_qdtrack_forward_loss_mode(self, cfg_file, devices):
  29. message_hub = MessageHub.get_instance(
  30. f'test_qdtrack_forward_loss_mode-{time.time()}')
  31. message_hub.update_info('iter', 0)
  32. message_hub.update_info('epoch', 0)
  33. assert all([device in ['cpu', 'cuda'] for device in devices])
  34. for device in devices:
  35. _model = get_detector_cfg(cfg_file)
  36. # _scope_ will be popped after build
  37. model = MODELS.build(_model)
  38. if device == 'cuda':
  39. if not torch.cuda.is_available():
  40. return unittest.skip('test requires GPU and torch+cuda')
  41. model = model.cuda()
  42. packed_inputs = demo_track_inputs(
  43. batch_size=1,
  44. num_frames=2,
  45. key_frames_inds=[0],
  46. image_shapes=(3, 128, 128),
  47. num_items=None)
  48. out_data = model.data_preprocessor(packed_inputs, True)
  49. inputs, data_samples = out_data['inputs'], out_data['data_samples']
  50. # Test forward
  51. losses = model.forward(inputs, data_samples, mode='loss')
  52. assert isinstance(losses, dict)
  53. @parameterized.expand([
  54. ('qdtrack/qdtrack_faster-rcnn_r50_fpn_8xb2-4e_mot17'
  55. 'halftrain_test-mot17halfval.py', ('cpu', 'cuda')),
  56. ])
  57. def test_qdtrack_forward_predict_mode(self, cfg_file, devices):
  58. message_hub = MessageHub.get_instance(
  59. f'test_bytetrack_forward_predict_mode-{time.time()}')
  60. message_hub.update_info('iter', 0)
  61. message_hub.update_info('epoch', 0)
  62. assert all([device in ['cpu', 'cuda'] for device in devices])
  63. for device in devices:
  64. _model = get_detector_cfg(cfg_file)
  65. model = MODELS.build(_model)
  66. if device == 'cuda':
  67. if not torch.cuda.is_available():
  68. return unittest.skip('test requires GPU and torch+cuda')
  69. model = model.cuda()
  70. packed_inputs = demo_track_inputs(
  71. batch_size=1, num_frames=1, image_shapes=(3, 128, 128))
  72. out_data = model.data_preprocessor(packed_inputs, False)
  73. # Test forward test
  74. model.eval()
  75. with torch.no_grad():
  76. batch_results = model.forward(**out_data, mode='predict')
  77. assert len(batch_results) == 1