# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import pytest
import torch
from mmengine.structures import InstanceData

from mmdet.models.dense_heads import EmbeddingRPNHead
from mmdet.structures import DetDataSample


class TestEmbeddingRPNHead(TestCase):

    def test_init(self):
        """Test init rpn head."""
        rpn_head = EmbeddingRPNHead(
            num_proposals=100, proposal_feature_channel=256)
        rpn_head.init_weights()
        self.assertTrue(rpn_head.init_proposal_bboxes)
        self.assertTrue(rpn_head.init_proposal_features)

    def test_loss_and_predict(self):
        s = 256
        img_meta = {
            'img_shape': (s, s, 3),
            'pad_shape': (s, s, 3),
            'scale_factor': 1,
        }
        rpn_head = EmbeddingRPNHead(
            num_proposals=100, proposal_feature_channel=256)

        feats = [
            torch.rand(2, 1, s // (2**(i + 2)), s // (2**(i + 2)))
            for i in range(5)
        ]

        data_sample = DetDataSample()
        data_sample.set_metainfo(img_meta)

        # test predict
        result_list = rpn_head.predict(feats, [data_sample])
        self.assertTrue(isinstance(result_list, list))
        self.assertTrue(isinstance(result_list[0], InstanceData))

        # test loss_and_predict
        result_list = rpn_head.loss_and_predict(feats, [data_sample])
        self.assertTrue(isinstance(result_list, tuple))
        self.assertTrue(isinstance(result_list[0], dict))
        self.assertEqual(len(result_list[0]), 0)
        self.assertTrue(isinstance(result_list[1], list))
        self.assertTrue(isinstance(result_list[1][0], InstanceData))

        # test loss
        with pytest.raises(NotImplementedError):
            rpn_head.loss(feats, [data_sample])