12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import pytest
- import torch
- from mmengine.structures import InstanceData
- from mmdet.models.dense_heads import EmbeddingRPNHead
- from mmdet.structures import DetDataSample
- class TestEmbeddingRPNHead(TestCase):
- def test_init(self):
- """Test init rpn head."""
- rpn_head = EmbeddingRPNHead(
- num_proposals=100, proposal_feature_channel=256)
- rpn_head.init_weights()
- self.assertTrue(rpn_head.init_proposal_bboxes)
- self.assertTrue(rpn_head.init_proposal_features)
- def test_loss_and_predict(self):
- s = 256
- img_meta = {
- 'img_shape': (s, s, 3),
- 'pad_shape': (s, s, 3),
- 'scale_factor': 1,
- }
- rpn_head = EmbeddingRPNHead(
- num_proposals=100, proposal_feature_channel=256)
- feats = [
- torch.rand(2, 1, s // (2**(i + 2)), s // (2**(i + 2)))
- for i in range(5)
- ]
- data_sample = DetDataSample()
- data_sample.set_metainfo(img_meta)
- # test predict
- result_list = rpn_head.predict(feats, [data_sample])
- self.assertTrue(isinstance(result_list, list))
- self.assertTrue(isinstance(result_list[0], InstanceData))
- # test loss_and_predict
- result_list = rpn_head.loss_and_predict(feats, [data_sample])
- self.assertTrue(isinstance(result_list, tuple))
- self.assertTrue(isinstance(result_list[0], dict))
- self.assertEqual(len(result_list[0]), 0)
- self.assertTrue(isinstance(result_list[1], list))
- self.assertTrue(isinstance(result_list[1][0], InstanceData))
- # test loss
- with pytest.raises(NotImplementedError):
- rpn_head.loss(feats, [data_sample])
|