rtdetr_transformer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import numpy as np
  2. import math
  3. from typing import Tuple
  4. import torch
  5. import torch.nn.init as init
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule,build_activation_layer
  8. from mmengine.model import BaseModule,ModuleList,Sequential
  9. from torch import Tensor, nn
  10. from mmdet.registry import MODELS
  11. from mmdet.utils import OptConfigType,ConfigType
  12. from ..layers import RtdetrDecoder,CdnQueryGenerator
  13. from .dino_utils import get_contrastive_denoising_training_group
  14. def _bias_initial_with_prob(prob):
  15. bias_init=float(-np.log((1-prob)/(prob)))
  16. return bias_init
  17. @torch.no_grad()
  18. def _linear_init(module:nn.Module)->None:
  19. bound = 1 / math.sqrt(module.weight.shape[0])
  20. init.uniform_(module.weight,-bound,bound)
  21. if hasattr(module, "bias") and module.bias is not None:
  22. init.uniform(module.bias,-bound,bound)
  23. class MLP(nn.Module):
  24. def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=dict(type='ReLU')):
  25. super().__init__()
  26. self.num_layers = num_layers
  27. h = [hidden_dim] * (num_layers - 1)
  28. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  29. self.act = nn.Identity() if act is None else build_activation_layer(act)
  30. def forward(self, x):
  31. for i, layer in enumerate(self.layers):
  32. x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
  33. return x
  34. @MODELS.register_module()
  35. class RtDetrTransformer(BaseModule):
  36. def _build_input_proj_layer(self, backbone_feat_channels):
  37. self.input_proj = ModuleList()
  38. for in_channels in backbone_feat_channels:
  39. self.input_proj.append(
  40. ConvModule(
  41. in_channels,
  42. self.hidden_dim,
  43. kernel_size=1,
  44. bias=False,
  45. norm_cfg=dict(type='BN', requires_grad=True),
  46. act_cfg=None))
  47. in_channels=backbone_feat_channels[-1]
  48. for _ in range(self.num_levels - len(backbone_feat_channels)):
  49. self.input_proj.append(
  50. ConvModule(
  51. in_channels,
  52. self.hidden_dim,
  53. 3,
  54. 2,
  55. 1,
  56. bias=False,
  57. norm_cfg=dict(type='BN', requires_grad=True),
  58. act_cfg=None
  59. )
  60. )
  61. in_channels=self.hidden_dim
  62. def _generate_anchors(self,spatial_shapes:list=None,grid_size:float=0.05,dtype=torch.float32,device='cpu'):
  63. if spatial_shapes is None:
  64. spatial_shapes=[[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)] for s in self.feat_strides]
  65. anchors=[]
  66. # print(spatial_shapes)
  67. for lvl, (h, w) in enumerate(spatial_shapes):
  68. grid_y, grid_x = torch.meshgrid(\
  69. torch.arange(end=h, dtype=dtype), \
  70. torch.arange(end=w, dtype=dtype), indexing='ij')
  71. grid_xy = torch.stack([grid_x, grid_y], -1)
  72. valid_WH = torch.tensor([w, h]).to(dtype)
  73. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
  74. wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
  75. anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
  76. anchors = torch.concat(anchors, 1).to(device)
  77. valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
  78. # print(f'anchors size is {anchors.size()} and valid mask size is {valid_mask.size()}')
  79. anchors = torch.log(anchors / (1 - anchors))
  80. anchors = torch.where(valid_mask, anchors, torch.inf)
  81. return anchors, valid_mask
  82. def _reset_parameter(self):
  83. bias_cls=_bias_initial_with_prob(0.01)
  84. _linear_init(self.encoder_score_head)
  85. init.constant_(self.encoder_score_head.bias,bias_cls)
  86. init.constant_(self.encoder_bbox_head.layers[-1].weight,0.)
  87. init.constant_(self.encoder_bbox_head.layers[-1].bias,0.)
  88. for cls_,reg_ in zip(self.decoder_score_head,self.decoder_bbox_head):
  89. _linear_init(cls_)
  90. init.constant_(cls_.bias,bias_cls)
  91. init.constant_(reg_.layers[-1].weight,0.)
  92. init.constant_(reg_.layers[-1].bias,0.)
  93. _linear_init(self.encoder_output[0])
  94. init.xavier_uniform_(self.encoder_output[0].weight)
  95. if self.learnt_init_query:
  96. init.xavier_uniform_(self.tgt_embed.weight)
  97. init.xavier_uniform_(self.query_pos_head.layers[0].weight)
  98. init.xavier_uniform_(self.query_pos_head.layers[1].weight)
  99. # for l in self.input_proj:
  100. # init.xavier_uniform_(l.weight)
  101. def _get_encoder_input(self,feats:Tensor)->Tuple[Tensor]:
  102. proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
  103. if self.num_levels > len(proj_feats):
  104. len_srcs = len(proj_feats)
  105. for i in range(len_srcs, self.num_levels):
  106. if i == len_srcs:
  107. proj_feats.append(self.input_proj[i](feats[-1]))
  108. else:
  109. proj_feats.append(self.input_proj[i](proj_feats[-1]))
  110. #get input for encoder
  111. feat_flatten = []
  112. spatial_shapes = []
  113. for feat in proj_feats:
  114. spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
  115. # [b, c, h, w] -> [b, h*w, c]
  116. feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
  117. # [num_levels, 2] each level
  118. spatial_shapes.append(spatial_shape)
  119. spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
  120. level_start_index = torch.cat((
  121. spatial_shapes.new_zeros((1, )), # (num_level)
  122. spatial_shapes.prod(1).cumsum(0)[:-1]))
  123. # [b, l, c]
  124. feat_flatten = torch.concat(feat_flatten, 1)
  125. return (feat_flatten, spatial_shapes, level_start_index)
  126. def _get_decoder_input(self,memory:Tensor,
  127. spatial_shapes,
  128. denoising_class=None,
  129. denoising_bbox_unact=None):
  130. bs, _, _ = memory.shape
  131. # print(memory.size())
  132. #prepare input for decoder
  133. if self.training or self.eval_size is None:
  134. anchors, valid_mask = self._generate_anchors(spatial_shapes,device=memory.device)
  135. else:
  136. anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
  137. memory = valid_mask.to(memory.dtype) * memory
  138. output_memory = self.encoder_output(memory)
  139. enc_outputs_class = self.encoder_score_head(output_memory)
  140. enc_outputs_coord_unact = self.encoder_bbox_head(output_memory) + anchors
  141. topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)[1]
  142. # extract region proposal boxes
  143. reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
  144. index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
  145. enc_topk_bboxes = F.sigmoid(reference_points_unact)
  146. if denoising_bbox_unact is not None:
  147. reference_points_unact = torch.concat(
  148. [denoising_bbox_unact, reference_points_unact], 1)
  149. enc_topk_logits = enc_outputs_class.gather(dim=1, \
  150. index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
  151. # extract region features
  152. if self.learnt_init_query:
  153. target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
  154. else:
  155. target = output_memory.gather(dim=1, \
  156. index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
  157. target = target.detach()
  158. if denoising_class is not None:
  159. target = torch.concat([denoising_class, target], 1)
  160. return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
  161. def __init__(self,
  162. num_classes:int,
  163. hidden_dim:int,
  164. num_queries:int,
  165. position_type:str='sine',
  166. feat_channels:list=[256,256,256],
  167. feat_strides:list=[8,16,24],
  168. num_levels:int=3,
  169. num_crossattn_points:int=4,
  170. number_head:int=8,
  171. number_decoder_layer:int=6,
  172. dim_feedforward_ratio:int=4,
  173. dropout:float=0.0,
  174. act_cfg:OptConfigType = dict(type='ReLU', inplace=True),
  175. num_denoising:int=100,
  176. label_noise_ratio:float=0.5,
  177. box_noise_scale:float=1.0,
  178. learnt_init_query:bool=True,
  179. eval_size:list=None,
  180. eval_spatial_size:list=None,
  181. eval_idx:int=-1,
  182. eps:float=1e-2
  183. ):
  184. super().__init__()
  185. assert position_type in ['sine', 'learned'], \
  186. f'ValueError: position_embed_type not supported {position_type}!'
  187. assert len(feat_channels) <= num_levels
  188. assert len(feat_strides) == len(feat_channels)
  189. for _ in range(num_levels - len(feat_strides)):
  190. feat_strides.append(feat_strides[-1] * 2)
  191. self.hidden_dim = hidden_dim
  192. self.number_head = number_head
  193. self.feat_strides = feat_strides
  194. self.num_levels = num_levels
  195. self.num_classes = num_classes
  196. self.num_queries = num_queries
  197. self.eps = eps
  198. self.number_decoder_layer = number_decoder_layer
  199. self.eval_size = eval_size
  200. self.num_denoising=num_denoising
  201. self.label_noise_ratio=label_noise_ratio
  202. self.box_noise_scale=box_noise_scale
  203. self.eval_idx=eval_idx
  204. self.eval_spatial_size=eval_spatial_size
  205. #backbone feature projection
  206. self._build_input_proj_layer(feat_channels)
  207. #Transformer module
  208. # embed_dims,
  209. # num_heads,
  210. # attn_drop=0.,
  211. # proj_drop=0.,
  212. # dropout_layer=dict(type='Dropout', drop_prob=0.),
  213. # init_cfg=None,
  214. # batch_first=False,
  215. # **kwargs
  216. self_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head,
  217. attn_drop=dropout,proj_drop=dropout,
  218. batch_first=True)
  219. # embed_dims: int = 256,
  220. # num_heads: int = 8,
  221. # num_levels: int = 4,
  222. # num_points: int = 4,
  223. # im2col_step: int = 64,
  224. # dropout: float = 0.1,
  225. # batch_first: bool = False,
  226. # norm_cfg: Optional[dict] = None,
  227. # init_cfg: Optional[mmengine.ConfigDict] = None,
  228. # value_proj_ratio: float = 1.0
  229. cross_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head,
  230. num_levels=num_levels,num_points=num_crossattn_points,
  231. dropout=dropout,batch_first=True)
  232. # embed_dims=256,
  233. # feedforward_channels=1024,
  234. # num_fcs=2,
  235. # act_cfg=dict(type='ReLU', inplace=True),
  236. # ffn_drop=0.,
  237. # dropout_layer=None,
  238. # add_identity=True,
  239. # init_cfg=None,
  240. # layer_scale_init_value=0.
  241. ffn_cfg=dict(embed_dims=hidden_dim,feedforward_channels=hidden_dim*dim_feedforward_ratio,
  242. num_fcs=2,ffn_drop=0,
  243. act_cfg=act_cfg)
  244. decode_layer_cfg=dict(self_attn_cfg=self_attn_cfg,cross_attn_cfg=cross_attn_cfg,ffn_cfg=ffn_cfg)
  245. self.decoder=RtdetrDecoder(num_layers=number_decoder_layer,layer_cfg=decode_layer_cfg)
  246. #denoising part
  247. # def __init__(self,
  248. # num_classes: int,
  249. # embed_dims: int,
  250. # num_matching_queries: int,
  251. # label_noise_scale: float = 0.5,
  252. # box_noise_scale: float = 1.0,
  253. # group_cfg: OptConfigType = None) -> None:
  254. if num_denoising>0:
  255. self.dino=CdnQueryGenerator(
  256. num_classes=num_classes,
  257. embed_dims=hidden_dim,
  258. num_matching_queries=num_queries,
  259. label_noise_scale=label_noise_ratio,
  260. box_noise_scale=box_noise_scale,
  261. group_cfg=dict(dynamic=True, num_groups=None,num_dn_queries=num_denoising))
  262. #decoder embedding
  263. self.learnt_init_query = learnt_init_query
  264. if learnt_init_query:
  265. self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
  266. self.query_pos_head = MLP(4,hidden_dim=hidden_dim,output_dim=hidden_dim,num_layers=2)
  267. #encoder head in transformer
  268. self.encoder_output=Sequential(nn.Linear(hidden_dim,hidden_dim),nn.LayerNorm(hidden_dim))
  269. self.encoder_score_head = nn.Linear(hidden_dim, num_classes)
  270. self.encoder_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  271. #decoder head in transformer
  272. self.decoder_score_head=ModuleList(nn.Linear(hidden_dim,num_classes) for _ in range(number_decoder_layer))
  273. self.decoder_bbox_head=ModuleList(MLP(hidden_dim,hidden_dim,4,num_layers=3) for _ in range(number_decoder_layer))
  274. #reset parametre for encoder_head and decoder_head with xavier uniform
  275. if self.eval_spatial_size:
  276. self.anchors, self.valid_mask = self._generate_anchors()
  277. # print(f'anchors size is {self.anchors.size()} and valid mask size is {self.valid_mask.size()}')
  278. self._reset_parameter()
  279. def forward(self,feats,pad_mask=None,gt_meta=None):
  280. (memory, spatial_shapes,level_start_index) = self._get_encoder_input(feats)
  281. if self.training and self.num_denoising:
  282. denoising_class, denoising_bbox_unact, attn_mask, dn_meta=self.dino.__call__(batch_data_samples=gt_meta)
  283. else:
  284. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
  285. target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
  286. self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
  287. # def forward(self,
  288. # target:Tensor,
  289. # memory:Tensor,
  290. # memory_spatial_shapes:Tensor,
  291. # memory_level_start_index:Tensor,
  292. # ref_points_unact:Tensor,
  293. # query_pos_head:FFN,
  294. # bbox_head:ModuleList,
  295. # score_head:ModuleList,
  296. # attn_mask:Tensor=None,
  297. # )->Tuple[Tensor]:
  298. #cls and bbox for query and reference
  299. query, reference=self.decoder(target=target,
  300. memory=memory,
  301. memory_spatial_shapes=spatial_shapes,
  302. memory_level_start_index=level_start_index,
  303. ref_points_unact=init_ref_points_unact,
  304. query_pos_head=self.query_pos_head,
  305. bbox_head=self.decoder_bbox_head,
  306. score_head=self.decoder_score_head,
  307. attn_mask=attn_mask)
  308. return (query, reference, enc_topk_bboxes, enc_topk_logits,
  309. dn_meta)