# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmengine.model import BaseModule from torch import Tensor from typing import Sequence, Union from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig from .se_layer import ChannelAttention class DarknetBottleneck(BaseModule): """The basic bottleneck block used in Darknet. Each ResBlock consists of two ConvModules and the input is added to the final output. Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer has filter size of 1x1 and the second one has the filter size of 3x3. Args: in_channels (int): The input channels of this Module. out_channels (int): The output channels of this Module. expansion (float): The kernel size of the convolution. Defaults to 0.5. add_identity (bool): Whether to add identity to the out. Defaults to True. use_depthwise (bool): Whether to use depthwise separable convolution. Defaults to False. conv_cfg (dict): Config dict for convolution layer. Defaults to None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to dict(type='BN'). act_cfg (dict): Config dict for activation layer. Defaults to dict(type='Swish'). """ def __init__(self, in_channels: int, out_channels: int, expansion: float = 0.5, add_identity: bool = True, use_depthwise: bool = False, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='Swish'), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) hidden_channels = int(out_channels * expansion) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule self.conv1 = ConvModule( in_channels, hidden_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.conv2 = conv( hidden_channels, out_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.add_identity = \ add_identity and in_channels == out_channels def forward(self, x: Tensor) -> Tensor: """Forward function.""" identity = x out = self.conv1(x) out = self.conv2(out) if self.add_identity: return out + identity else: return out class CSPNeXtBlock(BaseModule): """The basic bottleneck block used in CSPNeXt. Args: in_channels (int): The input channels of this Module. out_channels (int): The output channels of this Module. expansion (float): Expand ratio of the hidden channel. Defaults to 0.5. add_identity (bool): Whether to add identity to the out. Only works when in_channels == out_channels. Defaults to True. use_depthwise (bool): Whether to use depthwise separable convolution. Defaults to False. kernel_size (int): The kernel size of the second convolution layer. Defaults to 5. conv_cfg (dict): Config dict for convolution layer. Defaults to None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001). act_cfg (dict): Config dict for activation layer. Defaults to dict(type='SiLU'). init_cfg (:obj:`ConfigDict` or dict or list[dict] or list[:obj:`ConfigDict`], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, expansion: float = 0.5, add_identity: bool = True, use_depthwise: bool = False, kernel_size: int = 5, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU'), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) hidden_channels = int(out_channels * expansion) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule self.conv1 = conv( in_channels, hidden_channels, 3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg) self.conv2 = DepthwiseSeparableConvModule( hidden_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.add_identity = \ add_identity and in_channels == out_channels def forward(self, x: Tensor) -> Tensor: """Forward function.""" identity = x out = self.conv1(x) out = self.conv2(out) if self.add_identity: return out + identity else: return out class CSPLayer(BaseModule): """Cross Stage Partial Layer. Args: in_channels (int): The input channels of the CSP layer. out_channels (int): The output channels of the CSP layer. expand_ratio (float): Ratio to adjust the number of channels of the hidden layer. Defaults to 0.5. num_blocks (int): Number of blocks. Defaults to 1. add_identity (bool): Whether to add identity in blocks. Defaults to True. use_cspnext_block (bool): Whether to use CSPNeXt block. Defaults to False. use_depthwise (bool): Whether to use depthwise separable convolution in blocks. Defaults to False. channel_attention (bool): Whether to add channel attention in each stage. Defaults to True. conv_cfg (dict, optional): Config dict for convolution layer. Defaults to None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to dict(type='BN') act_cfg (dict): Config dict for activation layer. Defaults to dict(type='Swish') init_cfg (:obj:`ConfigDict` or dict or list[dict] or list[:obj:`ConfigDict`], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 0.5, num_blocks: int = 1, add_identity: bool = True, use_depthwise: bool = False, use_cspnext_block: bool = False, channel_attention: bool = False, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='Swish'), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) block = CSPNeXtBlock if use_cspnext_block else DarknetBottleneck mid_channels = int(out_channels * expand_ratio) self.channel_attention = channel_attention self.main_conv = ConvModule( in_channels, mid_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.short_conv = ConvModule( in_channels, mid_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.final_conv = ConvModule( 2 * mid_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.blocks = nn.Sequential(*[ block( mid_channels, mid_channels, 1.0, add_identity, use_depthwise, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) for _ in range(num_blocks) ]) if channel_attention: self.attention = ChannelAttention(2 * mid_channels) def forward(self, x: Tensor) -> Tensor: """Forward function.""" x_short = self.short_conv(x) x_main = self.main_conv(x) x_main = self.blocks(x_main) x_final = torch.cat((x_main, x_short), dim=1) if self.channel_attention: x_final = self.attention(x_final) return self.final_conv(x_final) class YoloV8Bottleneck(DarknetBottleneck): """The basic bottleneck block used in Darknet. Each ResBlock consists of two ConvModules and the input is added to the final output. Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer has filter size of k1Xk1 and the second one has the filter size of k2Xk2. Note: This DarknetBottleneck is little different from MMDet's, we can change the kernel size and padding for each conv. Args: in_channels (int): The input channels of this Module. out_channels (int): The output channels of this Module. expansion (float): The kernel size for hidden channel. Defaults to 0.5. kernel_size (Sequence[int]): The kernel size of the convolution. Defaults to (1, 3). padding (Sequence[int]): The padding size of the convolution. Defaults to (0, 1). add_identity (bool): Whether to add identity to the out. Defaults to True use_depthwise (bool): Whether to use depthwise separable convolution. Defaults to False conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to dict(type='BN'). act_cfg (dict): Config dict for activation layer. Defaults to dict(type='Swish'). """ def __init__(self, in_channels: int, out_channels: int, expansion: float = 0.5, kernel_size: Sequence[int] = (1, 3), padding: Sequence[int] = (0, 1), add_identity: bool = True, use_depthwise: bool = False, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None) -> None: super().__init__(in_channels, out_channels, init_cfg=init_cfg) hidden_channels = int(out_channels * expansion) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule assert isinstance(kernel_size, Sequence) and len(kernel_size) == 2 self.conv1 = ConvModule( in_channels, hidden_channels, kernel_size[0], padding=padding[0], conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.conv2 = conv( hidden_channels, out_channels, kernel_size[1], stride=1, padding=padding[1], conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.add_identity = \ add_identity and in_channels == out_channels class CSPLayerWithTwoConv(BaseModule): """Cross Stage Partial Layer with 2 convolutions. Args: in_channels (int): The input channels of the CSP layer. out_channels (int): The output channels of the CSP layer. expand_ratio (float): Ratio to adjust the number of channels of the hidden layer. Defaults to 0.5. num_blocks (int): Number of blocks. Defaults to 1 add_identity (bool): Whether to add identity in blocks. Defaults to True. conv_cfg (dict, optional): Config dict for convolution layer. Defaults to None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to dict(type='BN'). act_cfg (dict): Config dict for activation layer. Defaults to dict(type='SiLU', inplace=True). init_cfg (:obj:`ConfigDict` or dict or list[dict] or list[:obj:`ConfigDict`], optional): Initialization config dict. Defaults to None. """ def __init__( self, in_channels: int, out_channels: int, expand_ratio: float = 0.5, num_blocks: int = 1, add_identity: bool = True, # shortcut conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) self.mid_channels = int(out_channels * expand_ratio) self.main_conv = ConvModule( in_channels, 2 * self.mid_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.final_conv = ConvModule( (2 + num_blocks) * self.mid_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.blocks = nn.ModuleList( YoloV8Bottleneck( self.mid_channels, self.mid_channels, expansion=1, kernel_size=(3, 3), padding=(1, 1), add_identity=add_identity, use_depthwise=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) for _ in range(num_blocks)) def forward(self, x: Tensor) -> Tensor: """Forward process.""" x_main = self.main_conv(x) x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) return self.final_conv(torch.cat(x_main, 1)) #should try switch to deploy with only 3*3 conv class SPPFBottleneck(BaseModule): """Spatial pyramid pooling - Fast (SPPF) layer for YOLOv5, YOLOX and PPYOLOE by Glenn Jocher Args: in_channels (int): The input channels of this Module. out_channels (int): The output channels of this Module. kernel_sizes (int, tuple[int]): Sequential or number of kernel sizes of pooling layers. Defaults to 5. use_conv_first (bool): Whether to use conv before pooling layer. In YOLOv5 and YOLOX, the para set to True. In PPYOLOE, the para set to False. Defaults to True. mid_channels_scale (float): Channel multiplier, multiply in_channels by this amount to get mid_channels. This parameter is valid only when use_conv_fist=True.Defaults to 0.5. conv_cfg (dict): Config dict for convolution layer. Defaults to None. which means using conv2d. Defaults to None. norm_cfg (dict): Config dict for normalization layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001). act_cfg (dict): Config dict for activation layer. Defaults to dict(type='SiLU', inplace=True). init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, kernel_sizes: Union[int, Sequence[int]] = 5, use_conv_first: bool = True, mid_channels_scale: float = 0.5, conv_cfg: ConfigType = None, norm_cfg: ConfigType = dict( type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None): super().__init__(init_cfg) if use_conv_first: mid_channels = int(in_channels * mid_channels_scale) self.conv1 = ConvModule( in_channels, mid_channels, 1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) else: mid_channels = in_channels self.conv1 = None self.kernel_sizes = kernel_sizes if isinstance(kernel_sizes, int): self.poolings = nn.MaxPool2d( kernel_size=kernel_sizes, stride=1, padding=kernel_sizes // 2) conv2_in_channels = mid_channels * 4 else: self.poolings = nn.ModuleList([ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes ]) conv2_in_channels = mid_channels * (len(kernel_sizes) + 1) self.conv2 = ConvModule( conv2_in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) def forward(self, x: Tensor) -> Tensor: """Forward process Args: x (Tensor): The input tensor. """ if self.conv1: x = self.conv1(x) if isinstance(self.kernel_sizes, int): y1 = self.poolings(x) y2 = self.poolings(y1) x = torch.cat([x, y1, y2, self.poolings(y2)], dim=1) else: x = torch.cat( [x] + [pooling(x) for pooling in self.poolings], dim=1) x = self.conv2(x) return x