# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase import torch from mmdet.models import FcModule class TestFcModule(TestCase): def test_forward(self): inputs = torch.rand(32, 128) # test fc = FcModule( in_channels=128, out_channels=32, ) fc.init_weights() outputs = fc(inputs) assert outputs.shape == (32, 32) # test with norm fc = FcModule( in_channels=128, out_channels=32, norm_cfg=dict(type='BN1d'), ) outputs = fc(inputs) assert outputs.shape == (32, 32) # test with norm and act fc = FcModule( in_channels=128, out_channels=32, norm_cfg=dict(type='BN1d'), act_cfg=dict(type='ReLU'), ) outputs = fc(inputs) assert outputs.shape == (32, 32)