123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- from typing import List, Union,Tuple
- import torch
- import torch.nn as nn
- from mmcv.cnn import build_plugin_layer
- from mmdet.utils import ConfigType, OptMultiConfig
- from mmcv.cnn import ConvModule
- from ..utils import make_divisible, make_round
- from ..layers import CSPLayerWithTwoConv, SPPFBottleneck
- from mmengine.model import BaseModule
- from mmengine.registry import MODELS
- from torch.nn.modules.batchnorm import _BatchNorm
- @MODELS.register_module()
- class YOLOv8CSPDarknet(BaseModule):
- """CSP-Darknet backbone used in YOLOv8.
- Args:
- arch (str): Architecture of CSP-Darknet, from {P5}.
- Defaults to P5.
- last_stage_out_channels (int): Final layer output channel.
- Defaults to 1024.
- plugins (list[dict]): List of plugins for stages, each dict contains:
- - cfg (dict, required): Cfg dict to build plugin.
- - stages (tuple[bool], optional): Stages to apply plugin, length
- should be same as 'num_stages'.
- deepen_factor (float): Depth multiplier, multiply number of
- blocks in CSP layer by this amount. Defaults to 1.0.
- widen_factor (float): Width multiplier, multiply number of
- channels in each layer by this amount. Defaults to 1.0.
- input_channels (int): Number of input image channels. Defaults to: 3.
- out_indices (Tuple[int]): Output from which stages.
- Defaults to (2, 3, 4).
- frozen_stages (int): Stages to be frozen (stop grad and set eval
- mode). -1 means not freezing any parameters. Defaults to -1.
- norm_cfg (dict): Dictionary to construct and config norm layer.
- Defaults to dict(type='BN', requires_grad=True).
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='SiLU', inplace=True).
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
- freeze running stats (mean and var). Note: Effect on Batch Norm
- and its variants only. Defaults to False.
- init_cfg (Union[dict,list[dict]], optional): Initialization config
- dict. Defaults to None.
- Example:
- >>> from mmyolo.models import YOLOv8CSPDarknet
- >>> import torch
- >>> model = YOLOv8CSPDarknet()
- >>> model.eval()
- >>> inputs = torch.rand(1, 3, 416, 416)
- >>> level_outputs = model(inputs)
- >>> for level_out in level_outputs:
- ... print(tuple(level_out.shape))
- ...
- (1, 256, 52, 52)
- (1, 512, 26, 26)
- (1, 1024, 13, 13)
- """
- # From left to right:
- # in_channels, out_channels, num_blocks, add_identity, use_spp
- # the final out_channels will be set according to the param.
- arch_settings = [[64, 128, 3, True, False], [128, 256, 6, True, False],
- [256, 512, 6, True, False], [512, None, 3, True, True]]
- def __init__(self,
- last_stage_out_channels: int = 1024,
- plugins: Union[dict, List[dict]] = None,
- deepen_factor: float = 1.0,
- widen_factor: float = 1.0,
- input_channels: int = 3,
- out_indices: List[int] = (2, 3, 4),
- frozen_stages: int = -1,
- norm_cfg: ConfigType = dict(
- type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='SiLU', inplace=True),
- norm_eval: bool = False,
- init_cfg: OptMultiConfig = None):
- super().__init__()
- self.arch_settings[-1][1] = last_stage_out_channels
- self.num_stages = len(self.arch_settings)
- self.input_channels = input_channels
- self.out_indices = out_indices
- self.frozen_stages = frozen_stages
- self.widen_factor = widen_factor
- self.deepen_factor = deepen_factor
- self.norm_eval = norm_eval
- self.norm_cfg = norm_cfg
- self.act_cfg = act_cfg
- self.plugins = plugins
- self.stem = self.build_stem_layer()
- self.layers = ['stem']
- for idx, setting in enumerate(self.arch_settings):
- stage = []
- stage += self.build_stage_layer(idx, setting)
- if plugins is not None:
- stage += self.make_stage_plugins(plugins, idx, setting)
- self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
- self.layers.append(f'stage{idx + 1}')
- # self.arch_settings[arch],
- # deepen_factor,
- # widen_factor,
- # input_channels=input_channels,
- # out_indices=out_indices,
- # plugins=plugins,
- # frozen_stages=frozen_stages,
- # norm_cfg=norm_cfg,
- # act_cfg=act_cfg,
- # norm_eval=norm_eval,
- # init_cfg=init_cfg
- def build_stem_layer(self) -> nn.Module:
- """Build a stem layer."""
- return ConvModule(
- self.input_channels,
- make_divisible(self.arch_settings[0][0], self.widen_factor),
- kernel_size=3,
- stride=2,
- padding=1,
- norm_cfg=self.norm_cfg,
- act_cfg=self.act_cfg)
- def build_stage_layer(self, stage_idx: int, setting: list) -> list:
- """Build a stage layer.
- Args:
- stage_idx (int): The index of a stage layer.
- setting (list): The architecture setting of a stage layer.
- """
- in_channels, out_channels, num_blocks, add_identity, use_spp = setting
- in_channels = make_divisible(in_channels, self.widen_factor)
- out_channels = make_divisible(out_channels, self.widen_factor)
- print(out_channels)
- num_blocks = make_round(num_blocks, self.deepen_factor)
- stage = []
- conv_layer = ConvModule(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=2,
- padding=1,
- norm_cfg=self.norm_cfg,
- act_cfg=self.act_cfg)
- stage.append(conv_layer)
- csp_layer = CSPLayerWithTwoConv(
- out_channels,
- out_channels,
- num_blocks=num_blocks,
- add_identity=add_identity,
- norm_cfg=self.norm_cfg,
- act_cfg=self.act_cfg)
- stage.append(csp_layer)
- if use_spp:
- spp = SPPFBottleneck(
- out_channels,
- out_channels,
- kernel_sizes=5,
- norm_cfg=self.norm_cfg,
- act_cfg=self.act_cfg)
- stage.append(spp)
- return stage
- def make_stage_plugins(self, plugins, stage_idx, setting):
- """Make plugins for backbone ``stage_idx`` th stage.
- Currently we support to insert ``context_block``,
- ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
- into the backbone.
- An example of plugins format could be:
- Examples:
- >>> plugins=[
- ... dict(cfg=dict(type='xxx', arg1='xxx'),
- ... stages=(False, True, True, True)),
- ... dict(cfg=dict(type='yyy'),
- ... stages=(True, True, True, True)),
- ... ]
- >>> model = YOLOv5CSPDarknet()
- >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
- >>> assert len(stage_plugins) == 1
- Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
- .. code-block:: none
- conv1 -> conv2 -> conv3 -> yyy
- Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
- .. code-block:: none
- conv1 -> conv2 -> conv3 -> xxx -> yyy
- Args:
- plugins (list[dict]): List of plugins cfg to build. The postfix is
- required if multiple same type plugins are inserted.
- stage_idx (int): Index of stage to build
- If stages is missing, the plugin would be applied to all
- stages.
- setting (list): The architecture setting of a stage layer.
- Returns:
- list[nn.Module]: Plugins for current stage
- """
- # TODO: It is not general enough to support any channel and needs
- # to be refactored
- in_channels = int(setting[1] * self.widen_factor)
- plugin_layers = []
- for plugin in plugins:
- plugin = plugin.copy()
- stages = plugin.pop('stages', None)
- assert stages is None or len(stages) == self.num_stages
- if stages is None or stages[stage_idx]:
- name, layer = build_plugin_layer(
- plugin['cfg'], in_channels=in_channels)
- plugin_layers.append(layer)
- return plugin_layers
- def init_weights(self):
- """Initialize the parameters."""
- if self.init_cfg is None:
- for m in self.modules():
- if isinstance(m, torch.nn.Conv2d):
- # In order to be consistent with the source code,
- # reset the Conv2d initialization parameters
- m.reset_parameters()
- else:
- super().init_weights()
- def make_stage_plugins(self, plugins, stage_idx, setting):
- """Make plugins for backbone ``stage_idx`` th stage.
- Currently we support to insert ``context_block``,
- ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
- into the backbone.
- An example of plugins format could be:
- Examples:
- >>> plugins=[
- ... dict(cfg=dict(type='xxx', arg1='xxx'),
- ... stages=(False, True, True, True)),
- ... dict(cfg=dict(type='yyy'),
- ... stages=(True, True, True, True)),
- ... ]
- >>> model = YOLOv5CSPDarknet()
- >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
- >>> assert len(stage_plugins) == 1
- Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
- .. code-block:: none
- conv1 -> conv2 -> conv3 -> yyy
- Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
- .. code-block:: none
- conv1 -> conv2 -> conv3 -> xxx -> yyy
- Args:
- plugins (list[dict]): List of plugins cfg to build. The postfix is
- required if multiple same type plugins are inserted.
- stage_idx (int): Index of stage to build
- If stages is missing, the plugin would be applied to all
- stages.
- setting (list): The architecture setting of a stage layer.
- Returns:
- list[nn.Module]: Plugins for current stage
- """
- # TODO: It is not general enough to support any channel and needs
- # to be refactored
- in_channels = int(setting[1] * self.widen_factor)
- plugin_layers = []
- for plugin in plugins:
- plugin = plugin.copy()
- stages = plugin.pop('stages', None)
- assert stages is None or len(stages) == self.num_stages
- if stages is None or stages[stage_idx]:
- layer = build_plugin_layer(
- plugin['cfg'], in_channels=in_channels)[1]
- plugin_layers.append(layer)
- return plugin_layers
-
- def forward(self, x: torch.Tensor) -> tuple:
- """Forward batch_inputs from the data_preprocessor."""
- outs = []
- for i, layer_name in enumerate(self.layers):
- layer = getattr(self, layer_name)
- x = layer(x)
- if i in self.out_indices:
- outs.append(x)
- return tuple(outs)
|