repvgg.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import types
  2. from typing import Dict, Optional
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
  7. from mmengine.model import BaseModule
  8. from torch import Tensor
  9. from mmdet.utils import OptConfigType
  10. from mmdet.registry import MODELS
  11. @MODELS.register_module()
  12. class RepVGGBlock(BaseModule):
  13. """A block in RepVGG architecture, supporting optional normalization in the
  14. identity branch.
  15. This block consists of 3x3 and 1x1 convolutions, with an optional identity
  16. shortcut branch that includes normalization.
  17. Args:
  18. in_channels (int): The input channels of the block.
  19. out_channels (int): The output channels of the block.
  20. stride (int): The stride of the block. Defaults to 1.
  21. padding (int): The padding of the block. Defaults to 1.
  22. dilation (int): The dilation of the block. Defaults to 1.
  23. groups (int): The groups of the block. Defaults to 1.
  24. padding_mode (str): The padding mode of the block. Defaults to 'zeros'.
  25. norm_cfg (dict): The config dict for normalization layers.
  26. Defaults to dict(type='BN').
  27. act_cfg (dict): The config dict for activation layers.
  28. Defaults to dict(type='ReLU').
  29. without_branch_norm (bool): Whether to skip branch_norm.
  30. Defaults to True.
  31. init_cfg (dict): The config dict for initialization. Defaults to None.
  32. """
  33. def __init__(self,
  34. in_channels: int,
  35. out_channels: int,
  36. stride: int = 1,
  37. padding: int = 1,
  38. dilation: int = 1,
  39. groups: int = 1,
  40. norm_cfg: OptConfigType = dict(type='BN', momentum=0.03, eps=0.001),
  41. act_cfg: OptConfigType = dict(type='ReLU',inplace=True),
  42. without_branch_norm: bool = True,
  43. init_cfg: OptConfigType = None):
  44. super(RepVGGBlock, self).__init__(init_cfg)
  45. self.in_channels = in_channels
  46. self.out_channels = out_channels
  47. self.stride = stride
  48. self.padding = padding
  49. self.dilation = dilation
  50. self.groups = groups
  51. self.norm_cfg = norm_cfg
  52. self.act_cfg = act_cfg
  53. # judge if input shape and output shape are the same.
  54. # If true, add a normalized identity shortcut.
  55. self.branch_norm = None
  56. if out_channels == in_channels and stride == 1 and \
  57. padding == dilation and not without_branch_norm:
  58. self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
  59. self.branch_3x3 = ConvModule(
  60. self.in_channels,
  61. self.out_channels,
  62. 3,
  63. stride=self.stride,
  64. padding=self.padding,
  65. groups=self.groups,
  66. dilation=self.dilation,
  67. norm_cfg=self.norm_cfg,
  68. act_cfg=None)
  69. self.branch_1x1 = ConvModule(
  70. self.in_channels,
  71. self.out_channels,
  72. 1,
  73. stride=self.stride,
  74. groups=self.groups,
  75. norm_cfg=self.norm_cfg,
  76. act_cfg=None)
  77. self.act = build_activation_layer(act_cfg)
  78. def forward(self, x: Tensor) -> Tensor:
  79. """Forward pass through the RepVGG block.
  80. The output is the sum of 3x3 and 1x1 convolution outputs,
  81. along with the normalized identity branch output, followed by
  82. activation.
  83. Args:
  84. x (Tensor): The input tensor.
  85. Returns:
  86. Tensor: The output tensor.
  87. """
  88. if self.branch_norm is None:
  89. branch_norm_out = 0
  90. else:
  91. branch_norm_out = self.branch_norm(x)
  92. out = self.branch_3x3(x) + self.branch_1x1(x) + branch_norm_out
  93. out = self.act(out)
  94. return out
  95. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  96. """Pad 1x1 tensor to 3x3.
  97. Args:
  98. kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
  99. Returns:
  100. Tensor: 3x3 kernel after padded.
  101. """
  102. if kernel1x1 is None:
  103. return 0
  104. else:
  105. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  106. def _fuse_bn_tensor(self, branch: nn.Module) -> Tensor:
  107. """Derives the equivalent kernel and bias of a specific branch layer.
  108. Args:
  109. branch (nn.Module): The layer that needs to be equivalently
  110. transformed, which can be nn.Sequential or nn.Batchnorm2d
  111. Returns:
  112. tuple: Equivalent kernel and bias
  113. """
  114. if branch is None:
  115. return 0, 0
  116. if isinstance(branch, ConvModule):
  117. kernel = branch.conv.weight
  118. running_mean = branch.bn.running_mean
  119. running_var = branch.bn.running_var
  120. gamma = branch.bn.weight
  121. beta = branch.bn.bias
  122. eps = branch.bn.eps
  123. else:
  124. assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d))
  125. if not hasattr(self, 'id_tensor'):
  126. input_dim = self.in_channels // self.groups
  127. kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
  128. dtype=np.float32)
  129. for i in range(self.in_channels):
  130. kernel_value[i, i % input_dim, 1, 1] = 1
  131. self.id_tensor = torch.from_numpy(kernel_value).to(
  132. branch.weight.device)
  133. kernel = self.id_tensor
  134. running_mean = branch.running_mean
  135. running_var = branch.running_var
  136. gamma = branch.weight
  137. beta = branch.bias
  138. eps = branch.eps
  139. std = (running_var + eps).sqrt()
  140. t = (gamma / std).reshape(-1, 1, 1, 1)
  141. return kernel * t, beta - running_mean * gamma / std
  142. def get_equivalent_kernel_bias(self):
  143. """Derives the equivalent kernel and bias in a differentiable way.
  144. Returns:
  145. tuple: Equivalent kernel and bias
  146. """
  147. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3)
  148. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1)
  149. kernelid, biasid = (0, 0) if self.branch_norm is None else \
  150. self._fuse_bn_tensor(self.branch_norm)
  151. return (kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
  152. bias3x3 + bias1x1 + biasid)
  153. def switch_to_deploy(self, test_cfg: Optional[Dict] = None):
  154. """Switches the block to deployment mode.
  155. In deployment mode, the block uses a single convolution operation
  156. derived from the equivalent kernel and bias, replacing the original
  157. branches. This reduces computational complexity during inference.
  158. """
  159. if getattr(self, 'deploy', False):
  160. return
  161. kernel, bias = self.get_equivalent_kernel_bias()
  162. self.conv_reparam = nn.Conv2d(
  163. in_channels=self.branch_3x3.conv.in_channels,
  164. out_channels=self.branch_3x3.conv.out_channels,
  165. kernel_size=self.branch_3x3.conv.kernel_size,
  166. stride=self.branch_3x3.conv.stride,
  167. padding=self.branch_3x3.conv.padding,
  168. dilation=self.branch_3x3.conv.dilation,
  169. groups=self.branch_3x3.conv.groups,
  170. bias=True)
  171. self.conv_reparam.weight.data = kernel
  172. self.conv_reparam.bias.data = bias
  173. for para in self.parameters():
  174. para.detach_()
  175. self.__delattr__('branch_3x3')
  176. self.__delattr__('branch_1x1')
  177. if hasattr(self, 'branch_norm'):
  178. self.__delattr__('branch_norm')
  179. def _forward(self, x):
  180. return self.act(self.conv_reparam(x))
  181. self.forward = types.MethodType(_forward, self)
  182. self.deploy = True