from typing import List, Union,Tuple import torch import torch.nn as nn from mmcv.cnn import build_plugin_layer from mmdet.utils import ConfigType, OptMultiConfig from mmcv.cnn import ConvModule from ..utils import make_divisible, make_round from ..layers import CSPLayerWithTwoConv, SPPFBottleneck from mmengine.model import BaseModule from mmengine.registry import MODELS from torch.nn.modules.batchnorm import _BatchNorm @MODELS.register_module() class YOLOv8CSPDarknet(BaseModule): """CSP-Darknet backbone used in YOLOv8. Args: arch (str): Architecture of CSP-Darknet, from {P5}. Defaults to P5. last_stage_out_channels (int): Final layer output channel. Defaults to 1024. plugins (list[dict]): List of plugins for stages, each dict contains: - cfg (dict, required): Cfg dict to build plugin. - stages (tuple[bool], optional): Stages to apply plugin, length should be same as 'num_stages'. deepen_factor (float): Depth multiplier, multiply number of blocks in CSP layer by this amount. Defaults to 1.0. widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Defaults to 1.0. input_channels (int): Number of input image channels. Defaults to: 3. out_indices (Tuple[int]): Output from which stages. Defaults to (2, 3, 4). frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to dict(type='BN', requires_grad=True). act_cfg (dict): Config dict for activation layer. Defaults to dict(type='SiLU', inplace=True). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. init_cfg (Union[dict,list[dict]], optional): Initialization config dict. Defaults to None. Example: >>> from mmyolo.models import YOLOv8CSPDarknet >>> import torch >>> model = YOLOv8CSPDarknet() >>> model.eval() >>> inputs = torch.rand(1, 3, 416, 416) >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) ... (1, 256, 52, 52) (1, 512, 26, 26) (1, 1024, 13, 13) """ # From left to right: # in_channels, out_channels, num_blocks, add_identity, use_spp # the final out_channels will be set according to the param. arch_settings = [[64, 128, 3, True, False], [128, 256, 6, True, False], [256, 512, 6, True, False], [512, None, 3, True, True]] def __init__(self, last_stage_out_channels: int = 1024, plugins: Union[dict, List[dict]] = None, deepen_factor: float = 1.0, widen_factor: float = 1.0, input_channels: int = 3, out_indices: List[int] = (2, 3, 4), frozen_stages: int = -1, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), norm_eval: bool = False, init_cfg: OptMultiConfig = None): super().__init__() self.arch_settings[-1][1] = last_stage_out_channels self.num_stages = len(self.arch_settings) self.input_channels = input_channels self.out_indices = out_indices self.frozen_stages = frozen_stages self.widen_factor = widen_factor self.deepen_factor = deepen_factor self.norm_eval = norm_eval self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.plugins = plugins self.stem = self.build_stem_layer() self.layers = ['stem'] for idx, setting in enumerate(self.arch_settings): stage = [] stage += self.build_stage_layer(idx, setting) if plugins is not None: stage += self.make_stage_plugins(plugins, idx, setting) self.add_module(f'stage{idx + 1}', nn.Sequential(*stage)) self.layers.append(f'stage{idx + 1}') # self.arch_settings[arch], # deepen_factor, # widen_factor, # input_channels=input_channels, # out_indices=out_indices, # plugins=plugins, # frozen_stages=frozen_stages, # norm_cfg=norm_cfg, # act_cfg=act_cfg, # norm_eval=norm_eval, # init_cfg=init_cfg def build_stem_layer(self) -> nn.Module: """Build a stem layer.""" return ConvModule( self.input_channels, make_divisible(self.arch_settings[0][0], self.widen_factor), kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def build_stage_layer(self, stage_idx: int, setting: list) -> list: """Build a stage layer. Args: stage_idx (int): The index of a stage layer. setting (list): The architecture setting of a stage layer. """ in_channels, out_channels, num_blocks, add_identity, use_spp = setting in_channels = make_divisible(in_channels, self.widen_factor) out_channels = make_divisible(out_channels, self.widen_factor) print(out_channels) num_blocks = make_round(num_blocks, self.deepen_factor) stage = [] conv_layer = ConvModule( in_channels, out_channels, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) stage.append(conv_layer) csp_layer = CSPLayerWithTwoConv( out_channels, out_channels, num_blocks=num_blocks, add_identity=add_identity, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) stage.append(csp_layer) if use_spp: spp = SPPFBottleneck( out_channels, out_channels, kernel_sizes=5, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) stage.append(spp) return stage def make_stage_plugins(self, plugins, stage_idx, setting): """Make plugins for backbone ``stage_idx`` th stage. Currently we support to insert ``context_block``, ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block`` into the backbone. An example of plugins format could be: Examples: >>> plugins=[ ... dict(cfg=dict(type='xxx', arg1='xxx'), ... stages=(False, True, True, True)), ... dict(cfg=dict(type='yyy'), ... stages=(True, True, True, True)), ... ] >>> model = YOLOv5CSPDarknet() >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting) >>> assert len(stage_plugins) == 1 Suppose ``stage_idx=0``, the structure of blocks in the stage would be: .. code-block:: none conv1 -> conv2 -> conv3 -> yyy Suppose ``stage_idx=1``, the structure of blocks in the stage would be: .. code-block:: none conv1 -> conv2 -> conv3 -> xxx -> yyy Args: plugins (list[dict]): List of plugins cfg to build. The postfix is required if multiple same type plugins are inserted. stage_idx (int): Index of stage to build If stages is missing, the plugin would be applied to all stages. setting (list): The architecture setting of a stage layer. Returns: list[nn.Module]: Plugins for current stage """ # TODO: It is not general enough to support any channel and needs # to be refactored in_channels = int(setting[1] * self.widen_factor) plugin_layers = [] for plugin in plugins: plugin = plugin.copy() stages = plugin.pop('stages', None) assert stages is None or len(stages) == self.num_stages if stages is None or stages[stage_idx]: name, layer = build_plugin_layer( plugin['cfg'], in_channels=in_channels) plugin_layers.append(layer) return plugin_layers def init_weights(self): """Initialize the parameters.""" if self.init_cfg is None: for m in self.modules(): if isinstance(m, torch.nn.Conv2d): # In order to be consistent with the source code, # reset the Conv2d initialization parameters m.reset_parameters() else: super().init_weights() def make_stage_plugins(self, plugins, stage_idx, setting): """Make plugins for backbone ``stage_idx`` th stage. Currently we support to insert ``context_block``, ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block`` into the backbone. An example of plugins format could be: Examples: >>> plugins=[ ... dict(cfg=dict(type='xxx', arg1='xxx'), ... stages=(False, True, True, True)), ... dict(cfg=dict(type='yyy'), ... stages=(True, True, True, True)), ... ] >>> model = YOLOv5CSPDarknet() >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting) >>> assert len(stage_plugins) == 1 Suppose ``stage_idx=0``, the structure of blocks in the stage would be: .. code-block:: none conv1 -> conv2 -> conv3 -> yyy Suppose ``stage_idx=1``, the structure of blocks in the stage would be: .. code-block:: none conv1 -> conv2 -> conv3 -> xxx -> yyy Args: plugins (list[dict]): List of plugins cfg to build. The postfix is required if multiple same type plugins are inserted. stage_idx (int): Index of stage to build If stages is missing, the plugin would be applied to all stages. setting (list): The architecture setting of a stage layer. Returns: list[nn.Module]: Plugins for current stage """ # TODO: It is not general enough to support any channel and needs # to be refactored in_channels = int(setting[1] * self.widen_factor) plugin_layers = [] for plugin in plugins: plugin = plugin.copy() stages = plugin.pop('stages', None) assert stages is None or len(stages) == self.num_stages if stages is None or stages[stage_idx]: layer = build_plugin_layer( plugin['cfg'], in_channels=in_channels)[1] plugin_layers.append(layer) return plugin_layers def forward(self, x: torch.Tensor) -> tuple: """Forward batch_inputs from the data_preprocessor.""" outs = [] for i, layer_name in enumerate(self.layers): layer = getattr(self, layer_name) x = layer(x) if i in self.out_indices: outs.append(x) return tuple(outs)