123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as cp
- from torch.nn.modules.batchnorm import _BatchNorm
- from mmcv.cnn import build_conv_layer, build_norm_layer
- from mmengine.model import BaseModule, Sequential
- from mmdet.registry import MODELS
- from ..layers import RepVGGBlock
- @MODELS.register_module()
- class RepVGG(BaseModule):
- arch_settings = {
- 'A0':
- dict(
- num_blocks=[3, 3, 3, 3],
- width_factor=[1, 1, 1, 1],
- se_cfg=None,
- stem_channels=64),
- 'B0':
- dict(
- num_blocks=[4, 6, 16, 1],
- width_factor=[1, 1, 1, 2.5],
- se_cfg=None,
- stem_channels=64),
- }
- def __init__(self,
- arch,
- in_channels=1,
- base_channels=64,
- out_indices=(1, 2, 3),
- strides=(1, 2, 2, 2),
- dilations=(1, 1, 1, 1),
- frozen_stages=-1,
- conv_cfg=None,
- norm_cfg=dict(type='BN',requires_grad=True),
- act_cfg=dict(type='ReLU'),
- norm_eval=False,
- deploy=False,
- init_cfg=[
- dict(type='Kaiming', layer=['Conv2d']),
- dict(
- type='Constant',
- val=1,
- layer=['_BatchNorm', 'GroupNorm'])
- ]):
- super(RepVGG,self).__init__(init_cfg)
- if isinstance(arch, str):
- assert arch in self.arch_settings, \
- f'"arch": "{arch}" is not one of the arch_settings'
- arch = self.arch_settings[arch]
- elif not isinstance(arch, dict):
- raise TypeError('Expect "arch" to be either a string '
- f'or a dict, got {type(arch)}')
- assert len(arch['num_blocks']) == len(
- arch['width_factor']) == len(strides) == len(dilations)
- assert max(out_indices) < len(arch['num_blocks'])
- self.base_channels = base_channels
- self.arch = arch
- self.in_channels = in_channels
- self.out_indices = out_indices
- self.strides = strides
- self.dilations = dilations
- self.deploy = deploy
- self.frozen_stages = frozen_stages
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.act_cfg = act_cfg
- self.norm_eval = norm_eval
- stem_channels=self.arch['stem_channels']
- #we use deep stem same as resnet, we could try single repvgg block as stem later
- self.stem=nn.Sequential(
- build_conv_layer(
- self.conv_cfg,
- in_channels,
- stem_channels // 2,
- kernel_size=3,
- stride=2,
- padding=1,
- bias=False),
- build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
- nn.ReLU(inplace=True),
- build_conv_layer(
- self.conv_cfg,
- stem_channels // 2,
- stem_channels // 2,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False),
- build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
- nn.ReLU(inplace=True),
- build_conv_layer(
- self.conv_cfg,
- stem_channels // 2,
- stem_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False),
- build_norm_layer(self.norm_cfg, stem_channels)[1],
- nn.ReLU(inplace=True))
- next_create_block_idx = 1
- self.stages = []
- channels=self.base_channels
- for i in range(len(arch['num_blocks'])):
- num_blocks = self.arch['num_blocks'][i]
- stride = self.strides[i]
- dilation = self.dilations[i]
- out_channels = int(self.base_channels * 2**i *
- self.arch['width_factor'][i])
- stage, next_create_block_idx = self._make_stage(
- channels, out_channels, num_blocks, stride, dilation,
- next_create_block_idx, init_cfg)
- stage_name = f'stage_{i + 1}'
- self.add_module(stage_name, stage)
- self.stages.append(stage_name)
- channels = out_channels
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-
-
- def _make_stage(self, in_channels, out_channels, num_blocks, stride,
- dilation, next_create_block_idx, init_cfg):
- strides = [stride] + [1] * (num_blocks - 1)
- dilations = [dilation] * num_blocks
- blocks = []
- for i in range(num_blocks):
- blocks.append(
- RepVGGBlock(
- in_channels,
- out_channels,
- stride=strides[i],
- padding=dilations[i],
- dilation=dilations[i],
- norm_cfg=self.norm_cfg,
- act_cfg=self.act_cfg,
- init_cfg=init_cfg))
- in_channels = out_channels
- next_create_block_idx += 1
- return Sequential(*blocks), next_create_block_idx
- def forward(self, x):
- x = self.stem(x)
- x = self.maxpool(x)
- outs = []
- for i, stage_name in enumerate(self.stages):
- stage = getattr(self, stage_name)
- x = stage(x)
- # if i + 1 == len(self.stages):
- # x = self.ppf(x)
- if i in self.out_indices:
- outs.append(x)
- return tuple(outs)
- def train(self, mode=True):
- """Convert the model into training mode while keep normalization layer
- freezed."""
- super(RepVGG, self).train(mode)
- self._freeze_stages()
- if mode and self.norm_eval:
- for m in self.modules():
- # trick: eval have effect on BatchNorm only
- if isinstance(m, _BatchNorm):
- m.eval()
- def _freeze_stages(self):
- if self.frozen_stages >= 0:
- self.stem.eval()
- for param in self.stem.parameters():
- param.requires_grad = False
- for i in range(self.frozen_stages):
- stage = getattr(self, f'stage_{i+1}')
- stage.eval()
- for param in stage.parameters():
- param.requires_grad = False
|