123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
- from mmengine.model import BaseModule
- from torch import Tensor
- from typing import Sequence, Union
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
- from .se_layer import ChannelAttention
- class DarknetBottleneck(BaseModule):
- """The basic bottleneck block used in Darknet.
- Each ResBlock consists of two ConvModules and the input is added to the
- final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
- The first convLayer has filter size of 1x1 and the second one has the
- filter size of 3x3.
- Args:
- in_channels (int): The input channels of this Module.
- out_channels (int): The output channels of this Module.
- expansion (float): The kernel size of the convolution.
- Defaults to 0.5.
- add_identity (bool): Whether to add identity to the out.
- Defaults to True.
- use_depthwise (bool): Whether to use depthwise separable convolution.
- Defaults to False.
- conv_cfg (dict): Config dict for convolution layer. Defaults to None,
- which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Defaults to dict(type='BN').
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='Swish').
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- expansion: float = 0.5,
- add_identity: bool = True,
- use_depthwise: bool = False,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(
- type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='Swish'),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- hidden_channels = int(out_channels * expansion)
- conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
- self.conv1 = ConvModule(
- in_channels,
- hidden_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.conv2 = conv(
- hidden_channels,
- out_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.add_identity = \
- add_identity and in_channels == out_channels
- def forward(self, x: Tensor) -> Tensor:
- """Forward function."""
- identity = x
- out = self.conv1(x)
- out = self.conv2(out)
- if self.add_identity:
- return out + identity
- else:
- return out
- class CSPNeXtBlock(BaseModule):
- """The basic bottleneck block used in CSPNeXt.
- Args:
- in_channels (int): The input channels of this Module.
- out_channels (int): The output channels of this Module.
- expansion (float): Expand ratio of the hidden channel. Defaults to 0.5.
- add_identity (bool): Whether to add identity to the out. Only works
- when in_channels == out_channels. Defaults to True.
- use_depthwise (bool): Whether to use depthwise separable convolution.
- Defaults to False.
- kernel_size (int): The kernel size of the second convolution layer.
- Defaults to 5.
- conv_cfg (dict): Config dict for convolution layer. Defaults to None,
- which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Defaults to dict(type='BN', momentum=0.03, eps=0.001).
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='SiLU').
- init_cfg (:obj:`ConfigDict` or dict or list[dict] or
- list[:obj:`ConfigDict`], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- expansion: float = 0.5,
- add_identity: bool = True,
- use_depthwise: bool = False,
- kernel_size: int = 5,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(
- type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='SiLU'),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- hidden_channels = int(out_channels * expansion)
- conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
- self.conv1 = conv(
- in_channels,
- hidden_channels,
- 3,
- stride=1,
- padding=1,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.conv2 = DepthwiseSeparableConvModule(
- hidden_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=kernel_size // 2,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.add_identity = \
- add_identity and in_channels == out_channels
- def forward(self, x: Tensor) -> Tensor:
- """Forward function."""
- identity = x
- out = self.conv1(x)
- out = self.conv2(out)
- if self.add_identity:
- return out + identity
- else:
- return out
- class CSPLayer(BaseModule):
- """Cross Stage Partial Layer.
- Args:
- in_channels (int): The input channels of the CSP layer.
- out_channels (int): The output channels of the CSP layer.
- expand_ratio (float): Ratio to adjust the number of channels of the
- hidden layer. Defaults to 0.5.
- num_blocks (int): Number of blocks. Defaults to 1.
- add_identity (bool): Whether to add identity in blocks.
- Defaults to True.
- use_cspnext_block (bool): Whether to use CSPNeXt block.
- Defaults to False.
- use_depthwise (bool): Whether to use depthwise separable convolution in
- blocks. Defaults to False.
- channel_attention (bool): Whether to add channel attention in each
- stage. Defaults to True.
- conv_cfg (dict, optional): Config dict for convolution layer.
- Defaults to None, which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Defaults to dict(type='BN')
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='Swish')
- init_cfg (:obj:`ConfigDict` or dict or list[dict] or
- list[:obj:`ConfigDict`], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- expand_ratio: float = 0.5,
- num_blocks: int = 1,
- add_identity: bool = True,
- use_depthwise: bool = False,
- use_cspnext_block: bool = False,
- channel_attention: bool = False,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(
- type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='Swish'),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- block = CSPNeXtBlock if use_cspnext_block else DarknetBottleneck
- mid_channels = int(out_channels * expand_ratio)
- self.channel_attention = channel_attention
- self.main_conv = ConvModule(
- in_channels,
- mid_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.short_conv = ConvModule(
- in_channels,
- mid_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.final_conv = ConvModule(
- 2 * mid_channels,
- out_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.blocks = nn.Sequential(*[
- block(
- mid_channels,
- mid_channels,
- 1.0,
- add_identity,
- use_depthwise,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg) for _ in range(num_blocks)
- ])
- if channel_attention:
- self.attention = ChannelAttention(2 * mid_channels)
- def forward(self, x: Tensor) -> Tensor:
- """Forward function."""
- x_short = self.short_conv(x)
- x_main = self.main_conv(x)
- x_main = self.blocks(x_main)
- x_final = torch.cat((x_main, x_short), dim=1)
- if self.channel_attention:
- x_final = self.attention(x_final)
- return self.final_conv(x_final)
- class YoloV8Bottleneck(DarknetBottleneck):
- """The basic bottleneck block used in Darknet.
- Each ResBlock consists of two ConvModules and the input is added to the
- final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
- The first convLayer has filter size of k1Xk1 and the second one has the
- filter size of k2Xk2.
- Note:
- This DarknetBottleneck is little different from MMDet's, we can
- change the kernel size and padding for each conv.
- Args:
- in_channels (int): The input channels of this Module.
- out_channels (int): The output channels of this Module.
- expansion (float): The kernel size for hidden channel.
- Defaults to 0.5.
- kernel_size (Sequence[int]): The kernel size of the convolution.
- Defaults to (1, 3).
- padding (Sequence[int]): The padding size of the convolution.
- Defaults to (0, 1).
- add_identity (bool): Whether to add identity to the out.
- Defaults to True
- use_depthwise (bool): Whether to use depthwise separable convolution.
- Defaults to False
- conv_cfg (dict): Config dict for convolution layer. Default: None,
- which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Defaults to dict(type='BN').
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='Swish').
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- expansion: float = 0.5,
- kernel_size: Sequence[int] = (1, 3),
- padding: Sequence[int] = (0, 1),
- add_identity: bool = True,
- use_depthwise: bool = False,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(
- type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='SiLU', inplace=True),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(in_channels, out_channels, init_cfg=init_cfg)
- hidden_channels = int(out_channels * expansion)
- conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
- assert isinstance(kernel_size, Sequence) and len(kernel_size) == 2
- self.conv1 = ConvModule(
- in_channels,
- hidden_channels,
- kernel_size[0],
- padding=padding[0],
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.conv2 = conv(
- hidden_channels,
- out_channels,
- kernel_size[1],
- stride=1,
- padding=padding[1],
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.add_identity = \
- add_identity and in_channels == out_channels
- class CSPLayerWithTwoConv(BaseModule):
- """Cross Stage Partial Layer with 2 convolutions.
- Args:
- in_channels (int): The input channels of the CSP layer.
- out_channels (int): The output channels of the CSP layer.
- expand_ratio (float): Ratio to adjust the number of channels of the
- hidden layer. Defaults to 0.5.
- num_blocks (int): Number of blocks. Defaults to 1
- add_identity (bool): Whether to add identity in blocks.
- Defaults to True.
- conv_cfg (dict, optional): Config dict for convolution layer.
- Defaults to None, which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Defaults to dict(type='BN').
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='SiLU', inplace=True).
- init_cfg (:obj:`ConfigDict` or dict or list[dict] or
- list[:obj:`ConfigDict`], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- expand_ratio: float = 0.5,
- num_blocks: int = 1,
- add_identity: bool = True, # shortcut
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='SiLU', inplace=True),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.mid_channels = int(out_channels * expand_ratio)
- self.main_conv = ConvModule(
- in_channels,
- 2 * self.mid_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.final_conv = ConvModule(
- (2 + num_blocks) * self.mid_channels,
- out_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.blocks = nn.ModuleList(
- YoloV8Bottleneck(
- self.mid_channels,
- self.mid_channels,
- expansion=1,
- kernel_size=(3, 3),
- padding=(1, 1),
- add_identity=add_identity,
- use_depthwise=False,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg) for _ in range(num_blocks))
- def forward(self, x: Tensor) -> Tensor:
- """Forward process."""
- x_main = self.main_conv(x)
- x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
- x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
- return self.final_conv(torch.cat(x_main, 1))
- #should try switch to deploy with only 3*3 conv
- class SPPFBottleneck(BaseModule):
- """Spatial pyramid pooling - Fast (SPPF) layer for
- YOLOv5, YOLOX and PPYOLOE by Glenn Jocher
- Args:
- in_channels (int): The input channels of this Module.
- out_channels (int): The output channels of this Module.
- kernel_sizes (int, tuple[int]): Sequential or number of kernel
- sizes of pooling layers. Defaults to 5.
- use_conv_first (bool): Whether to use conv before pooling layer.
- In YOLOv5 and YOLOX, the para set to True.
- In PPYOLOE, the para set to False.
- Defaults to True.
- mid_channels_scale (float): Channel multiplier, multiply in_channels
- by this amount to get mid_channels. This parameter is valid only
- when use_conv_fist=True.Defaults to 0.5.
- conv_cfg (dict): Config dict for convolution layer. Defaults to None.
- which means using conv2d. Defaults to None.
- norm_cfg (dict): Config dict for normalization layer.
- Defaults to dict(type='BN', momentum=0.03, eps=0.001).
- act_cfg (dict): Config dict for activation layer.
- Defaults to dict(type='SiLU', inplace=True).
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_sizes: Union[int, Sequence[int]] = 5,
- use_conv_first: bool = True,
- mid_channels_scale: float = 0.5,
- conv_cfg: ConfigType = None,
- norm_cfg: ConfigType = dict(
- type='BN', momentum=0.03, eps=0.001),
- act_cfg: ConfigType = dict(type='SiLU', inplace=True),
- init_cfg: OptMultiConfig = None):
- super().__init__(init_cfg)
- if use_conv_first:
- mid_channels = int(in_channels * mid_channels_scale)
- self.conv1 = ConvModule(
- in_channels,
- mid_channels,
- 1,
- stride=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- else:
- mid_channels = in_channels
- self.conv1 = None
- self.kernel_sizes = kernel_sizes
- if isinstance(kernel_sizes, int):
- self.poolings = nn.MaxPool2d(
- kernel_size=kernel_sizes, stride=1, padding=kernel_sizes // 2)
- conv2_in_channels = mid_channels * 4
- else:
- self.poolings = nn.ModuleList([
- nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
- for ks in kernel_sizes
- ])
- conv2_in_channels = mid_channels * (len(kernel_sizes) + 1)
- self.conv2 = ConvModule(
- conv2_in_channels,
- out_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- def forward(self, x: Tensor) -> Tensor:
- """Forward process
- Args:
- x (Tensor): The input tensor.
- """
- if self.conv1:
- x = self.conv1(x)
- if isinstance(self.kernel_sizes, int):
- y1 = self.poolings(x)
- y2 = self.poolings(y1)
- x = torch.cat([x, y1, y2, self.poolings(y2)], dim=1)
- else:
- x = torch.cat(
- [x] + [pooling(x) for pooling in self.poolings], dim=1)
- x = self.conv2(x)
- return x
|