test_pisa_roi_head.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from unittest import TestCase
  4. import torch
  5. from parameterized import parameterized
  6. from mmdet.registry import MODELS
  7. from mmdet.testing import demo_mm_inputs, demo_mm_proposals, get_roi_head_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestPISARoIHead(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. self.roi_head_cfg = get_roi_head_cfg(
  13. 'pisa/faster-rcnn_r50_fpn_pisa_1x_coco.py')
  14. def test_init(self):
  15. roi_head = MODELS.build(self.roi_head_cfg)
  16. self.assertTrue(roi_head.with_bbox)
  17. @parameterized.expand(['cpu', 'cuda'])
  18. def test_pisa_roi_head(self, device):
  19. """Tests trident roi head predict."""
  20. if not torch.cuda.is_available() and device == 'cuda':
  21. # RoI pooling only support in GPU
  22. return unittest.skip('test requires GPU and torch+cuda')
  23. roi_head = MODELS.build(self.roi_head_cfg)
  24. roi_head = roi_head.to(device=device)
  25. s = 256
  26. feats = []
  27. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  28. feats.append(
  29. torch.rand(1, 256, s // (2**(i + 2)),
  30. s // (2**(i + 2))).to(device=device))
  31. image_shapes = [(3, s, s)]
  32. batch_data_samples = demo_mm_inputs(
  33. batch_size=1,
  34. image_shapes=image_shapes,
  35. num_items=[1],
  36. num_classes=4,
  37. with_mask=True,
  38. device=device)['data_samples']
  39. proposals_list = demo_mm_proposals(
  40. image_shapes=image_shapes, num_proposals=100, device=device)
  41. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  42. loss_cls = out['loss_cls']
  43. loss_bbox = out['loss_bbox']
  44. self.assertGreater(loss_cls.sum(), 0, 'cls loss should be non-zero')
  45. self.assertGreater(loss_bbox.sum(), 0, 'box loss should be non-zero')
  46. batch_data_samples = demo_mm_inputs(
  47. batch_size=1,
  48. image_shapes=image_shapes,
  49. num_items=[0],
  50. num_classes=4,
  51. with_mask=True,
  52. device=device)['data_samples']
  53. proposals_list = demo_mm_proposals(
  54. image_shapes=image_shapes, num_proposals=100, device=device)
  55. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  56. empty_cls_loss = out['loss_cls']
  57. empty_bbox_loss = out['loss_bbox']
  58. self.assertGreater(empty_cls_loss.sum(), 0,
  59. 'cls loss should be non-zero')
  60. self.assertEqual(
  61. empty_bbox_loss.sum(), 0,
  62. 'there should be no box loss when there are no true boxes')