test_grid_roi_head.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from unittest import TestCase
  4. import torch
  5. from parameterized import parameterized
  6. from mmdet.registry import MODELS
  7. from mmdet.testing import demo_mm_inputs, demo_mm_proposals, get_roi_head_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestGridRoIHead(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. self.roi_head_cfg = get_roi_head_cfg(
  13. 'grid_rcnn/grid-rcnn_r50_fpn_gn-head_2x_coco.py')
  14. def test_init(self):
  15. roi_head = MODELS.build(self.roi_head_cfg)
  16. self.assertTrue(roi_head.with_bbox)
  17. @parameterized.expand(['cpu', 'cuda'])
  18. def test_grid_roi_head_loss(self, device):
  19. """Tests trident roi head predict."""
  20. if device == 'cuda':
  21. if not torch.cuda.is_available():
  22. return unittest.skip('test requires GPU and torch+cuda')
  23. roi_head = MODELS.build(self.roi_head_cfg)
  24. roi_head = roi_head.to(device=device)
  25. s = 256
  26. feats = []
  27. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  28. feats.append(
  29. torch.rand(1, 256, s // (2**(i + 2)),
  30. s // (2**(i + 2))).to(device=device))
  31. image_shapes = [(3, s, s)]
  32. batch_data_samples = demo_mm_inputs(
  33. batch_size=1,
  34. image_shapes=image_shapes,
  35. num_items=[1],
  36. num_classes=4,
  37. with_mask=True,
  38. device=device)['data_samples']
  39. proposals_list = demo_mm_proposals(
  40. image_shapes=image_shapes, num_proposals=100, device=device)
  41. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  42. loss_cls = out['loss_cls']
  43. loss_grid = out['loss_grid']
  44. self.assertGreater(loss_cls.sum(), 0, 'cls loss should be non-zero')
  45. self.assertGreater(loss_grid.sum(), 0, 'grid loss should be non-zero')
  46. batch_data_samples = demo_mm_inputs(
  47. batch_size=1,
  48. image_shapes=image_shapes,
  49. num_items=[0],
  50. num_classes=4,
  51. with_mask=True,
  52. device=device)['data_samples']
  53. proposals_list = demo_mm_proposals(
  54. image_shapes=image_shapes, num_proposals=100, device=device)
  55. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  56. empty_cls_loss = out['loss_cls']
  57. self.assertGreater(empty_cls_loss.sum(), 0,
  58. 'cls loss should be non-zero')
  59. self.assertNotIn(
  60. 'loss_grid', out,
  61. 'grid loss should be passed when there are no true boxes')
  62. @parameterized.expand(['cpu', 'cuda'])
  63. def test_grid_roi_head_predict(self, device):
  64. """Tests trident roi head predict."""
  65. if device == 'cuda':
  66. if not torch.cuda.is_available():
  67. return unittest.skip('test requires GPU and torch+cuda')
  68. roi_head = MODELS.build(self.roi_head_cfg)
  69. roi_head = roi_head.to(device=device)
  70. s = 256
  71. feats = []
  72. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  73. feats.append(
  74. torch.rand(1, 256, s // (2**(i + 2)),
  75. s // (2**(i + 2))).to(device=device))
  76. image_shapes = [(3, s, s)]
  77. batch_data_samples = demo_mm_inputs(
  78. batch_size=1,
  79. image_shapes=image_shapes,
  80. num_items=[0],
  81. num_classes=4,
  82. with_mask=True,
  83. device=device)['data_samples']
  84. proposals_list = demo_mm_proposals(
  85. image_shapes=image_shapes, num_proposals=100, device=device)
  86. roi_head.predict(feats, proposals_list, batch_data_samples)
  87. @parameterized.expand(['cpu', 'cuda'])
  88. def test_grid_roi_head_forward(self, device):
  89. """Tests trident roi head forward."""
  90. if device == 'cuda':
  91. if not torch.cuda.is_available():
  92. return unittest.skip('test requires GPU and torch+cuda')
  93. roi_head = MODELS.build(self.roi_head_cfg)
  94. roi_head = roi_head.to(device=device)
  95. s = 256
  96. feats = []
  97. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  98. feats.append(
  99. torch.rand(1, 256, s // (2**(i + 2)),
  100. s // (2**(i + 2))).to(device=device))
  101. image_shapes = [(3, s, s)]
  102. proposals_list = demo_mm_proposals(
  103. image_shapes=image_shapes, num_proposals=100, device=device)
  104. roi_head.forward(feats, proposals_list)