rep_vgg.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.utils.checkpoint as cp
  5. from torch.nn.modules.batchnorm import _BatchNorm
  6. from mmcv.cnn import build_conv_layer, build_norm_layer
  7. from mmengine.model import BaseModule, Sequential
  8. from mmdet.registry import MODELS
  9. from ..layers import RepVGGBlock
  10. @MODELS.register_module()
  11. class RepVGG(BaseModule):
  12. arch_settings = {
  13. 'A0':
  14. dict(
  15. num_blocks=[3, 3, 3, 3],
  16. width_factor=[1, 1, 1, 1],
  17. se_cfg=None,
  18. stem_channels=64),
  19. 'B0':
  20. dict(
  21. num_blocks=[4, 6, 16, 1],
  22. width_factor=[1, 1, 1, 2.5],
  23. se_cfg=None,
  24. stem_channels=64),
  25. }
  26. def __init__(self,
  27. arch,
  28. in_channels=1,
  29. base_channels=64,
  30. out_indices=(1, 2, 3),
  31. strides=(1, 2, 2, 2),
  32. dilations=(1, 1, 1, 1),
  33. frozen_stages=-1,
  34. conv_cfg=None,
  35. norm_cfg=dict(type='BN',requires_grad=True),
  36. act_cfg=dict(type='ReLU'),
  37. norm_eval=False,
  38. deploy=False,
  39. init_cfg=[
  40. dict(type='Kaiming', layer=['Conv2d']),
  41. dict(
  42. type='Constant',
  43. val=1,
  44. layer=['_BatchNorm', 'GroupNorm'])
  45. ]):
  46. super(RepVGG,self).__init__(init_cfg)
  47. if isinstance(arch, str):
  48. assert arch in self.arch_settings, \
  49. f'"arch": "{arch}" is not one of the arch_settings'
  50. arch = self.arch_settings[arch]
  51. elif not isinstance(arch, dict):
  52. raise TypeError('Expect "arch" to be either a string '
  53. f'or a dict, got {type(arch)}')
  54. assert len(arch['num_blocks']) == len(
  55. arch['width_factor']) == len(strides) == len(dilations)
  56. assert max(out_indices) < len(arch['num_blocks'])
  57. self.base_channels = base_channels
  58. self.arch = arch
  59. self.in_channels = in_channels
  60. self.out_indices = out_indices
  61. self.strides = strides
  62. self.dilations = dilations
  63. self.deploy = deploy
  64. self.frozen_stages = frozen_stages
  65. self.conv_cfg = conv_cfg
  66. self.norm_cfg = norm_cfg
  67. self.act_cfg = act_cfg
  68. self.norm_eval = norm_eval
  69. stem_channels=self.arch['stem_channels']
  70. #we use deep stem same as resnet, we could try single repvgg block as stem later
  71. self.stem=nn.Sequential(
  72. build_conv_layer(
  73. self.conv_cfg,
  74. in_channels,
  75. stem_channels // 2,
  76. kernel_size=3,
  77. stride=2,
  78. padding=1,
  79. bias=False),
  80. build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
  81. nn.ReLU(inplace=True),
  82. build_conv_layer(
  83. self.conv_cfg,
  84. stem_channels // 2,
  85. stem_channels // 2,
  86. kernel_size=3,
  87. stride=1,
  88. padding=1,
  89. bias=False),
  90. build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
  91. nn.ReLU(inplace=True),
  92. build_conv_layer(
  93. self.conv_cfg,
  94. stem_channels // 2,
  95. stem_channels,
  96. kernel_size=3,
  97. stride=1,
  98. padding=1,
  99. bias=False),
  100. build_norm_layer(self.norm_cfg, stem_channels)[1],
  101. nn.ReLU(inplace=True))
  102. next_create_block_idx = 1
  103. self.stages = []
  104. channels=self.base_channels
  105. for i in range(len(arch['num_blocks'])):
  106. num_blocks = self.arch['num_blocks'][i]
  107. stride = self.strides[i]
  108. dilation = self.dilations[i]
  109. out_channels = int(self.base_channels * 2**i *
  110. self.arch['width_factor'][i])
  111. stage, next_create_block_idx = self._make_stage(
  112. channels, out_channels, num_blocks, stride, dilation,
  113. next_create_block_idx, init_cfg)
  114. stage_name = f'stage_{i + 1}'
  115. self.add_module(stage_name, stage)
  116. self.stages.append(stage_name)
  117. channels = out_channels
  118. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  119. def _make_stage(self, in_channels, out_channels, num_blocks, stride,
  120. dilation, next_create_block_idx, init_cfg):
  121. strides = [stride] + [1] * (num_blocks - 1)
  122. dilations = [dilation] * num_blocks
  123. blocks = []
  124. for i in range(num_blocks):
  125. blocks.append(
  126. RepVGGBlock(
  127. in_channels,
  128. out_channels,
  129. stride=strides[i],
  130. padding=dilations[i],
  131. dilation=dilations[i],
  132. norm_cfg=self.norm_cfg,
  133. act_cfg=self.act_cfg,
  134. init_cfg=init_cfg))
  135. in_channels = out_channels
  136. next_create_block_idx += 1
  137. return Sequential(*blocks), next_create_block_idx
  138. def forward(self, x):
  139. x = self.stem(x)
  140. x = self.maxpool(x)
  141. outs = []
  142. for i, stage_name in enumerate(self.stages):
  143. stage = getattr(self, stage_name)
  144. x = stage(x)
  145. # if i + 1 == len(self.stages):
  146. # x = self.ppf(x)
  147. if i in self.out_indices:
  148. outs.append(x)
  149. return tuple(outs)
  150. def train(self, mode=True):
  151. """Convert the model into training mode while keep normalization layer
  152. freezed."""
  153. super(RepVGG, self).train(mode)
  154. self._freeze_stages()
  155. if mode and self.norm_eval:
  156. for m in self.modules():
  157. # trick: eval have effect on BatchNorm only
  158. if isinstance(m, _BatchNorm):
  159. m.eval()
  160. def _freeze_stages(self):
  161. if self.frozen_stages >= 0:
  162. self.stem.eval()
  163. for param in self.stem.parameters():
  164. param.requires_grad = False
  165. for i in range(self.frozen_stages):
  166. stage = getattr(self, f'stage_{i+1}')
  167. stage.eval()
  168. for param in stage.parameters():
  169. param.requires_grad = False