123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- import numpy as np
- import math
- from typing import Tuple
- import torch
- import torch.nn.init as init
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule,build_activation_layer
- from mmengine.model import BaseModule,ModuleList,Sequential
- from torch import Tensor, nn
- from mmdet.registry import MODELS
- from mmdet.utils import OptConfigType,ConfigType
- from ..layers import RtdetrDecoder,CdnQueryGenerator
- from .dino_utils import get_contrastive_denoising_training_group
- def _bias_initial_with_prob(prob):
- bias_init=float(-np.log((1-prob)/(prob)))
- return bias_init
- @torch.no_grad()
- def _linear_init(module:nn.Module)->None:
- bound = 1 / math.sqrt(module.weight.shape[0])
- init.uniform_(module.weight,-bound,bound)
- if hasattr(module, "bias") and module.bias is not None:
- init.uniform(module.bias,-bound,bound)
- class MLP(nn.Module):
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=dict(type='ReLU')):
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- self.act = nn.Identity() if act is None else build_activation_layer(act)
- def forward(self, x):
- for i, layer in enumerate(self.layers):
- x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
- return x
- @MODELS.register_module()
- class RtDetrTransformer(BaseModule):
- def _build_input_proj_layer(self, backbone_feat_channels):
- self.input_proj = ModuleList()
- for in_channels in backbone_feat_channels:
- self.input_proj.append(
- ConvModule(
- in_channels,
- self.hidden_dim,
- kernel_size=1,
- bias=False,
- norm_cfg=dict(type='BN', requires_grad=True),
- act_cfg=None))
- in_channels=backbone_feat_channels[-1]
- for _ in range(self.num_levels - len(backbone_feat_channels)):
- self.input_proj.append(
- ConvModule(
- in_channels,
- self.hidden_dim,
- 3,
- 2,
- 1,
- bias=False,
- norm_cfg=dict(type='BN', requires_grad=True),
- act_cfg=None
- )
- )
- in_channels=self.hidden_dim
- def _generate_anchors(self,spatial_shapes:list=None,grid_size:float=0.05,dtype=torch.float32,device='cpu'):
- if spatial_shapes is None:
- spatial_shapes=[[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)] for s in self.feat_strides]
- anchors=[]
- # print(spatial_shapes)
- for lvl, (h, w) in enumerate(spatial_shapes):
- grid_y, grid_x = torch.meshgrid(\
- torch.arange(end=h, dtype=dtype), \
- torch.arange(end=w, dtype=dtype), indexing='ij')
- grid_xy = torch.stack([grid_x, grid_y], -1)
- valid_WH = torch.tensor([w, h]).to(dtype)
- grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
- wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
- anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
- anchors = torch.concat(anchors, 1).to(device)
- valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
- # print(f'anchors size is {anchors.size()} and valid mask size is {valid_mask.size()}')
- anchors = torch.log(anchors / (1 - anchors))
- anchors = torch.where(valid_mask, anchors, torch.inf)
- return anchors, valid_mask
- def _reset_parameter(self):
- bias_cls=_bias_initial_with_prob(0.01)
- _linear_init(self.encoder_score_head)
- init.constant_(self.encoder_score_head.bias,bias_cls)
- init.constant_(self.encoder_bbox_head.layers[-1].weight,0.)
- init.constant_(self.encoder_bbox_head.layers[-1].bias,0.)
- for cls_,reg_ in zip(self.decoder_score_head,self.decoder_bbox_head):
- _linear_init(cls_)
- init.constant_(cls_.bias,bias_cls)
- init.constant_(reg_.layers[-1].weight,0.)
- init.constant_(reg_.layers[-1].bias,0.)
- _linear_init(self.encoder_output[0])
- init.xavier_uniform_(self.encoder_output[0].weight)
- if self.learnt_init_query:
- init.xavier_uniform_(self.tgt_embed.weight)
- init.xavier_uniform_(self.query_pos_head.layers[0].weight)
- init.xavier_uniform_(self.query_pos_head.layers[1].weight)
- # for l in self.input_proj:
- # init.xavier_uniform_(l.weight)
- def _get_encoder_input(self,feats:Tensor)->Tuple[Tensor]:
- proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
- if self.num_levels > len(proj_feats):
- len_srcs = len(proj_feats)
- for i in range(len_srcs, self.num_levels):
- if i == len_srcs:
- proj_feats.append(self.input_proj[i](feats[-1]))
- else:
- proj_feats.append(self.input_proj[i](proj_feats[-1]))
- #get input for encoder
- feat_flatten = []
- spatial_shapes = []
- for feat in proj_feats:
- spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
- # [b, c, h, w] -> [b, h*w, c]
- feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
- # [num_levels, 2] each level
- spatial_shapes.append(spatial_shape)
- spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
- level_start_index = torch.cat((
- spatial_shapes.new_zeros((1, )), # (num_level)
- spatial_shapes.prod(1).cumsum(0)[:-1]))
- # [b, l, c]
- feat_flatten = torch.concat(feat_flatten, 1)
- return (feat_flatten, spatial_shapes, level_start_index)
- def _get_decoder_input(self,memory:Tensor,
- spatial_shapes,
- denoising_class=None,
- denoising_bbox_unact=None):
- bs, _, _ = memory.shape
- # print(memory.size())
- #prepare input for decoder
- if self.training or self.eval_size is None:
- anchors, valid_mask = self._generate_anchors(spatial_shapes,device=memory.device)
- else:
- anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
- memory = valid_mask.to(memory.dtype) * memory
- output_memory = self.encoder_output(memory)
- enc_outputs_class = self.encoder_score_head(output_memory)
- enc_outputs_coord_unact = self.encoder_bbox_head(output_memory) + anchors
- topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)[1]
- # extract region proposal boxes
- reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
- index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
- enc_topk_bboxes = F.sigmoid(reference_points_unact)
- if denoising_bbox_unact is not None:
- reference_points_unact = torch.concat(
- [denoising_bbox_unact, reference_points_unact], 1)
- enc_topk_logits = enc_outputs_class.gather(dim=1, \
- index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
- # extract region features
- if self.learnt_init_query:
- target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
- else:
- target = output_memory.gather(dim=1, \
- index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
- target = target.detach()
- if denoising_class is not None:
- target = torch.concat([denoising_class, target], 1)
- return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
- def __init__(self,
- num_classes:int,
- hidden_dim:int,
- num_queries:int,
- position_type:str='sine',
- feat_channels:list=[256,256,256],
- feat_strides:list=[8,16,24],
- num_levels:int=3,
- num_crossattn_points:int=4,
- number_head:int=8,
- number_decoder_layer:int=6,
- dim_feedforward_ratio:int=4,
- dropout:float=0.0,
- act_cfg:OptConfigType = dict(type='ReLU', inplace=True),
- num_denoising:int=100,
- label_noise_ratio:float=0.5,
- box_noise_scale:float=1.0,
- learnt_init_query:bool=True,
- eval_size:list=None,
- eval_spatial_size:list=None,
- eval_idx:int=-1,
- eps:float=1e-2
- ):
- super().__init__()
- assert position_type in ['sine', 'learned'], \
- f'ValueError: position_embed_type not supported {position_type}!'
- assert len(feat_channels) <= num_levels
- assert len(feat_strides) == len(feat_channels)
- for _ in range(num_levels - len(feat_strides)):
- feat_strides.append(feat_strides[-1] * 2)
- self.hidden_dim = hidden_dim
- self.number_head = number_head
- self.feat_strides = feat_strides
- self.num_levels = num_levels
- self.num_classes = num_classes
- self.num_queries = num_queries
- self.eps = eps
- self.number_decoder_layer = number_decoder_layer
- self.eval_size = eval_size
- self.num_denoising=num_denoising
- self.label_noise_ratio=label_noise_ratio
- self.box_noise_scale=box_noise_scale
- self.eval_idx=eval_idx
- self.eval_spatial_size=eval_spatial_size
- #backbone feature projection
- self._build_input_proj_layer(feat_channels)
- #Transformer module
- # embed_dims,
- # num_heads,
- # attn_drop=0.,
- # proj_drop=0.,
- # dropout_layer=dict(type='Dropout', drop_prob=0.),
- # init_cfg=None,
- # batch_first=False,
- # **kwargs
- self_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head,
- attn_drop=dropout,proj_drop=dropout,
- batch_first=True)
-
- # embed_dims: int = 256,
- # num_heads: int = 8,
- # num_levels: int = 4,
- # num_points: int = 4,
- # im2col_step: int = 64,
- # dropout: float = 0.1,
- # batch_first: bool = False,
- # norm_cfg: Optional[dict] = None,
- # init_cfg: Optional[mmengine.ConfigDict] = None,
- # value_proj_ratio: float = 1.0
- cross_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head,
- num_levels=num_levels,num_points=num_crossattn_points,
- dropout=dropout,batch_first=True)
- # embed_dims=256,
- # feedforward_channels=1024,
- # num_fcs=2,
- # act_cfg=dict(type='ReLU', inplace=True),
- # ffn_drop=0.,
- # dropout_layer=None,
- # add_identity=True,
- # init_cfg=None,
- # layer_scale_init_value=0.
- ffn_cfg=dict(embed_dims=hidden_dim,feedforward_channels=hidden_dim*dim_feedforward_ratio,
- num_fcs=2,ffn_drop=0,
- act_cfg=act_cfg)
- decode_layer_cfg=dict(self_attn_cfg=self_attn_cfg,cross_attn_cfg=cross_attn_cfg,ffn_cfg=ffn_cfg)
- self.decoder=RtdetrDecoder(num_layers=number_decoder_layer,layer_cfg=decode_layer_cfg)
- #denoising part
- # def __init__(self,
- # num_classes: int,
- # embed_dims: int,
- # num_matching_queries: int,
- # label_noise_scale: float = 0.5,
- # box_noise_scale: float = 1.0,
- # group_cfg: OptConfigType = None) -> None:
- if num_denoising>0:
- self.dino=CdnQueryGenerator(
- num_classes=num_classes,
- embed_dims=hidden_dim,
- num_matching_queries=num_queries,
- label_noise_scale=label_noise_ratio,
- box_noise_scale=box_noise_scale,
- group_cfg=dict(dynamic=True, num_groups=None,num_dn_queries=num_denoising))
-
- #decoder embedding
- self.learnt_init_query = learnt_init_query
- if learnt_init_query:
- self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
-
- self.query_pos_head = MLP(4,hidden_dim=hidden_dim,output_dim=hidden_dim,num_layers=2)
- #encoder head in transformer
- self.encoder_output=Sequential(nn.Linear(hidden_dim,hidden_dim),nn.LayerNorm(hidden_dim))
- self.encoder_score_head = nn.Linear(hidden_dim, num_classes)
- self.encoder_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
- #decoder head in transformer
- self.decoder_score_head=ModuleList(nn.Linear(hidden_dim,num_classes) for _ in range(number_decoder_layer))
- self.decoder_bbox_head=ModuleList(MLP(hidden_dim,hidden_dim,4,num_layers=3) for _ in range(number_decoder_layer))
- #reset parametre for encoder_head and decoder_head with xavier uniform
- if self.eval_spatial_size:
- self.anchors, self.valid_mask = self._generate_anchors()
- # print(f'anchors size is {self.anchors.size()} and valid mask size is {self.valid_mask.size()}')
- self._reset_parameter()
- def forward(self,feats,pad_mask=None,gt_meta=None):
- (memory, spatial_shapes,level_start_index) = self._get_encoder_input(feats)
- if self.training and self.num_denoising:
- denoising_class, denoising_bbox_unact, attn_mask, dn_meta=self.dino.__call__(batch_data_samples=gt_meta)
- else:
- denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
- target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
- self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
- # def forward(self,
- # target:Tensor,
- # memory:Tensor,
- # memory_spatial_shapes:Tensor,
- # memory_level_start_index:Tensor,
- # ref_points_unact:Tensor,
- # query_pos_head:FFN,
- # bbox_head:ModuleList,
- # score_head:ModuleList,
- # attn_mask:Tensor=None,
- # )->Tuple[Tensor]:
- #cls and bbox for query and reference
- query, reference=self.decoder(target=target,
- memory=memory,
- memory_spatial_shapes=spatial_shapes,
- memory_level_start_index=level_start_index,
- ref_points_unact=init_ref_points_unact,
- query_pos_head=self.query_pos_head,
- bbox_head=self.decoder_bbox_head,
- score_head=self.decoder_score_head,
- attn_mask=attn_mask)
- return (query, reference, enc_topk_bboxes, enc_topk_logits,
- dn_meta)
|