test_mask2former.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import time
  3. import unittest
  4. from unittest import TestCase
  5. import torch
  6. from mmengine.logging import MessageHub
  7. from mmengine.registry import init_default_scope
  8. from parameterized import parameterized
  9. from mmdet.registry import MODELS
  10. from mmdet.testing import demo_track_inputs, get_detector_cfg
  11. class TestMask2Former(TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. init_default_scope('mmdet')
  15. @parameterized.expand([
  16. 'mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021.py',
  17. ])
  18. def test_mask2former_init(self, cfg_file):
  19. model = get_detector_cfg(cfg_file)
  20. model = MODELS.build(model)
  21. assert model.backbone
  22. assert model.track_head
  23. @parameterized.expand([
  24. ('mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021.py',
  25. ('cpu', 'cuda')),
  26. ])
  27. def test_mask2former_forward_loss_mode(self, cfg_file, devices):
  28. message_hub = MessageHub.get_instance(
  29. f'test_mask2former_forward_loss_mode-{time.time()}')
  30. message_hub.update_info('iter', 0)
  31. message_hub.update_info('epoch', 0)
  32. assert all([device in ['cpu', 'cuda'] for device in devices])
  33. for device in devices:
  34. _model = get_detector_cfg(cfg_file)
  35. # _scope_ will be popped after build
  36. model = MODELS.build(_model)
  37. if device == 'cuda':
  38. if not torch.cuda.is_available():
  39. return unittest.skip('test requires GPU and torch+cuda')
  40. model = model.cuda()
  41. packed_inputs = demo_track_inputs(
  42. batch_size=1,
  43. num_frames=2,
  44. key_frames_inds=[0],
  45. image_shapes=(3, 128, 128),
  46. num_classes=2,
  47. with_mask=True)
  48. out_data = model.data_preprocessor(packed_inputs, True)
  49. inputs, data_samples = out_data['inputs'], out_data['data_samples']
  50. losses = model.forward(inputs, data_samples, mode='loss')
  51. assert isinstance(losses, dict)
  52. @parameterized.expand([
  53. ('mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021.py',
  54. ('cpu', 'cuda')),
  55. ])
  56. def test_mask2former_forward_predict_mode(self, cfg_file, devices):
  57. message_hub = MessageHub.get_instance(
  58. f'test_mask2former_forward_predict_mode-{time.time()}')
  59. message_hub.update_info('iter', 0)
  60. message_hub.update_info('epoch', 0)
  61. assert all([device in ['cpu', 'cuda'] for device in devices])
  62. for device in devices:
  63. _model = get_detector_cfg(cfg_file)
  64. model = MODELS.build(_model)
  65. if device == 'cuda':
  66. if not torch.cuda.is_available():
  67. return unittest.skip('test requires GPU and torch+cuda')
  68. model = model.cuda()
  69. packed_inputs = demo_track_inputs(
  70. batch_size=1,
  71. num_frames=1,
  72. image_shapes=(3, 128, 128),
  73. num_classes=2,
  74. with_mask=True)
  75. out_data = model.data_preprocessor(packed_inputs, False)
  76. # Test forward test
  77. model.eval()
  78. with torch.no_grad():
  79. batch_results = model.forward(**out_data, mode='predict')
  80. assert len(batch_results) == 1