import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from torch.nn.modules.batchnorm import _BatchNorm from mmcv.cnn import build_conv_layer, build_norm_layer from mmengine.model import BaseModule, Sequential from mmdet.registry import MODELS from ..layers import RepVGGBlock @MODELS.register_module() class RepVGG(BaseModule): arch_settings = { 'A0': dict( num_blocks=[3, 3, 3, 3], width_factor=[1, 1, 1, 1], se_cfg=None, stem_channels=64), 'B0': dict( num_blocks=[4, 6, 16, 1], width_factor=[1, 1, 1, 2.5], se_cfg=None, stem_channels=64), } def __init__(self, arch, in_channels=1, base_channels=64, out_indices=(1, 2, 3), strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), frozen_stages=-1, conv_cfg=None, norm_cfg=dict(type='BN',requires_grad=True), act_cfg=dict(type='ReLU'), norm_eval=False, deploy=False, init_cfg=[ dict(type='Kaiming', layer=['Conv2d']), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ]): super(RepVGG,self).__init__(init_cfg) if isinstance(arch, str): assert arch in self.arch_settings, \ f'"arch": "{arch}" is not one of the arch_settings' arch = self.arch_settings[arch] elif not isinstance(arch, dict): raise TypeError('Expect "arch" to be either a string ' f'or a dict, got {type(arch)}') assert len(arch['num_blocks']) == len( arch['width_factor']) == len(strides) == len(dilations) assert max(out_indices) < len(arch['num_blocks']) self.base_channels = base_channels self.arch = arch self.in_channels = in_channels self.out_indices = out_indices self.strides = strides self.dilations = dilations self.deploy = deploy self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval stem_channels=self.arch['stem_channels'] #we use deep stem same as resnet, we could try single repvgg block as stem later self.stem=nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( self.conv_cfg, stem_channels // 2, stem_channels // 2, kernel_size=3, stride=1, padding=1, bias=False), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( self.conv_cfg, stem_channels // 2, stem_channels, kernel_size=3, stride=1, padding=1, bias=False), build_norm_layer(self.norm_cfg, stem_channels)[1], nn.ReLU(inplace=True)) next_create_block_idx = 1 self.stages = [] channels=self.base_channels for i in range(len(arch['num_blocks'])): num_blocks = self.arch['num_blocks'][i] stride = self.strides[i] dilation = self.dilations[i] out_channels = int(self.base_channels * 2**i * self.arch['width_factor'][i]) stage, next_create_block_idx = self._make_stage( channels, out_channels, num_blocks, stride, dilation, next_create_block_idx, init_cfg) stage_name = f'stage_{i + 1}' self.add_module(stage_name, stage) self.stages.append(stage_name) channels = out_channels self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def _make_stage(self, in_channels, out_channels, num_blocks, stride, dilation, next_create_block_idx, init_cfg): strides = [stride] + [1] * (num_blocks - 1) dilations = [dilation] * num_blocks blocks = [] for i in range(num_blocks): blocks.append( RepVGGBlock( in_channels, out_channels, stride=strides[i], padding=dilations[i], dilation=dilations[i], norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, init_cfg=init_cfg)) in_channels = out_channels next_create_block_idx += 1 return Sequential(*blocks), next_create_block_idx def forward(self, x): x = self.stem(x) x = self.maxpool(x) outs = [] for i, stage_name in enumerate(self.stages): stage = getattr(self, stage_name) x = stage(x) # if i + 1 == len(self.stages): # x = self.ppf(x) if i in self.out_indices: outs.append(x) return tuple(outs) def train(self, mode=True): """Convert the model into training mode while keep normalization layer freezed.""" super(RepVGG, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() def _freeze_stages(self): if self.frozen_stages >= 0: self.stem.eval() for param in self.stem.parameters(): param.requires_grad = False for i in range(self.frozen_stages): stage = getattr(self, f'stage_{i+1}') stage.eval() for param in stage.parameters(): param.requires_grad = False