from typing import Optional,Tuple,Dict,List import math import numpy as np import types import torch import torch.nn.functional as F import torch.nn as nn from mmcv.cnn import build_norm_layer,ConvModule,build_activation_layer,build_conv_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmengine.model import BaseModule, ModuleList from torch import Tensor from mmcv.ops import MultiScaleDeformableAttention from mmdet.models.layers.transformer.utils import inverse_sigmoid from .deformable_detr_layers import DetrTransformerDecoder,DetrTransformerDecoderLayer from mmdet.utils import ConfigType, OptConfigType,OptMultiConfig from mmdet.registry import MODELS class SPD(nn.Module): # Changing the dimension of the Tensor def __init__(self, dimension=1): super().__init__() self.d = dimension def forward(self, x): return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1) class RepVGGBlock(BaseModule): """A block in RepVGG architecture, supporting optional normalization in the identity branch. This block consists of 3x3 and 1x1 convolutions, with an optional identity shortcut branch that includes normalization. Args: in_channels (int): The input channels of the block. out_channels (int): The output channels of the block. stride (int): The stride of the block. Defaults to 1. padding (int): The padding of the block. Defaults to 1. dilation (int): The dilation of the block. Defaults to 1. groups (int): The groups of the block. Defaults to 1. padding_mode (str): The padding mode of the block. Defaults to 'zeros'. norm_cfg (dict): The config dict for normalization layers. Defaults to dict(type='BN'). act_cfg (dict): The config dict for activation layers. Defaults to dict(type='ReLU'). without_branch_norm (bool): Whether to skip branch_norm. Defaults to True. init_cfg (dict): The config dict for initialization. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, stride: int = 1, padding: int = 1, dilation: int = 1, groups: int = 1, norm_cfg: OptConfigType = dict(type='BN', momentum=0.03, eps=0.001), act_cfg: OptConfigType = dict(type='ReLU',inplace=True), without_branch_norm: bool = True, init_cfg: OptConfigType = None): super(RepVGGBlock, self).__init__(init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.norm_cfg = norm_cfg self.act_cfg = act_cfg # judge if input shape and output shape are the same. # If true, add a normalized identity shortcut. self.branch_norm = None if out_channels == in_channels and stride == 1 and \ padding == dilation and not without_branch_norm: self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] self.branch_3x3 = ConvModule( self.in_channels, self.out_channels, 3, stride=self.stride, padding=self.padding, groups=self.groups, dilation=self.dilation, norm_cfg=self.norm_cfg, act_cfg=None) self.branch_1x1 = ConvModule( self.in_channels, self.out_channels, 1, groups=self.groups, norm_cfg=self.norm_cfg, act_cfg=None) self.act = build_activation_layer(act_cfg) def forward(self, x: Tensor) -> Tensor: """Forward pass through the RepVGG block. The output is the sum of 3x3 and 1x1 convolution outputs, along with the normalized identity branch output, followed by activation. Args: x (Tensor): The input tensor. Returns: Tensor: The output tensor. """ if self.branch_norm is None: branch_norm_out = 0 else: branch_norm_out = self.branch_norm(x) out = self.branch_3x3(x) + self.branch_1x1(x) + branch_norm_out out = self.act(out) return out def _pad_1x1_to_3x3_tensor(self, kernel1x1): """Pad 1x1 tensor to 3x3. Args: kernel1x1 (Tensor): The input 1x1 kernel need to be padded. Returns: Tensor: 3x3 kernel after padded. """ if kernel1x1 is None: return 0 else: return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) def _fuse_bn_tensor(self, branch: nn.Module) -> Tensor: """Derives the equivalent kernel and bias of a specific branch layer. Args: branch (nn.Module): The layer that needs to be equivalently transformed, which can be nn.Sequential or nn.Batchnorm2d Returns: tuple: Equivalent kernel and bias """ if branch is None: return 0, 0 if isinstance(branch, ConvModule): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var gamma = branch.bn.weight beta = branch.bn.bias eps = branch.bn.eps else: assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d)) if not hasattr(self, 'id_tensor'): input_dim = self.in_channels // self.groups kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) for i in range(self.in_channels): kernel_value[i, i % input_dim, 1, 1] = 1 self.id_tensor = torch.from_numpy(kernel_value).to( branch.weight.device) kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def get_equivalent_kernel_bias(self): """Derives the equivalent kernel and bias in a differentiable way. Returns: tuple: Equivalent kernel and bias """ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3) kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1) kernelid, biasid = (0, 0) if self.branch_norm is None else \ self._fuse_bn_tensor(self.branch_norm) return (kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid) def switch_to_deploy(self, test_cfg: Optional[Dict] = None): """Switches the block to deployment mode. In deployment mode, the block uses a single convolution operation derived from the equivalent kernel and bias, replacing the original branches. This reduces computational complexity during inference. """ if getattr(self, 'deploy', False): return kernel, bias = self.get_equivalent_kernel_bias() self.conv_reparam = nn.Conv2d( in_channels=self.branch_3x3.conv.in_channels, out_channels=self.branch_3x3.conv.out_channels, kernel_size=self.branch_3x3.conv.kernel_size, stride=self.branch_3x3.conv.stride, padding=self.branch_3x3.conv.padding, dilation=self.branch_3x3.conv.dilation, groups=self.branch_3x3.conv.groups, bias=True) self.conv_reparam.weight.data = kernel self.conv_reparam.bias.data = bias for para in self.parameters(): para.detach_() self.__delattr__('branch_3x3') self.__delattr__('branch_1x1') if hasattr(self, 'branch_norm'): self.__delattr__('branch_norm') def _forward(self, x): return self.act(self.conv_reparam(x)) self.forward = types.MethodType(_forward, self) self.deploy = True class CSPRepLayer(BaseModule): """CSPRepLayer, a layer that combines Cross Stage Partial Networks with RepVGG Blocks. Args: in_channels (int): Number of input channels to the layer. out_channels (int): Number of output channels from the layer. num_blocks (int): The number of RepVGG blocks to be used in the layer. Defaults to 3. widen_factor (float): Expansion factor for intermediate channels. Determines the hidden channel size based on out_channels. Defaults to 1.0. norm_cfg (dict): Configuration for normalization layers. Defaults to Batch Normalization with trainable parameters. act_cfg (dict): Configuration for activation layers. Defaults to SiLU (Swish) with in-place operation. """ def __init__(self, in_channels: int, out_channels: int, num_blocks: int = 3, widen_factor: float = 1.0, norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), act_cfg: OptConfigType = dict(type='SiLU', inplace=True)): super(CSPRepLayer, self).__init__() hidden_channels = int(out_channels * widen_factor) self.conv1 = ConvModule( in_channels, hidden_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=act_cfg) self.conv2 = ConvModule( in_channels, hidden_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=act_cfg) self.bottlenecks = nn.Sequential(*[ RepVGGBlock(hidden_channels, hidden_channels, act_cfg=act_cfg,norm_cfg=norm_cfg) for _ in range(num_blocks) ]) if hidden_channels != out_channels: self.conv3 = ConvModule( hidden_channels, out_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=act_cfg) else: self.conv3 = nn.Identity() def forward(self, x: Tensor) -> Tensor: """Forward function. Args: x (Tensor): The input tensor. Returns: Tensor: The output tensor. """ x_1 = self.conv1(x) x_1 = self.bottlenecks(x_1) x_2 = self.conv2(x) return self.conv3(x_1 + x_2) #Encoder and Encoder layer with embedded postion class EncoderLayer(BaseModule): def __init__(self, self_attn_cfg:OptConfigType=dict( embed_dims=256, num_heads=8, attn_drop=0, proj_drop=0, ), ffn_cfg:OptConfigType=dict( embed_dims=256, feedforward_channels=1024, num_fcs=2,ffn_drop=0, act_cfg=dict(type='ReLU',inpalce=True)), norm_cfg: OptConfigType = dict(type='LN'), init_cfg: OptConfigType = None, )->None: super().__init__(init_cfg) self.self_attn_cfg=self_attn_cfg if 'batch_first' not in self.self_attn_cfg: self.self_attn_cfg['batch_first'] = True else: assert self.self_attn_cfg['batch_first'] is True, 'First \ dimension of all DETRs in mmdet is `batch`, \ please set `batch_first` flag.' self.ffn_cfg = ffn_cfg self.norm_cfg = norm_cfg self._init_layers() def _init_layers(self)->None: #initialize the FFN and Multiheadattention layer self.self_attn = MultiheadAttention(**self.self_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(2) ] self.norms = ModuleList(norms_list) def forward(self,query:Tensor,pos_embed:Tensor,key_padding_mask=None)->Tensor: query = self.self_attn( query=query, key=query, value=query, query_pos=pos_embed, key_pos=pos_embed, key_padding_mask=key_padding_mask, ) query=self.norms[0](query) query=self.ffn(query) query=self.norms[1](query) return query class Encoder(BaseModule): def __init__(self, num_layers,layer_cfg: ConfigType=None): super().__init__() self.num_layers=num_layers self.layer_cfg = layer_cfg self._init_layers() def _init_layers(self)->None: self.layers = ModuleList([ EncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims def forward(self,query: Tensor, query_pos: Tensor, key_padding_mask: Tensor=None)->Tensor: for layer in self.layers: output=layer(query,query_pos,key_padding_mask) return output #RtdetrFPN class RTDETRFPN(BaseModule): """FPN of RTDETR. Args: in_channels (List[int], optional): The input channels of the feature maps. Defaults to [256, 256, 256]. out_channels (int, optional): The output dimension of the MLP. Defaults to 256. expansion (float, optional): The expansion of the CSPLayer. Defaults to 1.0. depth_mult (float, optional): The depth multiplier of the CSPLayer. Defaults to 1.0. upsample_cfg (dict): Config dict for interpolate layer. Default: `dict(scale_factor=2, mode='nearest')` conv_cfg (dict, optional): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for normalization layers. Defaults to dict(type='BN'). act_cfg (:obj:`ConfigDict` or dict, optional): The config dict for activation layers. Defaults to dict(type='SiLU', inplace=True). init_cfg (:obj:`ConfigDict` or dict or list[dict] or list[:obj:`ConfigDict`], optional): Initialization config dict. """ def __init__( self, in_channels: List[int] = [256, 256, 256], out_channels: int = 256, expansion: float = 1.0, depth_mult: float = 1.0, with_spd:bool=True, upsample_cfg: ConfigType = dict(scale_factor=2, mode='nearest'), conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), act_cfg: OptConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = dict( type='Kaiming', layer='Conv2d', a=math.sqrt(5), distribution='uniform', mode='fan_in', nonlinearity='leaky_relu') ) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels num_csp_blocks = round(3 * depth_mult) # top-down fpn self.upsample = nn.Upsample(**upsample_cfg) self.reduce_layers = nn.ModuleList() self.top_down_blocks = nn.ModuleList() for idx in range(len(in_channels) - 1, 0, -1): self.reduce_layers.append( ConvModule( in_channels[idx], in_channels[idx - 1], 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.top_down_blocks.append( CSPRepLayer( in_channels[idx - 1] * 2, in_channels[idx - 1], num_blocks=num_csp_blocks, widen_factor=expansion, norm_cfg=norm_cfg, act_cfg=act_cfg)) # build bottom-up blocks self.downsamples = nn.ModuleList() self.bottom_up_blocks = nn.ModuleList() self.with_spd=with_spd self.spd=SPD() self.with_spd_norm=build_norm_layer(norm_cfg, in_channels[idx]*4, postfix=1)[1] if self.with_spd: for idx in range(len(in_channels) - 1): self.downsamples.append( build_conv_layer( conv_cfg, in_channels[idx], in_channels[idx], 3, stride=1, padding=1, bias=False), ) self.bottom_up_blocks.append( CSPRepLayer( in_channels[idx] * 5, in_channels[idx + 1], num_blocks=num_csp_blocks, widen_factor=expansion, norm_cfg=norm_cfg, act_cfg=act_cfg)) else: for idx in range(len(in_channels) - 1): self.downsamples.append( ConvModule( in_channels[idx], in_channels[idx], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.bottom_up_blocks.append( CSPRepLayer( in_channels[idx] * 2, in_channels[idx + 1], num_blocks=num_csp_blocks, widen_factor=expansion, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.out_convs = nn.ModuleList() for i in range(len(in_channels)): self.out_convs.append( ConvModule( in_channels[i], out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)) def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: """ Args: inputs (tuple[Tensor]): input features. Returns: tuple[Tensor]: FPN features. """ assert len(inputs) == len(self.in_channels) # top-down path inner_outs = [inputs[-1]] for idx in range(len(self.in_channels) - 1, 0, -1): feat_high = inner_outs[0] feat_low = inputs[idx - 1] feat_high = self.reduce_layers[len(self.in_channels) - 1 - idx]( feat_high) inner_outs[0] = feat_high upsample_feat = self.upsample(feat_high) inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( torch.cat([upsample_feat, feat_low], 1)) inner_outs.insert(0, inner_out) # bottom-up path outs = [inner_outs[0]] for idx in range(len(self.in_channels) - 1): feat_low = outs[-1] feat_high = inner_outs[idx + 1] downsample_feat = self.downsamples[idx](feat_low) if self.with_spd: downsample_feat = self.spd(downsample_feat) downsample_feat = self.with_spd_norm(downsample_feat) out = self.bottom_up_blocks[idx]( torch.cat([downsample_feat, feat_high], 1)) outs.append(out) # out convs for idx, conv in enumerate(self.out_convs): outs[idx] = conv(outs[idx]) return tuple(outs) #Instra-scale feature interaction and cross-sacle feature-fusion class SSFF(BaseModule): def __init__(self, in_channels:list, out_channels, ): super().__init__() self.in_channels=in_channels self.out_channels=out_channels self.convs = nn.ModuleList() for in_channel in in_channels: self.convs.append( ConvModule( in_channel, out_channels, 1, padding=0, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU'))) self.conv3d=nn.Conv3d(out_channels,out_channels,kernel_size=(1,1,1)) self.bn3d=nn.BatchNorm3d(out_channels) self.act = nn.LeakyReLU(0.1) self.pool_3d = nn.MaxPool3d(kernel_size=(3,1,1)) def forward(self,inputs)->Tensor: outputs=[] for i in range(len(inputs)): feature=self.convs[i](inputs[i]) if i!=0: feature=F.interpolate(feature,inputs[0].size()[2:], mode='nearest') outputs.append(feature) for i in range(len(outputs)): outputs[i]=torch.unsqueeze(outputs[i], -3) combine=torch.cat(outputs,dim=2) conv_3d = self.act(self.bn3d(self.conv3d(combine))) output = self.pool_3d(conv_3d) output = torch.squeeze(output, 2) return output @MODELS.register_module() class HybridEncoder(BaseModule): def __init__(self, in_channels=[512,1024,2048], feat_strides=[8,16,32], hidden_dim=256, n_head=8, dim_feedforward_ratio=4, drop_out=0.0, enc_act:OptConfigType=dict(type='GELU'), use_encoder_idx=[2], num_encoder_layers=1, with_ssff:bool=False, with_spd:bool=False, pe_temperature=100*100, norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), widen_factor=1, deepen_factor=1, eval_spatial_size=None, input_proj_cfg:OptConfigType=None, act_cfg: OptConfigType = dict(type='SiLU', inplace=True) ): super().__init__() self.in_channels = in_channels self.feat_strides = feat_strides self.hidden_dim = hidden_dim self.use_encoder_idx = use_encoder_idx self.num_encoder_layers = num_encoder_layers self.pe_temperature = pe_temperature self.eval_spatial_size = eval_spatial_size self.out_channels = [hidden_dim for _ in range(len(in_channels))] self.out_strides = feat_strides self.with_ssff=with_ssff #using channel mapper implemented in ChannelMapper self.input_proj = MODELS.build(input_proj_cfg)\ if input_proj_cfg is not None else nn.Identity() if self.with_ssff: self.ssff=SSFF(in_channels=[hidden_dim,hidden_dim,hidden_dim],out_channels=hidden_dim) #transformer encoder and position encoder # def __init__(self, # embed_dims, # num_heads, # attn_drop=0., # proj_drop=0., # dropout_layer=dict(type='Dropout', drop_prob=0.), # init_cfg=None, # batch_first=False, # **kwargs) #Multihead # def __init__(self, # embed_dims=256, # feedforward_channels=1024, # num_fcs=2, # act_cfg=dict(type='ReLU', inplace=True), # ffn_drop=0., # dropout_layer=None, # add_identity=True, # init_cfg=None, # layer_scale_init_value=0.): #FFN encoder_layer_opt = dict( self_attn_cfg=dict(embed_dims=hidden_dim, num_heads=n_head, attn_drop=drop_out, proj_drop=drop_out, ), ffn_cfg=dict(embed_dims=hidden_dim, feedforward_channels=hidden_dim*dim_feedforward_ratio, num_fcs=2, ffn_drop=drop_out, act_cfg=enc_act) ) self.encoder = nn.ModuleList([ Encoder(num_encoder_layers, layer_cfg=encoder_layer_opt) for _ in range(len(use_encoder_idx)) ]) self.fpn=RTDETRFPN(in_channels=[hidden_dim,hidden_dim,hidden_dim], out_channels=hidden_dim, expansion=widen_factor, depth_mult=deepen_factor, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spd=with_spd ) self._reset_parameters() def _reset_parameters(self): if self.eval_spatial_size: for idx in self.use_encoder_idx: stride = self.feat_strides[idx] pos_embed = self.build_2d_sincos_position_embedding( self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride, self.hidden_dim, self.pe_temperature) setattr(self, f'pos_embed{idx}', pos_embed) # self.register_buffer(f'pos_embed{idx}', pos_embed) @staticmethod def build_2d_sincos_position_embedding( w: int, h: int, embed_dim: int = 256, temperature: float = 10000., device=None, ) -> Tensor: grid_w = torch.arange(w, dtype=torch.float32, device=device) grid_h = torch.arange(h, dtype=torch.float32, device=device) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, ('Embed dimension must be divisible by 4 ' 'for 2D sin-cos position embedding') pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32, device=device) omega = temperature**(omega / -pos_dim) out_w = grid_w.flatten()[..., None] @ omega[None] out_h = grid_h.flatten()[..., None] @ omega[None] pos_embd = [ torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h) ] return torch.cat(pos_embd, axis=1)[None, :, :] def forward(self,inputs:Tuple[Tensor])->Tuple[Tensor]: assert len(inputs)==len(self.in_channels) proj_feats=self.input_proj(inputs) proj_feats=list(proj_feats) if self.with_ssff: fuse_layer=self.ssff(proj_feats) proj_feats[len(proj_feats)-1]=fuse_layer # proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(inputs)] #encoder with position encoding if self.num_encoder_layers>0: for i,enc_idx in enumerate(self.use_encoder_idx): h,w=proj_feats[enc_idx].shape[2:] #B,C,H,W -> B,H*W,C src_flatten=proj_feats[enc_idx].flatten(2).permute(0,2,1).contiguous() if self.training or self.eval_spatial_size is None: pos_enc = self.build_2d_sincos_position_embedding( h, w, embed_dim=self.hidden_dim, temperature=self.pe_temperature, device=src_flatten.device) else: pos_enc=getattr(self, f'pos_embed{enc_idx}', None).to(src_flatten.device) memory = self.encoder[i]( src_flatten, query_pos=pos_enc) proj_feats[enc_idx] = memory.permute( 0, 2, 1).contiguous().reshape([-1, self.hidden_dim, h, w]) #fpn outs=self.fpn(tuple(proj_feats)) return outs #derived from detrTransformerDecoder check init in detr TransformerDecoder class RtDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): """Decoder layer of Deformable DETR.""" def _init_layers(self) -> None: """Initialize self_attn, cross-attn, ffn, and norms.""" self.self_attn = MultiheadAttention(**self.self_attn_cfg) self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(3) ] self.norms = ModuleList(norms_list) def with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos def forward(self, tgt:Tensor, referenc_point:Tensor, memory:Tensor, spartial_shapes:Tensor, level_start_index:Tensor, query_pos_embed:Tensor, attn_mask:Tensor=None, )->Tensor: #tgt is feature from backbone #reference point is 2d coodinates corresponding to features #memory is output from hybrid encoder #query embedding is embeding with refrence point #self attention tgt_after_attn=self.self_attn(query=tgt, key=tgt, value=tgt, query_pos=query_pos_embed, attn_mask=attn_mask) tgt=tgt+tgt_after_attn tgt=self.norms[0](tgt) #cross attention #level_start_index and spatial_shapes need to be tensor tgt_after_attn=self.cross_attn.forward( query=tgt,value=memory, reference_points=referenc_point, spatial_shapes=spartial_shapes, query_pos=query_pos_embed, level_start_index=level_start_index) tgt=tgt+tgt_after_attn tgt=self.norms[1](tgt) #feed forward tgt_after_attn=self.ffn(tgt) tgt=tgt+tgt_after_attn tgt=self.norms[2](tgt) return tgt class RtdetrDecoder(DetrTransformerDecoder): def _init_layers(self) -> None: self.layers = ModuleList([ RtDetrTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims self.eval_idx=self.num_layers-1 def forward(self, target:Tensor, memory:Tensor, memory_spatial_shapes:Tensor, memory_level_start_index:Tensor, ref_points_unact:Tensor, query_pos_head:nn.Module, #MLP bbox_head:ModuleList, score_head:ModuleList, attn_mask:Tensor=None, )->Tuple[Tensor]: output=target dec_out_bboxes=[] dec_out_logits=[] ref_points_detach = F.sigmoid(ref_points_unact) for i, layer in enumerate(self.layers): ref_points_input = ref_points_detach.unsqueeze(2) query_pos_embed = query_pos_head(ref_points_detach) # def forward(self, # tgt:Tensor, # referenc_point:Tensor, # memory:Tensor, # spartial_shapes:Tensor, # level_start_index:Tensor, # query_pos_embed:Tensor, # attn_mask:Tensor=None, # )->Tensor: output = layer(output, ref_points_input, memory, memory_spatial_shapes, memory_level_start_index, query_pos_embed,attn_mask) inter_ref_bbox=F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) if self.training: dec_out_logits.append(score_head[i](output)) if i == 0: dec_out_bboxes.append(inter_ref_bbox) else: dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))) elif i==self.eval_idx: dec_out_logits.append(score_head[i](output)) dec_out_bboxes.append(inter_ref_bbox) break ref_points_detach = inter_ref_bbox.detach( ) if self.training else inter_ref_bbox return tuple([torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)])