res_layer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. from mmcv.cnn import build_conv_layer, build_norm_layer
  4. from mmengine.model import BaseModule, Sequential
  5. from torch import Tensor
  6. from torch import nn as nn
  7. import torch
  8. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  9. class SPD(nn.Module):
  10. # Changing the dimension of the Tensor
  11. def __init__(self, dimension=1):
  12. super().__init__()
  13. self.d = dimension
  14. def forward(self, x):
  15. return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
  16. class ResLayer(Sequential):
  17. """ResLayer to build ResNet style backbone.
  18. Args:
  19. block (nn.Module): block used to build ResLayer.
  20. inplanes (int): inplanes of block.
  21. planes (int): planes of block.
  22. num_blocks (int): number of blocks.
  23. stride (int): stride of the first block. Defaults to 1
  24. avg_down (bool): Use AvgPool instead of stride conv when
  25. downsampling in the bottleneck. Defaults to False
  26. conv_cfg (dict): dictionary to construct and config conv layer.
  27. Defaults to None
  28. norm_cfg (dict): dictionary to construct and config norm layer.
  29. Defaults to dict(type='BN')
  30. downsample_first (bool): Downsample at the first block or last block.
  31. False for Hourglass, True for ResNet. Defaults to True
  32. """
  33. def __init__(self,
  34. block: BaseModule,
  35. inplanes: int,
  36. planes: int,
  37. num_blocks: int,
  38. stride: int = 1,
  39. avg_down: bool = False,
  40. conv_cfg: OptConfigType = None,
  41. norm_cfg: ConfigType = dict(type='BN'),
  42. downsample_first: bool = True,
  43. **kwargs) -> None:
  44. self.block = block
  45. downsample = None
  46. if stride != 1 or inplanes != planes * block.expansion:
  47. downsample = []
  48. conv_stride = stride
  49. if avg_down:
  50. conv_stride = 1
  51. downsample.append(
  52. nn.AvgPool2d(
  53. kernel_size=stride,
  54. stride=stride,
  55. ceil_mode=True,
  56. count_include_pad=False))
  57. downsample.extend([
  58. build_conv_layer(
  59. conv_cfg,
  60. inplanes,
  61. planes * block.expansion,
  62. kernel_size=1,
  63. stride=conv_stride,
  64. bias=False),
  65. build_norm_layer(norm_cfg, planes * block.expansion)[1]
  66. ])
  67. downsample = nn.Sequential(*downsample)
  68. layers = []
  69. if downsample_first:
  70. layers.append(
  71. block(
  72. inplanes=inplanes,
  73. planes=planes,
  74. stride=stride,
  75. downsample=downsample,
  76. conv_cfg=conv_cfg,
  77. norm_cfg=norm_cfg,
  78. **kwargs))
  79. inplanes = planes * block.expansion
  80. for _ in range(1, num_blocks):
  81. layers.append(
  82. block(
  83. inplanes=inplanes,
  84. planes=planes,
  85. stride=1,
  86. conv_cfg=conv_cfg,
  87. norm_cfg=norm_cfg,
  88. **kwargs))
  89. else: # downsample_first=False is for HourglassModule
  90. for _ in range(num_blocks - 1):
  91. layers.append(
  92. block(
  93. inplanes=inplanes,
  94. planes=inplanes,
  95. stride=1,
  96. conv_cfg=conv_cfg,
  97. norm_cfg=norm_cfg,
  98. **kwargs))
  99. layers.append(
  100. block(
  101. inplanes=inplanes,
  102. planes=planes,
  103. stride=stride,
  104. downsample=downsample,
  105. conv_cfg=conv_cfg,
  106. norm_cfg=norm_cfg,
  107. **kwargs))
  108. super().__init__(*layers)
  109. class SimplifiedBasicBlock(BaseModule):
  110. """Simplified version of original basic residual block. This is used in
  111. `SCNet <https://arxiv.org/abs/2012.10150>`_.
  112. - Norm layer is now optional
  113. - Last ReLU in forward function is removed
  114. """
  115. expansion = 1
  116. def __init__(self,
  117. inplanes: int,
  118. planes: int,
  119. stride: int = 1,
  120. dilation: int = 1,
  121. downsample: Optional[Sequential] = None,
  122. style: ConfigType = 'pytorch',
  123. with_cp: bool = False,
  124. conv_cfg: OptConfigType = None,
  125. norm_cfg: ConfigType = dict(type='BN'),
  126. dcn: OptConfigType = None,
  127. plugins: OptConfigType = None,
  128. init_cfg: OptMultiConfig = None) -> None:
  129. super().__init__(init_cfg=init_cfg)
  130. assert dcn is None, 'Not implemented yet.'
  131. assert plugins is None, 'Not implemented yet.'
  132. assert not with_cp, 'Not implemented yet.'
  133. self.with_norm = norm_cfg is not None
  134. with_bias = True if norm_cfg is None else False
  135. self.conv1 = build_conv_layer(
  136. conv_cfg,
  137. inplanes,
  138. planes,
  139. 3,
  140. stride=stride,
  141. padding=dilation,
  142. dilation=dilation,
  143. bias=with_bias)
  144. if self.with_norm:
  145. self.norm1_name, norm1 = build_norm_layer(
  146. norm_cfg, planes, postfix=1)
  147. self.add_module(self.norm1_name, norm1)
  148. self.conv2 = build_conv_layer(
  149. conv_cfg, planes, planes, 3, padding=1, bias=with_bias)
  150. if self.with_norm:
  151. self.norm2_name, norm2 = build_norm_layer(
  152. norm_cfg, planes, postfix=2)
  153. self.add_module(self.norm2_name, norm2)
  154. self.relu = nn.ReLU(inplace=True)
  155. self.downsample = downsample
  156. self.stride = stride
  157. self.dilation = dilation
  158. self.with_cp = with_cp
  159. @property
  160. def norm1(self) -> Optional[BaseModule]:
  161. """nn.Module: normalization layer after the first convolution layer"""
  162. return getattr(self, self.norm1_name) if self.with_norm else None
  163. @property
  164. def norm2(self) -> Optional[BaseModule]:
  165. """nn.Module: normalization layer after the second convolution layer"""
  166. return getattr(self, self.norm2_name) if self.with_norm else None
  167. def forward(self, x: Tensor) -> Tensor:
  168. """Forward function for SimplifiedBasicBlock."""
  169. identity = x
  170. out = self.conv1(x)
  171. if self.with_norm:
  172. out = self.norm1(out)
  173. out = self.relu(out)
  174. out = self.conv2(out)
  175. if self.with_norm:
  176. out = self.norm2(out)
  177. if self.downsample is not None:
  178. identity = self.downsample(x)
  179. out += identity
  180. return out