123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import pytest
- import torch
- import torch.nn.functional as F
- from mmengine.model import constant_init
- from mmdet.models.layers import DyReLU, SELayer
- def test_se_layer():
- with pytest.raises(AssertionError):
-
- SELayer(channels=32, act_cfg=(dict(type='ReLU'), ))
- with pytest.raises(AssertionError):
-
- SELayer(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
-
- layer = SELayer(channels=32)
- layer.init_weights()
- layer.train()
- x = torch.randn((1, 32, 10, 10))
- x_out = layer(x)
- assert x_out.shape == torch.Size((1, 32, 10, 10))
- def test_dyrelu():
- with pytest.raises(AssertionError):
-
- DyReLU(channels=32, act_cfg=(dict(type='ReLU'), ))
- with pytest.raises(AssertionError):
-
- DyReLU(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
-
- layer = DyReLU(channels=32)
- layer.init_weights()
- layer.train()
- x = torch.randn((1, 32, 10, 10))
- x_out = layer(x)
- assert x_out.shape == torch.Size((1, 32, 10, 10))
-
-
- layer = DyReLU(channels=32)
- constant_init(layer.conv2.conv, 0)
- layer.train()
- x = torch.randn((1, 32, 10, 10))
- x_out = layer(x)
- relu_out = F.relu(x)
- assert torch.equal(x_out, relu_out)
|