import numpy as np import math from typing import Tuple import torch import torch.nn.init as init import torch.nn.functional as F from mmcv.cnn import ConvModule,build_activation_layer from mmengine.model import BaseModule,ModuleList,Sequential from torch import Tensor, nn from mmdet.registry import MODELS from mmdet.utils import OptConfigType,ConfigType from ..layers import RtdetrDecoder,CdnQueryGenerator from .dino_utils import get_contrastive_denoising_training_group def _bias_initial_with_prob(prob): bias_init=float(-np.log((1-prob)/(prob))) return bias_init @torch.no_grad() def _linear_init(module:nn.Module)->None: bound = 1 / math.sqrt(module.weight.shape[0]) init.uniform_(module.weight,-bound,bound) if hasattr(module, "bias") and module.bias is not None: init.uniform(module.bias,-bound,bound) class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=dict(type='ReLU')): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) self.act = nn.Identity() if act is None else build_activation_layer(act) def forward(self, x): for i, layer in enumerate(self.layers): x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) return x @MODELS.register_module() class RtDetrTransformer(BaseModule): def _build_input_proj_layer(self, backbone_feat_channels): self.input_proj = ModuleList() for in_channels in backbone_feat_channels: self.input_proj.append( ConvModule( in_channels, self.hidden_dim, kernel_size=1, bias=False, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None)) in_channels=backbone_feat_channels[-1] for _ in range(self.num_levels - len(backbone_feat_channels)): self.input_proj.append( ConvModule( in_channels, self.hidden_dim, 3, 2, 1, bias=False, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None ) ) in_channels=self.hidden_dim def _generate_anchors(self,spatial_shapes:list=None,grid_size:float=0.05,dtype=torch.float32,device='cpu'): if spatial_shapes is None: spatial_shapes=[[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)] for s in self.feat_strides] anchors=[] # print(spatial_shapes) for lvl, (h, w) in enumerate(spatial_shapes): grid_y, grid_x = torch.meshgrid(\ torch.arange(end=h, dtype=dtype), \ torch.arange(end=w, dtype=dtype), indexing='ij') grid_xy = torch.stack([grid_x, grid_y], -1) valid_WH = torch.tensor([w, h]).to(dtype) grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4)) anchors = torch.concat(anchors, 1).to(device) valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) # print(f'anchors size is {anchors.size()} and valid mask size is {valid_mask.size()}') anchors = torch.log(anchors / (1 - anchors)) anchors = torch.where(valid_mask, anchors, torch.inf) return anchors, valid_mask def _reset_parameter(self): bias_cls=_bias_initial_with_prob(0.01) _linear_init(self.encoder_score_head) init.constant_(self.encoder_score_head.bias,bias_cls) init.constant_(self.encoder_bbox_head.layers[-1].weight,0.) init.constant_(self.encoder_bbox_head.layers[-1].bias,0.) for cls_,reg_ in zip(self.decoder_score_head,self.decoder_bbox_head): _linear_init(cls_) init.constant_(cls_.bias,bias_cls) init.constant_(reg_.layers[-1].weight,0.) init.constant_(reg_.layers[-1].bias,0.) _linear_init(self.encoder_output[0]) init.xavier_uniform_(self.encoder_output[0].weight) if self.learnt_init_query: init.xavier_uniform_(self.tgt_embed.weight) init.xavier_uniform_(self.query_pos_head.layers[0].weight) init.xavier_uniform_(self.query_pos_head.layers[1].weight) # for l in self.input_proj: # init.xavier_uniform_(l.weight) def _get_encoder_input(self,feats:Tensor)->Tuple[Tensor]: proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] if self.num_levels > len(proj_feats): len_srcs = len(proj_feats) for i in range(len_srcs, self.num_levels): if i == len_srcs: proj_feats.append(self.input_proj[i](feats[-1])) else: proj_feats.append(self.input_proj[i](proj_feats[-1])) #get input for encoder feat_flatten = [] spatial_shapes = [] for feat in proj_feats: spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) # [b, c, h, w] -> [b, h*w, c] feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) # [num_levels, 2] each level spatial_shapes.append(spatial_shape) spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) level_start_index = torch.cat(( spatial_shapes.new_zeros((1, )), # (num_level) spatial_shapes.prod(1).cumsum(0)[:-1])) # [b, l, c] feat_flatten = torch.concat(feat_flatten, 1) return (feat_flatten, spatial_shapes, level_start_index) def _get_decoder_input(self,memory:Tensor, spatial_shapes, denoising_class=None, denoising_bbox_unact=None): bs, _, _ = memory.shape # print(memory.size()) #prepare input for decoder if self.training or self.eval_size is None: anchors, valid_mask = self._generate_anchors(spatial_shapes,device=memory.device) else: anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device) memory = valid_mask.to(memory.dtype) * memory output_memory = self.encoder_output(memory) enc_outputs_class = self.encoder_score_head(output_memory) enc_outputs_coord_unact = self.encoder_bbox_head(output_memory) + anchors topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)[1] # extract region proposal boxes reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \ index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])) enc_topk_bboxes = F.sigmoid(reference_points_unact) if denoising_bbox_unact is not None: reference_points_unact = torch.concat( [denoising_bbox_unact, reference_points_unact], 1) enc_topk_logits = enc_outputs_class.gather(dim=1, \ index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])) # extract region features if self.learnt_init_query: target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) else: target = output_memory.gather(dim=1, \ index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) target = target.detach() if denoising_class is not None: target = torch.concat([denoising_class, target], 1) return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits def __init__(self, num_classes:int, hidden_dim:int, num_queries:int, position_type:str='sine', feat_channels:list=[256,256,256], feat_strides:list=[8,16,24], num_levels:int=3, num_crossattn_points:int=4, number_head:int=8, number_decoder_layer:int=6, dim_feedforward_ratio:int=4, dropout:float=0.0, act_cfg:OptConfigType = dict(type='ReLU', inplace=True), num_denoising:int=100, label_noise_ratio:float=0.5, box_noise_scale:float=1.0, learnt_init_query:bool=True, eval_size:list=None, eval_spatial_size:list=None, eval_idx:int=-1, eps:float=1e-2 ): super().__init__() assert position_type in ['sine', 'learned'], \ f'ValueError: position_embed_type not supported {position_type}!' assert len(feat_channels) <= num_levels assert len(feat_strides) == len(feat_channels) for _ in range(num_levels - len(feat_strides)): feat_strides.append(feat_strides[-1] * 2) self.hidden_dim = hidden_dim self.number_head = number_head self.feat_strides = feat_strides self.num_levels = num_levels self.num_classes = num_classes self.num_queries = num_queries self.eps = eps self.number_decoder_layer = number_decoder_layer self.eval_size = eval_size self.num_denoising=num_denoising self.label_noise_ratio=label_noise_ratio self.box_noise_scale=box_noise_scale self.eval_idx=eval_idx self.eval_spatial_size=eval_spatial_size #backbone feature projection self._build_input_proj_layer(feat_channels) #Transformer module # embed_dims, # num_heads, # attn_drop=0., # proj_drop=0., # dropout_layer=dict(type='Dropout', drop_prob=0.), # init_cfg=None, # batch_first=False, # **kwargs self_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head, attn_drop=dropout,proj_drop=dropout, batch_first=True) # embed_dims: int = 256, # num_heads: int = 8, # num_levels: int = 4, # num_points: int = 4, # im2col_step: int = 64, # dropout: float = 0.1, # batch_first: bool = False, # norm_cfg: Optional[dict] = None, # init_cfg: Optional[mmengine.ConfigDict] = None, # value_proj_ratio: float = 1.0 cross_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head, num_levels=num_levels,num_points=num_crossattn_points, dropout=dropout,batch_first=True) # embed_dims=256, # feedforward_channels=1024, # num_fcs=2, # act_cfg=dict(type='ReLU', inplace=True), # ffn_drop=0., # dropout_layer=None, # add_identity=True, # init_cfg=None, # layer_scale_init_value=0. ffn_cfg=dict(embed_dims=hidden_dim,feedforward_channels=hidden_dim*dim_feedforward_ratio, num_fcs=2,ffn_drop=0, act_cfg=act_cfg) decode_layer_cfg=dict(self_attn_cfg=self_attn_cfg,cross_attn_cfg=cross_attn_cfg,ffn_cfg=ffn_cfg) self.decoder=RtdetrDecoder(num_layers=number_decoder_layer,layer_cfg=decode_layer_cfg) #denoising part # def __init__(self, # num_classes: int, # embed_dims: int, # num_matching_queries: int, # label_noise_scale: float = 0.5, # box_noise_scale: float = 1.0, # group_cfg: OptConfigType = None) -> None: if num_denoising>0: self.dino=CdnQueryGenerator( num_classes=num_classes, embed_dims=hidden_dim, num_matching_queries=num_queries, label_noise_scale=label_noise_ratio, box_noise_scale=box_noise_scale, group_cfg=dict(dynamic=True, num_groups=None,num_dn_queries=num_denoising)) #decoder embedding self.learnt_init_query = learnt_init_query if learnt_init_query: self.tgt_embed = nn.Embedding(num_queries, hidden_dim) self.query_pos_head = MLP(4,hidden_dim=hidden_dim,output_dim=hidden_dim,num_layers=2) #encoder head in transformer self.encoder_output=Sequential(nn.Linear(hidden_dim,hidden_dim),nn.LayerNorm(hidden_dim)) self.encoder_score_head = nn.Linear(hidden_dim, num_classes) self.encoder_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) #decoder head in transformer self.decoder_score_head=ModuleList(nn.Linear(hidden_dim,num_classes) for _ in range(number_decoder_layer)) self.decoder_bbox_head=ModuleList(MLP(hidden_dim,hidden_dim,4,num_layers=3) for _ in range(number_decoder_layer)) #reset parametre for encoder_head and decoder_head with xavier uniform if self.eval_spatial_size: self.anchors, self.valid_mask = self._generate_anchors() # print(f'anchors size is {self.anchors.size()} and valid mask size is {self.valid_mask.size()}') self._reset_parameter() def forward(self,feats,pad_mask=None,gt_meta=None): (memory, spatial_shapes,level_start_index) = self._get_encoder_input(feats) if self.training and self.num_denoising: denoising_class, denoising_bbox_unact, attn_mask, dn_meta=self.dino.__call__(batch_data_samples=gt_meta) else: denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact) # def forward(self, # target:Tensor, # memory:Tensor, # memory_spatial_shapes:Tensor, # memory_level_start_index:Tensor, # ref_points_unact:Tensor, # query_pos_head:FFN, # bbox_head:ModuleList, # score_head:ModuleList, # attn_mask:Tensor=None, # )->Tuple[Tensor]: #cls and bbox for query and reference query, reference=self.decoder(target=target, memory=memory, memory_spatial_shapes=spatial_shapes, memory_level_start_index=level_start_index, ref_points_unact=init_ref_points_unact, query_pos_head=self.query_pos_head, bbox_head=self.decoder_bbox_head, score_head=self.decoder_score_head, attn_mask=attn_mask) return (query, reference, enc_topk_bboxes, enc_topk_logits, dn_meta)