rtdetr_layers.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. from typing import Optional,Tuple,Dict,List
  2. import math
  3. import numpy as np
  4. import types
  5. import torch
  6. import torch.nn.functional as F
  7. import torch.nn as nn
  8. from mmcv.cnn import build_norm_layer,ConvModule,build_activation_layer,build_conv_layer
  9. from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
  10. from mmengine.model import BaseModule, ModuleList
  11. from torch import Tensor
  12. from mmcv.ops import MultiScaleDeformableAttention
  13. from mmdet.models.layers.transformer.utils import inverse_sigmoid
  14. from .deformable_detr_layers import DetrTransformerDecoder,DetrTransformerDecoderLayer
  15. from mmdet.utils import ConfigType, OptConfigType,OptMultiConfig
  16. from mmdet.registry import MODELS
  17. class SPD(nn.Module):
  18. # Changing the dimension of the Tensor
  19. def __init__(self, dimension=1):
  20. super().__init__()
  21. self.d = dimension
  22. def forward(self, x):
  23. return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
  24. class RepVGGBlock(BaseModule):
  25. """A block in RepVGG architecture, supporting optional normalization in the
  26. identity branch.
  27. This block consists of 3x3 and 1x1 convolutions, with an optional identity
  28. shortcut branch that includes normalization.
  29. Args:
  30. in_channels (int): The input channels of the block.
  31. out_channels (int): The output channels of the block.
  32. stride (int): The stride of the block. Defaults to 1.
  33. padding (int): The padding of the block. Defaults to 1.
  34. dilation (int): The dilation of the block. Defaults to 1.
  35. groups (int): The groups of the block. Defaults to 1.
  36. padding_mode (str): The padding mode of the block. Defaults to 'zeros'.
  37. norm_cfg (dict): The config dict for normalization layers.
  38. Defaults to dict(type='BN').
  39. act_cfg (dict): The config dict for activation layers.
  40. Defaults to dict(type='ReLU').
  41. without_branch_norm (bool): Whether to skip branch_norm.
  42. Defaults to True.
  43. init_cfg (dict): The config dict for initialization. Defaults to None.
  44. """
  45. def __init__(self,
  46. in_channels: int,
  47. out_channels: int,
  48. stride: int = 1,
  49. padding: int = 1,
  50. dilation: int = 1,
  51. groups: int = 1,
  52. norm_cfg: OptConfigType = dict(type='BN', momentum=0.03, eps=0.001),
  53. act_cfg: OptConfigType = dict(type='ReLU',inplace=True),
  54. without_branch_norm: bool = True,
  55. init_cfg: OptConfigType = None):
  56. super(RepVGGBlock, self).__init__(init_cfg)
  57. self.in_channels = in_channels
  58. self.out_channels = out_channels
  59. self.stride = stride
  60. self.padding = padding
  61. self.dilation = dilation
  62. self.groups = groups
  63. self.norm_cfg = norm_cfg
  64. self.act_cfg = act_cfg
  65. # judge if input shape and output shape are the same.
  66. # If true, add a normalized identity shortcut.
  67. self.branch_norm = None
  68. if out_channels == in_channels and stride == 1 and \
  69. padding == dilation and not without_branch_norm:
  70. self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
  71. self.branch_3x3 = ConvModule(
  72. self.in_channels,
  73. self.out_channels,
  74. 3,
  75. stride=self.stride,
  76. padding=self.padding,
  77. groups=self.groups,
  78. dilation=self.dilation,
  79. norm_cfg=self.norm_cfg,
  80. act_cfg=None)
  81. self.branch_1x1 = ConvModule(
  82. self.in_channels,
  83. self.out_channels,
  84. 1,
  85. groups=self.groups,
  86. norm_cfg=self.norm_cfg,
  87. act_cfg=None)
  88. self.act = build_activation_layer(act_cfg)
  89. def forward(self, x: Tensor) -> Tensor:
  90. """Forward pass through the RepVGG block.
  91. The output is the sum of 3x3 and 1x1 convolution outputs,
  92. along with the normalized identity branch output, followed by
  93. activation.
  94. Args:
  95. x (Tensor): The input tensor.
  96. Returns:
  97. Tensor: The output tensor.
  98. """
  99. if self.branch_norm is None:
  100. branch_norm_out = 0
  101. else:
  102. branch_norm_out = self.branch_norm(x)
  103. out = self.branch_3x3(x) + self.branch_1x1(x) + branch_norm_out
  104. out = self.act(out)
  105. return out
  106. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  107. """Pad 1x1 tensor to 3x3.
  108. Args:
  109. kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
  110. Returns:
  111. Tensor: 3x3 kernel after padded.
  112. """
  113. if kernel1x1 is None:
  114. return 0
  115. else:
  116. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  117. def _fuse_bn_tensor(self, branch: nn.Module) -> Tensor:
  118. """Derives the equivalent kernel and bias of a specific branch layer.
  119. Args:
  120. branch (nn.Module): The layer that needs to be equivalently
  121. transformed, which can be nn.Sequential or nn.Batchnorm2d
  122. Returns:
  123. tuple: Equivalent kernel and bias
  124. """
  125. if branch is None:
  126. return 0, 0
  127. if isinstance(branch, ConvModule):
  128. kernel = branch.conv.weight
  129. running_mean = branch.bn.running_mean
  130. running_var = branch.bn.running_var
  131. gamma = branch.bn.weight
  132. beta = branch.bn.bias
  133. eps = branch.bn.eps
  134. else:
  135. assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d))
  136. if not hasattr(self, 'id_tensor'):
  137. input_dim = self.in_channels // self.groups
  138. kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
  139. dtype=np.float32)
  140. for i in range(self.in_channels):
  141. kernel_value[i, i % input_dim, 1, 1] = 1
  142. self.id_tensor = torch.from_numpy(kernel_value).to(
  143. branch.weight.device)
  144. kernel = self.id_tensor
  145. running_mean = branch.running_mean
  146. running_var = branch.running_var
  147. gamma = branch.weight
  148. beta = branch.bias
  149. eps = branch.eps
  150. std = (running_var + eps).sqrt()
  151. t = (gamma / std).reshape(-1, 1, 1, 1)
  152. return kernel * t, beta - running_mean * gamma / std
  153. def get_equivalent_kernel_bias(self):
  154. """Derives the equivalent kernel and bias in a differentiable way.
  155. Returns:
  156. tuple: Equivalent kernel and bias
  157. """
  158. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3)
  159. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1)
  160. kernelid, biasid = (0, 0) if self.branch_norm is None else \
  161. self._fuse_bn_tensor(self.branch_norm)
  162. return (kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
  163. bias3x3 + bias1x1 + biasid)
  164. def switch_to_deploy(self, test_cfg: Optional[Dict] = None):
  165. """Switches the block to deployment mode.
  166. In deployment mode, the block uses a single convolution operation
  167. derived from the equivalent kernel and bias, replacing the original
  168. branches. This reduces computational complexity during inference.
  169. """
  170. if getattr(self, 'deploy', False):
  171. return
  172. kernel, bias = self.get_equivalent_kernel_bias()
  173. self.conv_reparam = nn.Conv2d(
  174. in_channels=self.branch_3x3.conv.in_channels,
  175. out_channels=self.branch_3x3.conv.out_channels,
  176. kernel_size=self.branch_3x3.conv.kernel_size,
  177. stride=self.branch_3x3.conv.stride,
  178. padding=self.branch_3x3.conv.padding,
  179. dilation=self.branch_3x3.conv.dilation,
  180. groups=self.branch_3x3.conv.groups,
  181. bias=True)
  182. self.conv_reparam.weight.data = kernel
  183. self.conv_reparam.bias.data = bias
  184. for para in self.parameters():
  185. para.detach_()
  186. self.__delattr__('branch_3x3')
  187. self.__delattr__('branch_1x1')
  188. if hasattr(self, 'branch_norm'):
  189. self.__delattr__('branch_norm')
  190. def _forward(self, x):
  191. return self.act(self.conv_reparam(x))
  192. self.forward = types.MethodType(_forward, self)
  193. self.deploy = True
  194. class CSPRepLayer(BaseModule):
  195. """CSPRepLayer, a layer that combines Cross Stage Partial Networks with
  196. RepVGG Blocks.
  197. Args:
  198. in_channels (int): Number of input channels to the layer.
  199. out_channels (int): Number of output channels from the layer.
  200. num_blocks (int): The number of RepVGG blocks to be used in the layer.
  201. Defaults to 3.
  202. widen_factor (float): Expansion factor for intermediate channels.
  203. Determines the hidden channel size based on out_channels.
  204. Defaults to 1.0.
  205. norm_cfg (dict): Configuration for normalization layers.
  206. Defaults to Batch Normalization with trainable parameters.
  207. act_cfg (dict): Configuration for activation layers.
  208. Defaults to SiLU (Swish) with in-place operation.
  209. """
  210. def __init__(self,
  211. in_channels: int,
  212. out_channels: int,
  213. num_blocks: int = 3,
  214. widen_factor: float = 1.0,
  215. norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
  216. act_cfg: OptConfigType = dict(type='SiLU', inplace=True)):
  217. super(CSPRepLayer, self).__init__()
  218. hidden_channels = int(out_channels * widen_factor)
  219. self.conv1 = ConvModule(
  220. in_channels,
  221. hidden_channels,
  222. kernel_size=1,
  223. norm_cfg=norm_cfg,
  224. act_cfg=act_cfg)
  225. self.conv2 = ConvModule(
  226. in_channels,
  227. hidden_channels,
  228. kernel_size=1,
  229. norm_cfg=norm_cfg,
  230. act_cfg=act_cfg)
  231. self.bottlenecks = nn.Sequential(*[
  232. RepVGGBlock(hidden_channels, hidden_channels, act_cfg=act_cfg,norm_cfg=norm_cfg)
  233. for _ in range(num_blocks)
  234. ])
  235. if hidden_channels != out_channels:
  236. self.conv3 = ConvModule(
  237. hidden_channels,
  238. out_channels,
  239. kernel_size=1,
  240. norm_cfg=norm_cfg,
  241. act_cfg=act_cfg)
  242. else:
  243. self.conv3 = nn.Identity()
  244. def forward(self, x: Tensor) -> Tensor:
  245. """Forward function.
  246. Args:
  247. x (Tensor): The input tensor.
  248. Returns:
  249. Tensor: The output tensor.
  250. """
  251. x_1 = self.conv1(x)
  252. x_1 = self.bottlenecks(x_1)
  253. x_2 = self.conv2(x)
  254. return self.conv3(x_1 + x_2)
  255. #Encoder and Encoder layer with embedded postion
  256. class EncoderLayer(BaseModule):
  257. def __init__(self,
  258. self_attn_cfg:OptConfigType=dict(
  259. embed_dims=256,
  260. num_heads=8,
  261. attn_drop=0,
  262. proj_drop=0,
  263. ),
  264. ffn_cfg:OptConfigType=dict(
  265. embed_dims=256,
  266. feedforward_channels=1024,
  267. num_fcs=2,ffn_drop=0,
  268. act_cfg=dict(type='ReLU',inpalce=True)),
  269. norm_cfg: OptConfigType = dict(type='LN'),
  270. init_cfg: OptConfigType = None,
  271. )->None:
  272. super().__init__(init_cfg)
  273. self.self_attn_cfg=self_attn_cfg
  274. if 'batch_first' not in self.self_attn_cfg:
  275. self.self_attn_cfg['batch_first'] = True
  276. else:
  277. assert self.self_attn_cfg['batch_first'] is True, 'First \
  278. dimension of all DETRs in mmdet is `batch`, \
  279. please set `batch_first` flag.'
  280. self.ffn_cfg = ffn_cfg
  281. self.norm_cfg = norm_cfg
  282. self._init_layers()
  283. def _init_layers(self)->None:
  284. #initialize the FFN and Multiheadattention layer
  285. self.self_attn = MultiheadAttention(**self.self_attn_cfg)
  286. self.embed_dims = self.self_attn.embed_dims
  287. self.ffn = FFN(**self.ffn_cfg)
  288. norms_list = [
  289. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  290. for _ in range(2)
  291. ]
  292. self.norms = ModuleList(norms_list)
  293. def forward(self,query:Tensor,pos_embed:Tensor,key_padding_mask=None)->Tensor:
  294. query = self.self_attn(
  295. query=query,
  296. key=query,
  297. value=query,
  298. query_pos=pos_embed,
  299. key_pos=pos_embed,
  300. key_padding_mask=key_padding_mask,
  301. )
  302. query=self.norms[0](query)
  303. query=self.ffn(query)
  304. query=self.norms[1](query)
  305. return query
  306. class Encoder(BaseModule):
  307. def __init__(self, num_layers,layer_cfg: ConfigType=None):
  308. super().__init__()
  309. self.num_layers=num_layers
  310. self.layer_cfg = layer_cfg
  311. self._init_layers()
  312. def _init_layers(self)->None:
  313. self.layers = ModuleList([
  314. EncoderLayer(**self.layer_cfg)
  315. for _ in range(self.num_layers)
  316. ])
  317. self.embed_dims = self.layers[0].embed_dims
  318. def forward(self,query: Tensor, query_pos: Tensor,
  319. key_padding_mask: Tensor=None)->Tensor:
  320. for layer in self.layers:
  321. output=layer(query,query_pos,key_padding_mask)
  322. return output
  323. #RtdetrFPN
  324. class RTDETRFPN(BaseModule):
  325. """FPN of RTDETR.
  326. Args:
  327. in_channels (List[int], optional): The input channels of the
  328. feature maps. Defaults to [256, 256, 256].
  329. out_channels (int, optional): The output dimension of the MLP.
  330. Defaults to 256.
  331. expansion (float, optional): The expansion of the CSPLayer.
  332. Defaults to 1.0.
  333. depth_mult (float, optional): The depth multiplier of the CSPLayer.
  334. Defaults to 1.0.
  335. upsample_cfg (dict): Config dict for interpolate layer.
  336. Default: `dict(scale_factor=2, mode='nearest')`
  337. conv_cfg (dict, optional): Config dict for convolution layer.
  338. Default: None, which means using conv2d.
  339. norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for
  340. normalization layers. Defaults to dict(type='BN').
  341. act_cfg (:obj:`ConfigDict` or dict, optional): The config dict for
  342. activation layers. Defaults to dict(type='SiLU', inplace=True).
  343. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  344. list[:obj:`ConfigDict`], optional): Initialization config dict.
  345. """
  346. def __init__(
  347. self,
  348. in_channels: List[int] = [256, 256, 256],
  349. out_channels: int = 256,
  350. expansion: float = 1.0,
  351. depth_mult: float = 1.0,
  352. with_spd:bool=True,
  353. upsample_cfg: ConfigType = dict(scale_factor=2, mode='nearest'),
  354. conv_cfg: OptConfigType = None,
  355. norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
  356. act_cfg: OptConfigType = dict(type='SiLU', inplace=True),
  357. init_cfg: OptMultiConfig = dict(
  358. type='Kaiming',
  359. layer='Conv2d',
  360. a=math.sqrt(5),
  361. distribution='uniform',
  362. mode='fan_in',
  363. nonlinearity='leaky_relu')
  364. ) -> None:
  365. super().__init__(init_cfg=init_cfg)
  366. self.in_channels = in_channels
  367. self.out_channels = out_channels
  368. num_csp_blocks = round(3 * depth_mult)
  369. # top-down fpn
  370. self.upsample = nn.Upsample(**upsample_cfg)
  371. self.reduce_layers = nn.ModuleList()
  372. self.top_down_blocks = nn.ModuleList()
  373. for idx in range(len(in_channels) - 1, 0, -1):
  374. self.reduce_layers.append(
  375. ConvModule(
  376. in_channels[idx],
  377. in_channels[idx - 1],
  378. 1,
  379. conv_cfg=conv_cfg,
  380. norm_cfg=norm_cfg,
  381. act_cfg=act_cfg))
  382. self.top_down_blocks.append(
  383. CSPRepLayer(
  384. in_channels[idx - 1] * 2,
  385. in_channels[idx - 1],
  386. num_blocks=num_csp_blocks,
  387. widen_factor=expansion,
  388. norm_cfg=norm_cfg,
  389. act_cfg=act_cfg))
  390. # build bottom-up blocks
  391. self.downsamples = nn.ModuleList()
  392. self.bottom_up_blocks = nn.ModuleList()
  393. self.with_spd=with_spd
  394. self.spd=SPD()
  395. self.with_spd_norm=build_norm_layer(norm_cfg, in_channels[idx]*4, postfix=1)[1]
  396. if self.with_spd:
  397. for idx in range(len(in_channels) - 1):
  398. self.downsamples.append(
  399. build_conv_layer(
  400. conv_cfg,
  401. in_channels[idx],
  402. in_channels[idx],
  403. 3,
  404. stride=1,
  405. padding=1,
  406. bias=False),
  407. )
  408. self.bottom_up_blocks.append(
  409. CSPRepLayer(
  410. in_channels[idx] * 5,
  411. in_channels[idx + 1],
  412. num_blocks=num_csp_blocks,
  413. widen_factor=expansion,
  414. norm_cfg=norm_cfg,
  415. act_cfg=act_cfg))
  416. else:
  417. for idx in range(len(in_channels) - 1):
  418. self.downsamples.append(
  419. ConvModule(
  420. in_channels[idx],
  421. in_channels[idx],
  422. 3,
  423. stride=2,
  424. padding=1,
  425. conv_cfg=conv_cfg,
  426. norm_cfg=norm_cfg,
  427. act_cfg=act_cfg))
  428. self.bottom_up_blocks.append(
  429. CSPRepLayer(
  430. in_channels[idx] * 2,
  431. in_channels[idx + 1],
  432. num_blocks=num_csp_blocks,
  433. widen_factor=expansion,
  434. norm_cfg=norm_cfg,
  435. act_cfg=act_cfg))
  436. self.out_convs = nn.ModuleList()
  437. for i in range(len(in_channels)):
  438. self.out_convs.append(
  439. ConvModule(
  440. in_channels[i],
  441. out_channels,
  442. 1,
  443. conv_cfg=conv_cfg,
  444. norm_cfg=norm_cfg,
  445. act_cfg=None))
  446. def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
  447. """
  448. Args:
  449. inputs (tuple[Tensor]): input features.
  450. Returns:
  451. tuple[Tensor]: FPN features.
  452. """
  453. assert len(inputs) == len(self.in_channels)
  454. # top-down path
  455. inner_outs = [inputs[-1]]
  456. for idx in range(len(self.in_channels) - 1, 0, -1):
  457. feat_high = inner_outs[0]
  458. feat_low = inputs[idx - 1]
  459. feat_high = self.reduce_layers[len(self.in_channels) - 1 - idx](
  460. feat_high)
  461. inner_outs[0] = feat_high
  462. upsample_feat = self.upsample(feat_high)
  463. inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
  464. torch.cat([upsample_feat, feat_low], 1))
  465. inner_outs.insert(0, inner_out)
  466. # bottom-up path
  467. outs = [inner_outs[0]]
  468. for idx in range(len(self.in_channels) - 1):
  469. feat_low = outs[-1]
  470. feat_high = inner_outs[idx + 1]
  471. downsample_feat = self.downsamples[idx](feat_low)
  472. if self.with_spd:
  473. downsample_feat = self.spd(downsample_feat)
  474. downsample_feat = self.with_spd_norm(downsample_feat)
  475. out = self.bottom_up_blocks[idx](
  476. torch.cat([downsample_feat, feat_high], 1))
  477. outs.append(out)
  478. # out convs
  479. for idx, conv in enumerate(self.out_convs):
  480. outs[idx] = conv(outs[idx])
  481. return tuple(outs)
  482. #Instra-scale feature interaction and cross-sacle feature-fusion
  483. class SSFF(BaseModule):
  484. def __init__(self,
  485. in_channels:list,
  486. out_channels,
  487. ):
  488. super().__init__()
  489. self.in_channels=in_channels
  490. self.out_channels=out_channels
  491. self.convs = nn.ModuleList()
  492. for in_channel in in_channels:
  493. self.convs.append(
  494. ConvModule(
  495. in_channel,
  496. out_channels,
  497. 1,
  498. padding=0,
  499. conv_cfg=None,
  500. norm_cfg=dict(type='BN', requires_grad=True),
  501. act_cfg=dict(type='ReLU')))
  502. self.conv3d=nn.Conv3d(out_channels,out_channels,kernel_size=(1,1,1))
  503. self.bn3d=nn.BatchNorm3d(out_channels)
  504. self.act = nn.LeakyReLU(0.1)
  505. self.pool_3d = nn.MaxPool3d(kernel_size=(3,1,1))
  506. def forward(self,inputs)->Tensor:
  507. outputs=[]
  508. for i in range(len(inputs)):
  509. feature=self.convs[i](inputs[i])
  510. if i!=0:
  511. feature=F.interpolate(feature,inputs[0].size()[2:], mode='nearest')
  512. outputs.append(feature)
  513. for i in range(len(outputs)):
  514. outputs[i]=torch.unsqueeze(outputs[i], -3)
  515. combine=torch.cat(outputs,dim=2)
  516. conv_3d = self.act(self.bn3d(self.conv3d(combine)))
  517. output = self.pool_3d(conv_3d)
  518. output = torch.squeeze(output, 2)
  519. return output
  520. @MODELS.register_module()
  521. class HybridEncoder(BaseModule):
  522. def __init__(self,
  523. in_channels=[512,1024,2048],
  524. feat_strides=[8,16,32],
  525. hidden_dim=256,
  526. n_head=8,
  527. dim_feedforward_ratio=4,
  528. drop_out=0.0,
  529. enc_act:OptConfigType=dict(type='GELU'),
  530. use_encoder_idx=[2],
  531. num_encoder_layers=1,
  532. with_ssff:bool=False,
  533. with_spd:bool=False,
  534. pe_temperature=100*100,
  535. norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
  536. widen_factor=1,
  537. deepen_factor=1,
  538. eval_spatial_size=None,
  539. input_proj_cfg:OptConfigType=None,
  540. act_cfg: OptConfigType = dict(type='SiLU', inplace=True)
  541. ):
  542. super().__init__()
  543. self.in_channels = in_channels
  544. self.feat_strides = feat_strides
  545. self.hidden_dim = hidden_dim
  546. self.use_encoder_idx = use_encoder_idx
  547. self.num_encoder_layers = num_encoder_layers
  548. self.pe_temperature = pe_temperature
  549. self.eval_spatial_size = eval_spatial_size
  550. self.out_channels = [hidden_dim for _ in range(len(in_channels))]
  551. self.out_strides = feat_strides
  552. self.with_ssff=with_ssff
  553. #using channel mapper implemented in ChannelMapper
  554. self.input_proj = MODELS.build(input_proj_cfg)\
  555. if input_proj_cfg is not None else nn.Identity()
  556. if self.with_ssff:
  557. self.ssff=SSFF(in_channels=[hidden_dim,hidden_dim,hidden_dim],out_channels=hidden_dim)
  558. #transformer encoder and position encoder
  559. # def __init__(self,
  560. # embed_dims,
  561. # num_heads,
  562. # attn_drop=0.,
  563. # proj_drop=0.,
  564. # dropout_layer=dict(type='Dropout', drop_prob=0.),
  565. # init_cfg=None,
  566. # batch_first=False,
  567. # **kwargs)
  568. #Multihead
  569. # def __init__(self,
  570. # embed_dims=256,
  571. # feedforward_channels=1024,
  572. # num_fcs=2,
  573. # act_cfg=dict(type='ReLU', inplace=True),
  574. # ffn_drop=0.,
  575. # dropout_layer=None,
  576. # add_identity=True,
  577. # init_cfg=None,
  578. # layer_scale_init_value=0.):
  579. #FFN
  580. encoder_layer_opt = dict(
  581. self_attn_cfg=dict(embed_dims=hidden_dim,
  582. num_heads=n_head,
  583. attn_drop=drop_out,
  584. proj_drop=drop_out,
  585. ),
  586. ffn_cfg=dict(embed_dims=hidden_dim,
  587. feedforward_channels=hidden_dim*dim_feedforward_ratio,
  588. num_fcs=2,
  589. ffn_drop=drop_out,
  590. act_cfg=enc_act)
  591. )
  592. self.encoder = nn.ModuleList([
  593. Encoder(num_encoder_layers, layer_cfg=encoder_layer_opt) for _ in range(len(use_encoder_idx))
  594. ])
  595. self.fpn=RTDETRFPN(in_channels=[hidden_dim,hidden_dim,hidden_dim],
  596. out_channels=hidden_dim,
  597. expansion=widen_factor,
  598. depth_mult=deepen_factor,
  599. norm_cfg=norm_cfg,
  600. act_cfg=act_cfg,
  601. with_spd=with_spd
  602. )
  603. self._reset_parameters()
  604. def _reset_parameters(self):
  605. if self.eval_spatial_size:
  606. for idx in self.use_encoder_idx:
  607. stride = self.feat_strides[idx]
  608. pos_embed = self.build_2d_sincos_position_embedding(
  609. self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride,
  610. self.hidden_dim, self.pe_temperature)
  611. setattr(self, f'pos_embed{idx}', pos_embed)
  612. # self.register_buffer(f'pos_embed{idx}', pos_embed)
  613. @staticmethod
  614. def build_2d_sincos_position_embedding(
  615. w: int,
  616. h: int,
  617. embed_dim: int = 256,
  618. temperature: float = 10000.,
  619. device=None,
  620. ) -> Tensor:
  621. grid_w = torch.arange(w, dtype=torch.float32, device=device)
  622. grid_h = torch.arange(h, dtype=torch.float32, device=device)
  623. grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
  624. assert embed_dim % 4 == 0, ('Embed dimension must be divisible by 4 '
  625. 'for 2D sin-cos position embedding')
  626. pos_dim = embed_dim // 4
  627. omega = torch.arange(pos_dim, dtype=torch.float32, device=device)
  628. omega = temperature**(omega / -pos_dim)
  629. out_w = grid_w.flatten()[..., None] @ omega[None]
  630. out_h = grid_h.flatten()[..., None] @ omega[None]
  631. pos_embd = [
  632. torch.sin(out_w),
  633. torch.cos(out_w),
  634. torch.sin(out_h),
  635. torch.cos(out_h)
  636. ]
  637. return torch.cat(pos_embd, axis=1)[None, :, :]
  638. def forward(self,inputs:Tuple[Tensor])->Tuple[Tensor]:
  639. assert len(inputs)==len(self.in_channels)
  640. proj_feats=self.input_proj(inputs)
  641. proj_feats=list(proj_feats)
  642. if self.with_ssff:
  643. fuse_layer=self.ssff(proj_feats)
  644. proj_feats[len(proj_feats)-1]=fuse_layer
  645. # proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(inputs)]
  646. #encoder with position encoding
  647. if self.num_encoder_layers>0:
  648. for i,enc_idx in enumerate(self.use_encoder_idx):
  649. h,w=proj_feats[enc_idx].shape[2:]
  650. #B,C,H,W -> B,H*W,C
  651. src_flatten=proj_feats[enc_idx].flatten(2).permute(0,2,1).contiguous()
  652. if self.training or self.eval_spatial_size is None:
  653. pos_enc = self.build_2d_sincos_position_embedding(
  654. h,
  655. w,
  656. embed_dim=self.hidden_dim,
  657. temperature=self.pe_temperature,
  658. device=src_flatten.device)
  659. else:
  660. pos_enc=getattr(self, f'pos_embed{enc_idx}', None).to(src_flatten.device)
  661. memory = self.encoder[i](
  662. src_flatten, query_pos=pos_enc)
  663. proj_feats[enc_idx] = memory.permute(
  664. 0, 2, 1).contiguous().reshape([-1, self.hidden_dim, h, w])
  665. #fpn
  666. outs=self.fpn(tuple(proj_feats))
  667. return outs
  668. #derived from detrTransformerDecoder check init in detr TransformerDecoder
  669. class RtDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
  670. """Decoder layer of Deformable DETR."""
  671. def _init_layers(self) -> None:
  672. """Initialize self_attn, cross-attn, ffn, and norms."""
  673. self.self_attn = MultiheadAttention(**self.self_attn_cfg)
  674. self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
  675. self.embed_dims = self.self_attn.embed_dims
  676. self.ffn = FFN(**self.ffn_cfg)
  677. norms_list = [
  678. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  679. for _ in range(3)
  680. ]
  681. self.norms = ModuleList(norms_list)
  682. def with_pos_embed(self, tensor, pos):
  683. return tensor if pos is None else tensor + pos
  684. def forward(self,
  685. tgt:Tensor,
  686. referenc_point:Tensor,
  687. memory:Tensor,
  688. spartial_shapes:Tensor,
  689. level_start_index:Tensor,
  690. query_pos_embed:Tensor,
  691. attn_mask:Tensor=None,
  692. )->Tensor:
  693. #tgt is feature from backbone
  694. #reference point is 2d coodinates corresponding to features
  695. #memory is output from hybrid encoder
  696. #query embedding is embeding with refrence point
  697. #self attention
  698. tgt_after_attn=self.self_attn(query=tgt,
  699. key=tgt,
  700. value=tgt,
  701. query_pos=query_pos_embed,
  702. attn_mask=attn_mask)
  703. tgt=tgt+tgt_after_attn
  704. tgt=self.norms[0](tgt)
  705. #cross attention
  706. #level_start_index and spatial_shapes need to be tensor
  707. tgt_after_attn=self.cross_attn.forward(
  708. query=tgt,value=memory,
  709. reference_points=referenc_point,
  710. spatial_shapes=spartial_shapes,
  711. query_pos=query_pos_embed,
  712. level_start_index=level_start_index)
  713. tgt=tgt+tgt_after_attn
  714. tgt=self.norms[1](tgt)
  715. #feed forward
  716. tgt_after_attn=self.ffn(tgt)
  717. tgt=tgt+tgt_after_attn
  718. tgt=self.norms[2](tgt)
  719. return tgt
  720. class RtdetrDecoder(DetrTransformerDecoder):
  721. def _init_layers(self) -> None:
  722. self.layers = ModuleList([
  723. RtDetrTransformerDecoderLayer(**self.layer_cfg)
  724. for _ in range(self.num_layers)
  725. ])
  726. self.embed_dims = self.layers[0].embed_dims
  727. self.eval_idx=self.num_layers-1
  728. def forward(self,
  729. target:Tensor,
  730. memory:Tensor,
  731. memory_spatial_shapes:Tensor,
  732. memory_level_start_index:Tensor,
  733. ref_points_unact:Tensor,
  734. query_pos_head:nn.Module,
  735. #MLP
  736. bbox_head:ModuleList,
  737. score_head:ModuleList,
  738. attn_mask:Tensor=None,
  739. )->Tuple[Tensor]:
  740. output=target
  741. dec_out_bboxes=[]
  742. dec_out_logits=[]
  743. ref_points_detach = F.sigmoid(ref_points_unact)
  744. for i, layer in enumerate(self.layers):
  745. ref_points_input = ref_points_detach.unsqueeze(2)
  746. query_pos_embed = query_pos_head(ref_points_detach)
  747. # def forward(self,
  748. # tgt:Tensor,
  749. # referenc_point:Tensor,
  750. # memory:Tensor,
  751. # spartial_shapes:Tensor,
  752. # level_start_index:Tensor,
  753. # query_pos_embed:Tensor,
  754. # attn_mask:Tensor=None,
  755. # )->Tensor:
  756. output = layer(output,
  757. ref_points_input,
  758. memory,
  759. memory_spatial_shapes,
  760. memory_level_start_index,
  761. query_pos_embed,attn_mask)
  762. inter_ref_bbox=F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
  763. if self.training:
  764. dec_out_logits.append(score_head[i](output))
  765. if i == 0:
  766. dec_out_bboxes.append(inter_ref_bbox)
  767. else:
  768. dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)))
  769. elif i==self.eval_idx:
  770. dec_out_logits.append(score_head[i](output))
  771. dec_out_bboxes.append(inter_ref_bbox)
  772. break
  773. ref_points_detach = inter_ref_bbox.detach(
  774. ) if self.training else inter_ref_bbox
  775. return tuple([torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)])