test_embedding_rpn_head.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import pytest
  4. import torch
  5. from mmengine.structures import InstanceData
  6. from mmdet.models.dense_heads import EmbeddingRPNHead
  7. from mmdet.structures import DetDataSample
  8. class TestEmbeddingRPNHead(TestCase):
  9. def test_init(self):
  10. """Test init rpn head."""
  11. rpn_head = EmbeddingRPNHead(
  12. num_proposals=100, proposal_feature_channel=256)
  13. rpn_head.init_weights()
  14. self.assertTrue(rpn_head.init_proposal_bboxes)
  15. self.assertTrue(rpn_head.init_proposal_features)
  16. def test_loss_and_predict(self):
  17. s = 256
  18. img_meta = {
  19. 'img_shape': (s, s, 3),
  20. 'pad_shape': (s, s, 3),
  21. 'scale_factor': 1,
  22. }
  23. rpn_head = EmbeddingRPNHead(
  24. num_proposals=100, proposal_feature_channel=256)
  25. feats = [
  26. torch.rand(2, 1, s // (2**(i + 2)), s // (2**(i + 2)))
  27. for i in range(5)
  28. ]
  29. data_sample = DetDataSample()
  30. data_sample.set_metainfo(img_meta)
  31. # test predict
  32. result_list = rpn_head.predict(feats, [data_sample])
  33. self.assertTrue(isinstance(result_list, list))
  34. self.assertTrue(isinstance(result_list[0], InstanceData))
  35. # test loss_and_predict
  36. result_list = rpn_head.loss_and_predict(feats, [data_sample])
  37. self.assertTrue(isinstance(result_list, tuple))
  38. self.assertTrue(isinstance(result_list[0], dict))
  39. self.assertEqual(len(result_list[0]), 0)
  40. self.assertTrue(isinstance(result_list[1], list))
  41. self.assertTrue(isinstance(result_list[1][0], InstanceData))
  42. # test loss
  43. with pytest.raises(NotImplementedError):
  44. rpn_head.loss(feats, [data_sample])