inverted_residual.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.utils.checkpoint as cp
  4. from mmcv.cnn import ConvModule
  5. from mmcv.cnn.bricks import DropPath
  6. from mmengine.model import BaseModule
  7. from .se_layer import SELayer
  8. class InvertedResidual(BaseModule):
  9. """Inverted Residual Block.
  10. Args:
  11. in_channels (int): The input channels of this Module.
  12. out_channels (int): The output channels of this Module.
  13. mid_channels (int): The input channels of the depthwise convolution.
  14. kernel_size (int): The kernel size of the depthwise convolution.
  15. Default: 3.
  16. stride (int): The stride of the depthwise convolution. Default: 1.
  17. se_cfg (dict): Config dict for se layer. Default: None, which means no
  18. se layer.
  19. with_expand_conv (bool): Use expand conv or not. If set False,
  20. mid_channels must be the same with in_channels.
  21. Default: True.
  22. conv_cfg (dict): Config dict for convolution layer. Default: None,
  23. which means using conv2d.
  24. norm_cfg (dict): Config dict for normalization layer.
  25. Default: dict(type='BN').
  26. act_cfg (dict): Config dict for activation layer.
  27. Default: dict(type='ReLU').
  28. drop_path_rate (float): stochastic depth rate. Defaults to 0.
  29. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  30. memory while slowing down the training speed. Default: False.
  31. init_cfg (dict or list[dict], optional): Initialization config dict.
  32. Default: None
  33. Returns:
  34. Tensor: The output tensor.
  35. """
  36. def __init__(self,
  37. in_channels,
  38. out_channels,
  39. mid_channels,
  40. kernel_size=3,
  41. stride=1,
  42. se_cfg=None,
  43. with_expand_conv=True,
  44. conv_cfg=None,
  45. norm_cfg=dict(type='BN'),
  46. act_cfg=dict(type='ReLU'),
  47. drop_path_rate=0.,
  48. with_cp=False,
  49. init_cfg=None):
  50. super(InvertedResidual, self).__init__(init_cfg)
  51. self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
  52. assert stride in [1, 2], f'stride must in [1, 2]. ' \
  53. f'But received {stride}.'
  54. self.with_cp = with_cp
  55. self.drop_path = DropPath(
  56. drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  57. self.with_se = se_cfg is not None
  58. self.with_expand_conv = with_expand_conv
  59. if self.with_se:
  60. assert isinstance(se_cfg, dict)
  61. if not self.with_expand_conv:
  62. assert mid_channels == in_channels
  63. if self.with_expand_conv:
  64. self.expand_conv = ConvModule(
  65. in_channels=in_channels,
  66. out_channels=mid_channels,
  67. kernel_size=1,
  68. stride=1,
  69. padding=0,
  70. conv_cfg=conv_cfg,
  71. norm_cfg=norm_cfg,
  72. act_cfg=act_cfg)
  73. self.depthwise_conv = ConvModule(
  74. in_channels=mid_channels,
  75. out_channels=mid_channels,
  76. kernel_size=kernel_size,
  77. stride=stride,
  78. padding=kernel_size // 2,
  79. groups=mid_channels,
  80. conv_cfg=conv_cfg,
  81. norm_cfg=norm_cfg,
  82. act_cfg=act_cfg)
  83. if self.with_se:
  84. self.se = SELayer(**se_cfg)
  85. self.linear_conv = ConvModule(
  86. in_channels=mid_channels,
  87. out_channels=out_channels,
  88. kernel_size=1,
  89. stride=1,
  90. padding=0,
  91. conv_cfg=conv_cfg,
  92. norm_cfg=norm_cfg,
  93. act_cfg=None)
  94. def forward(self, x):
  95. def _inner_forward(x):
  96. out = x
  97. if self.with_expand_conv:
  98. out = self.expand_conv(out)
  99. out = self.depthwise_conv(out)
  100. if self.with_se:
  101. out = self.se(out)
  102. out = self.linear_conv(out)
  103. if self.with_res_shortcut:
  104. return x + self.drop_path(out)
  105. else:
  106. return out
  107. if self.with_cp and x.requires_grad:
  108. out = cp.checkpoint(_inner_forward, x)
  109. else:
  110. out = _inner_forward(x)
  111. return out