test_glip.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from unittest import TestCase
  4. import torch
  5. from parameterized import parameterized
  6. from mmdet.structures import DetDataSample
  7. from mmdet.testing import demo_mm_inputs, get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestGLIP(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. @parameterized.expand(
  13. ['glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py'])
  14. def test_init(self, cfg_file):
  15. model = get_detector_cfg(cfg_file)
  16. model.backbone.init_cfg = None
  17. from mmdet.registry import MODELS
  18. detector = MODELS.build(model)
  19. self.assertTrue(detector.backbone)
  20. self.assertTrue(detector.language_model)
  21. self.assertTrue(detector.neck)
  22. self.assertTrue(detector.bbox_head)
  23. @parameterized.expand([
  24. ('glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py', ('cpu',
  25. 'cuda'))
  26. ])
  27. def test_glip_forward_predict_mode(self, cfg_file, devices):
  28. model = get_detector_cfg(cfg_file)
  29. model.backbone.init_cfg = None
  30. from mmdet.registry import MODELS
  31. assert all([device in ['cpu', 'cuda'] for device in devices])
  32. for device in devices:
  33. detector = MODELS.build(model)
  34. if device == 'cuda':
  35. if not torch.cuda.is_available():
  36. return unittest.skip('test requires GPU and torch+cuda')
  37. detector = detector.cuda()
  38. # test custom_entities is True
  39. packed_inputs = demo_mm_inputs(
  40. 2, [[3, 128, 128], [3, 125, 130]],
  41. texts=['a', 'b'],
  42. custom_entities=True)
  43. data = detector.data_preprocessor(packed_inputs, False)
  44. # Test forward test
  45. detector.eval()
  46. with torch.no_grad():
  47. batch_results = detector.forward(**data, mode='predict')
  48. self.assertEqual(len(batch_results), 2)
  49. self.assertIsInstance(batch_results[0], DetDataSample)
  50. # test custom_entities is False
  51. packed_inputs = demo_mm_inputs(
  52. 2, [[3, 128, 128], [3, 125, 130]],
  53. texts=['a', 'b'],
  54. custom_entities=False)
  55. data = detector.data_preprocessor(packed_inputs, False)
  56. # Test forward test
  57. detector.eval()
  58. with torch.no_grad():
  59. batch_results = detector.forward(**data, mode='predict')
  60. self.assertEqual(len(batch_results), 2)
  61. self.assertIsInstance(batch_results[0], DetDataSample)