test_masktrack_rcnn.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import time
  3. import unittest
  4. from unittest import TestCase
  5. import torch
  6. from mmengine.logging import MessageHub
  7. from mmengine.registry import init_default_scope
  8. from parameterized import parameterized
  9. from mmdet.registry import MODELS
  10. from mmdet.testing import demo_track_inputs, get_detector_cfg
  11. class TestMaskTrackRCNN(TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. init_default_scope('mmdet')
  15. @parameterized.expand([
  16. 'masktrack_rcnn/masktrack-rcnn_mask-rcnn_r50_fpn_8xb1-12e_youtubevis2019.py', # noqa: E501
  17. ])
  18. def test_mask_track_rcnn_init(self, cfg_file):
  19. model = get_detector_cfg(cfg_file)
  20. model = MODELS.build(model)
  21. assert model.detector
  22. assert model.track_head
  23. assert model.tracker
  24. @parameterized.expand([
  25. (
  26. 'masktrack_rcnn/masktrack-rcnn_mask-rcnn_r50_fpn_8xb1-12e_youtubevis2019.py', # noqa: E501
  27. ('cpu', 'cuda')),
  28. ])
  29. def test_mask_track_rcnn_forward_loss_mode(self, cfg_file, devices):
  30. message_hub = MessageHub.get_instance(
  31. f'test_mask_track_rcnn_forward_loss_mode-{time.time()}')
  32. message_hub.update_info('iter', 0)
  33. message_hub.update_info('epoch', 0)
  34. assert all([device in ['cpu', 'cuda'] for device in devices])
  35. for device in devices:
  36. _model = get_detector_cfg(cfg_file)
  37. # _scope_ will be popped after build
  38. model = MODELS.build(_model)
  39. if device == 'cuda':
  40. if not torch.cuda.is_available():
  41. return unittest.skip('test requires GPU and torch+cuda')
  42. model = model.cuda()
  43. packed_inputs = demo_track_inputs(
  44. batch_size=1,
  45. num_frames=2,
  46. key_frames_inds=[0],
  47. image_shapes=(3, 128, 128),
  48. num_classes=2,
  49. with_mask=True)
  50. out_data = model.data_preprocessor(packed_inputs, True)
  51. # Test forward
  52. losses = model.forward(**out_data, mode='loss')
  53. assert isinstance(losses, dict)
  54. @parameterized.expand([
  55. (
  56. 'masktrack_rcnn/masktrack-rcnn_mask-rcnn_r50_fpn_8xb1-12e_youtubevis2019.py', # noqa: E501
  57. ('cpu', 'cuda')),
  58. ])
  59. def test_mask_track_rcnn_forward_predict_mode(self, cfg_file, devices):
  60. message_hub = MessageHub.get_instance(
  61. f'test_mask_track_rcnn_forward_predict_mode-{time.time()}')
  62. message_hub.update_info('iter', 0)
  63. message_hub.update_info('epoch', 0)
  64. assert all([device in ['cpu', 'cuda'] for device in devices])
  65. for device in devices:
  66. _model = get_detector_cfg(cfg_file)
  67. model = MODELS.build(_model)
  68. if device == 'cuda':
  69. if not torch.cuda.is_available():
  70. return unittest.skip('test requires GPU and torch+cuda')
  71. model = model.cuda()
  72. packed_inputs = demo_track_inputs(
  73. batch_size=1,
  74. num_frames=1,
  75. image_shapes=(3, 128, 128),
  76. num_classes=2,
  77. with_mask=True)
  78. out_data = model.data_preprocessor(packed_inputs, False)
  79. # Test forward test
  80. model.eval()
  81. with torch.no_grad():
  82. batch_results = model.forward(**out_data, mode='predict')
  83. assert len(batch_results) == 1