test_dynamic_roi_head.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from unittest import TestCase
  4. import torch
  5. from parameterized import parameterized
  6. from mmdet.registry import MODELS
  7. from mmdet.testing import demo_mm_inputs, demo_mm_proposals, get_roi_head_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestDynamicRoIHead(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. self.roi_head_cfg = get_roi_head_cfg(
  13. 'dynamic_rcnn/dynamic-rcnn_r50_fpn_1x_coco.py')
  14. def test_init(self):
  15. roi_head = MODELS.build(self.roi_head_cfg)
  16. self.assertTrue(roi_head.with_bbox)
  17. @parameterized.expand(['cpu', 'cuda'])
  18. def test_dynamic_roi_head_loss(self, device):
  19. """Tests trident roi head predict."""
  20. if not torch.cuda.is_available() and device == 'cuda':
  21. # RoI pooling only support in GPU
  22. return unittest.skip('test requires GPU and torch+cuda')
  23. roi_head = MODELS.build(self.roi_head_cfg)
  24. roi_head = roi_head.to(device=device)
  25. s = 256
  26. feats = []
  27. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  28. feats.append(
  29. torch.rand(1, 256, s // (2**(i + 2)),
  30. s // (2**(i + 2))).to(device=device))
  31. image_shapes = [(3, s, s)]
  32. batch_data_samples = demo_mm_inputs(
  33. batch_size=1,
  34. image_shapes=image_shapes,
  35. num_items=[1],
  36. num_classes=4,
  37. with_mask=True,
  38. device=device)['data_samples']
  39. proposals_list = demo_mm_proposals(
  40. image_shapes=image_shapes, num_proposals=100, device=device)
  41. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  42. loss_cls = out['loss_cls']
  43. loss_bbox = out['loss_bbox']
  44. self.assertGreater(loss_cls.sum(), 0, 'cls loss should be non-zero')
  45. self.assertGreater(loss_bbox.sum(), 0, 'box loss should be non-zero')
  46. batch_data_samples = demo_mm_inputs(
  47. batch_size=1,
  48. image_shapes=image_shapes,
  49. num_items=[0],
  50. num_classes=4,
  51. with_mask=True,
  52. device=device)['data_samples']
  53. proposals_list = demo_mm_proposals(
  54. image_shapes=image_shapes, num_proposals=100, device=device)
  55. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  56. empty_cls_loss = out['loss_cls']
  57. empty_bbox_loss = out['loss_bbox']
  58. self.assertGreater(empty_cls_loss.sum(), 0,
  59. 'cls loss should be non-zero')
  60. self.assertEqual(
  61. empty_bbox_loss.sum(), 0,
  62. 'there should be no box loss when there are no true boxes')