123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch.nn as nn
- import torch.utils.checkpoint as cp
- from mmcv.cnn import ConvModule
- from mmcv.cnn.bricks import DropPath
- from mmengine.model import BaseModule
- from .se_layer import SELayer
- class InvertedResidual(BaseModule):
- """Inverted Residual Block.
- Args:
- in_channels (int): The input channels of this Module.
- out_channels (int): The output channels of this Module.
- mid_channels (int): The input channels of the depthwise convolution.
- kernel_size (int): The kernel size of the depthwise convolution.
- Default: 3.
- stride (int): The stride of the depthwise convolution. Default: 1.
- se_cfg (dict): Config dict for se layer. Default: None, which means no
- se layer.
- with_expand_conv (bool): Use expand conv or not. If set False,
- mid_channels must be the same with in_channels.
- Default: True.
- conv_cfg (dict): Config dict for convolution layer. Default: None,
- which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Default: dict(type='BN').
- act_cfg (dict): Config dict for activation layer.
- Default: dict(type='ReLU').
- drop_path_rate (float): stochastic depth rate. Defaults to 0.
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
- memory while slowing down the training speed. Default: False.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Default: None
- Returns:
- Tensor: The output tensor.
- """
- def __init__(self,
- in_channels,
- out_channels,
- mid_channels,
- kernel_size=3,
- stride=1,
- se_cfg=None,
- with_expand_conv=True,
- conv_cfg=None,
- norm_cfg=dict(type='BN'),
- act_cfg=dict(type='ReLU'),
- drop_path_rate=0.,
- with_cp=False,
- init_cfg=None):
- super(InvertedResidual, self).__init__(init_cfg)
- self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
- assert stride in [1, 2], f'stride must in [1, 2]. ' \
- f'But received {stride}.'
- self.with_cp = with_cp
- self.drop_path = DropPath(
- drop_path_rate) if drop_path_rate > 0 else nn.Identity()
- self.with_se = se_cfg is not None
- self.with_expand_conv = with_expand_conv
- if self.with_se:
- assert isinstance(se_cfg, dict)
- if not self.with_expand_conv:
- assert mid_channels == in_channels
- if self.with_expand_conv:
- self.expand_conv = ConvModule(
- in_channels=in_channels,
- out_channels=mid_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.depthwise_conv = ConvModule(
- in_channels=mid_channels,
- out_channels=mid_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=kernel_size // 2,
- groups=mid_channels,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- if self.with_se:
- self.se = SELayer(**se_cfg)
- self.linear_conv = ConvModule(
- in_channels=mid_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=None)
- def forward(self, x):
- def _inner_forward(x):
- out = x
- if self.with_expand_conv:
- out = self.expand_conv(out)
- out = self.depthwise_conv(out)
- if self.with_se:
- out = self.se(out)
- out = self.linear_conv(out)
- if self.with_res_shortcut:
- return x + self.drop_path(out)
- else:
- return out
- if self.with_cp and x.requires_grad:
- out = cp.checkpoint(_inner_forward, x)
- else:
- out = _inner_forward(x)
- return out
|