test_byte_track.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import time
  3. import unittest
  4. from unittest import TestCase
  5. import torch
  6. from mmengine.logging import MessageHub
  7. from mmengine.registry import init_default_scope
  8. from parameterized import parameterized
  9. from mmdet.registry import MODELS
  10. from mmdet.testing import demo_mm_inputs, demo_track_inputs, get_detector_cfg
  11. class TestByteTrack(TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. init_default_scope('mmdet')
  15. @parameterized.expand([
  16. 'bytetrack/bytetrack_yolox_x_8xb4-80e_crowdhuman-mot17halftrain'
  17. '_test-mot17halfval.py',
  18. ])
  19. def test_bytetrack_init(self, cfg_file):
  20. model = get_detector_cfg(cfg_file)
  21. model.detector.neck.out_channels = 1
  22. model.detector.neck.num_csp_blocks = 1
  23. model.detector.bbox_head.in_channels = 1
  24. model.detector.bbox_head.feat_channels = 1
  25. model = MODELS.build(model)
  26. assert model.detector
  27. @parameterized.expand([
  28. ('bytetrack/bytetrack_yolox_x_8xb4-80e_crowdhuman-mot17halftrain_'
  29. 'test-mot17halfval.py', ('cpu', 'cuda')),
  30. ])
  31. def test_bytetrack_forward_loss_mode(self, cfg_file, devices):
  32. message_hub = MessageHub.get_instance(
  33. f'test_bytetrack_forward_loss_mode-{time.time()}')
  34. message_hub.update_info('iter', 0)
  35. message_hub.update_info('epoch', 0)
  36. assert all([device in ['cpu', 'cuda'] for device in devices])
  37. for device in devices:
  38. _model = get_detector_cfg(cfg_file)
  39. _model.detector.neck.out_channels = 1
  40. _model.detector.neck.num_csp_blocks = 1
  41. _model.detector.bbox_head.num_classes = 10
  42. _model.detector.bbox_head.in_channels = 1
  43. _model.detector.bbox_head.feat_channels = 1
  44. # _scope_ will be popped after build
  45. model = MODELS.build(_model)
  46. if device == 'cuda':
  47. if not torch.cuda.is_available():
  48. return unittest.skip('test requires GPU and torch+cuda')
  49. model = model.cuda()
  50. packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]])
  51. data = model.data_preprocessor(packed_inputs, True)
  52. losses = model.forward(**data, mode='loss')
  53. assert isinstance(losses, dict)
  54. @parameterized.expand([
  55. ('bytetrack/bytetrack_yolox_x_8xb4-80e_crowdhuman-mot17halftrain_'
  56. 'test-mot17halfval.py', ('cpu', 'cuda')),
  57. ])
  58. def test_bytetrack_forward_predict_mode(self, cfg_file, devices):
  59. message_hub = MessageHub.get_instance(
  60. f'test_bytetrack_forward_predict_mode-{time.time()}')
  61. message_hub.update_info('iter', 0)
  62. message_hub.update_info('epoch', 0)
  63. assert all([device in ['cpu', 'cuda'] for device in devices])
  64. for device in devices:
  65. _model = get_detector_cfg(cfg_file)
  66. _model.detector.neck.out_channels = 1
  67. _model.detector.neck.num_csp_blocks = 1
  68. _model.detector.bbox_head.in_channels = 1
  69. _model.detector.bbox_head.feat_channels = 1
  70. model = MODELS.build(_model)
  71. if device == 'cuda':
  72. if not torch.cuda.is_available():
  73. return unittest.skip('test requires GPU and torch+cuda')
  74. model = model.cuda()
  75. packed_inputs = demo_track_inputs(
  76. batch_size=1,
  77. num_frames=2,
  78. image_shapes=[(3, 256, 256)],
  79. num_classes=1)
  80. out_data = model.data_preprocessor(packed_inputs, False)
  81. # Test forward test
  82. model.eval()
  83. with torch.no_grad():
  84. batch_results = model.forward(**out_data, mode='predict')
  85. assert len(batch_results) == 1