test_trident_roi_head.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import unittest
  4. from unittest import TestCase
  5. import torch
  6. from mmdet.registry import MODELS
  7. from mmdet.testing import demo_mm_inputs, demo_mm_proposals, get_roi_head_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestTridentRoIHead(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. self.roi_head_cfg = get_roi_head_cfg(
  13. 'tridentnet/tridentnet_r50-caffe_1x_coco.py')
  14. def test_init(self):
  15. roi_head = MODELS.build(self.roi_head_cfg)
  16. self.assertTrue(roi_head.with_bbox)
  17. self.assertTrue(roi_head.with_shared_head)
  18. def test_trident_roi_head_predict(self):
  19. """Tests trident roi head predict."""
  20. if not torch.cuda.is_available():
  21. # RoI pooling only support in GPU
  22. return unittest.skip('test requires GPU and torch+cuda')
  23. roi_head_cfg = copy.deepcopy(self.roi_head_cfg)
  24. roi_head = MODELS.build(roi_head_cfg)
  25. roi_head = roi_head.cuda()
  26. s = 256
  27. feats = []
  28. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  29. feats.append(
  30. torch.rand(1, 1024, s // (2**(i + 2)),
  31. s // (2**(i + 2))).to(device='cuda'))
  32. image_shapes = [(3, s, s)]
  33. batch_data_samples = demo_mm_inputs(
  34. batch_size=1,
  35. image_shapes=image_shapes,
  36. num_items=[0],
  37. num_classes=4,
  38. with_mask=True,
  39. device='cuda')['data_samples']
  40. proposals_list = demo_mm_proposals(
  41. image_shapes=image_shapes, num_proposals=100, device='cuda')
  42. # When `test_branch_idx == 1`
  43. roi_head.predict(feats, proposals_list, batch_data_samples)
  44. # When `test_branch_idx == -1`
  45. roi_head_cfg.test_branch_idx = -1
  46. roi_head = MODELS.build(roi_head_cfg)
  47. roi_head = roi_head.cuda()
  48. roi_head.predict(feats, proposals_list, batch_data_samples)