utils.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import types
  3. from typing import Dict, Optional
  4. import math
  5. import warnings
  6. from typing import Optional, Sequence, Tuple, Union
  7. import numpy as np
  8. import torch
  9. import torch.nn.functional as F
  10. from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
  11. build_norm_layer,ConvModule)
  12. from mmcv.cnn.bricks.drop import Dropout
  13. from mmengine.model import BaseModule, ModuleList
  14. from mmengine.utils import to_2tuple
  15. from torch import Tensor, nn
  16. from mmdet.registry import MODELS
  17. from mmdet.utils import OptConfigType, OptMultiConfig
  18. def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor:
  19. """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
  20. Args:
  21. x (Tensor): The input tensor of shape [N, L, C] before conversion.
  22. hw_shape (Sequence[int]): The height and width of output feature map.
  23. Returns:
  24. Tensor: The output tensor of shape [N, C, H, W] after conversion.
  25. """
  26. H, W = hw_shape
  27. assert len(x.shape) == 3
  28. B, L, C = x.shape
  29. assert L == H * W, 'The seq_len does not match H, W'
  30. return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
  31. def nchw_to_nlc(x):
  32. """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
  33. Args:
  34. x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
  35. Returns:
  36. Tensor: The output tensor of shape [N, L, C] after conversion.
  37. """
  38. assert len(x.shape) == 4
  39. return x.flatten(2).transpose(1, 2).contiguous()
  40. def coordinate_to_encoding(coord_tensor: Tensor,
  41. num_feats: int = 128,
  42. temperature: int = 10000,
  43. scale: float = 2 * math.pi):
  44. """Convert coordinate tensor to positional encoding.
  45. Args:
  46. coord_tensor (Tensor): Coordinate tensor to be converted to
  47. positional encoding. With the last dimension as 2 or 4.
  48. num_feats (int, optional): The feature dimension for each position
  49. along x-axis or y-axis. Note the final returned dimension
  50. for each position is 2 times of this value. Defaults to 128.
  51. temperature (int, optional): The temperature used for scaling
  52. the position embedding. Defaults to 10000.
  53. scale (float, optional): A scale factor that scales the position
  54. embedding. The scale will be used only when `normalize` is True.
  55. Defaults to 2*pi.
  56. Returns:
  57. Tensor: Returned encoded positional tensor.
  58. """
  59. dim_t = torch.arange(
  60. num_feats, dtype=torch.float32, device=coord_tensor.device)
  61. dim_t = temperature**(2 * (dim_t // 2) / num_feats)
  62. x_embed = coord_tensor[..., 0] * scale
  63. y_embed = coord_tensor[..., 1] * scale
  64. pos_x = x_embed[..., None] / dim_t
  65. pos_y = y_embed[..., None] / dim_t
  66. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()),
  67. dim=-1).flatten(2)
  68. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()),
  69. dim=-1).flatten(2)
  70. if coord_tensor.size(-1) == 2:
  71. pos = torch.cat((pos_y, pos_x), dim=-1)
  72. elif coord_tensor.size(-1) == 4:
  73. w_embed = coord_tensor[..., 2] * scale
  74. pos_w = w_embed[..., None] / dim_t
  75. pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()),
  76. dim=-1).flatten(2)
  77. h_embed = coord_tensor[..., 3] * scale
  78. pos_h = h_embed[..., None] / dim_t
  79. pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()),
  80. dim=-1).flatten(2)
  81. pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1)
  82. else:
  83. raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
  84. coord_tensor.size(-1)))
  85. return pos
  86. def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
  87. """Inverse function of sigmoid.
  88. Args:
  89. x (Tensor): The tensor to do the inverse.
  90. eps (float): EPS avoid numerical overflow. Defaults 1e-5.
  91. Returns:
  92. Tensor: The x has passed the inverse function of sigmoid, has the same
  93. shape with input.
  94. """
  95. x = x.clamp(min=0, max=1)
  96. x1 = x.clamp(min=eps)
  97. x2 = (1 - x).clamp(min=eps)
  98. return torch.log(x1 / x2)
  99. class AdaptivePadding(nn.Module):
  100. """Applies padding to input (if needed) so that input can get fully covered
  101. by filter you specified. It support two modes "same" and "corner". The
  102. "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
  103. input. The "corner" mode would pad zero to bottom right.
  104. Args:
  105. kernel_size (int | tuple): Size of the kernel:
  106. stride (int | tuple): Stride of the filter. Default: 1:
  107. dilation (int | tuple): Spacing between kernel elements.
  108. Default: 1
  109. padding (str): Support "same" and "corner", "corner" mode
  110. would pad zero to bottom right, and "same" mode would
  111. pad zero around input. Default: "corner".
  112. Example:
  113. >>> kernel_size = 16
  114. >>> stride = 16
  115. >>> dilation = 1
  116. >>> input = torch.rand(1, 1, 15, 17)
  117. >>> adap_pad = AdaptivePadding(
  118. >>> kernel_size=kernel_size,
  119. >>> stride=stride,
  120. >>> dilation=dilation,
  121. >>> padding="corner")
  122. >>> out = adap_pad(input)
  123. >>> assert (out.shape[2], out.shape[3]) == (16, 32)
  124. >>> input = torch.rand(1, 1, 16, 17)
  125. >>> out = adap_pad(input)
  126. >>> assert (out.shape[2], out.shape[3]) == (16, 32)
  127. """
  128. def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
  129. super(AdaptivePadding, self).__init__()
  130. assert padding in ('same', 'corner')
  131. kernel_size = to_2tuple(kernel_size)
  132. stride = to_2tuple(stride)
  133. padding = to_2tuple(padding)
  134. dilation = to_2tuple(dilation)
  135. self.padding = padding
  136. self.kernel_size = kernel_size
  137. self.stride = stride
  138. self.dilation = dilation
  139. def get_pad_shape(self, input_shape):
  140. input_h, input_w = input_shape
  141. kernel_h, kernel_w = self.kernel_size
  142. stride_h, stride_w = self.stride
  143. output_h = math.ceil(input_h / stride_h)
  144. output_w = math.ceil(input_w / stride_w)
  145. pad_h = max((output_h - 1) * stride_h +
  146. (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
  147. pad_w = max((output_w - 1) * stride_w +
  148. (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
  149. return pad_h, pad_w
  150. def forward(self, x):
  151. pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
  152. if pad_h > 0 or pad_w > 0:
  153. if self.padding == 'corner':
  154. x = F.pad(x, [0, pad_w, 0, pad_h])
  155. elif self.padding == 'same':
  156. x = F.pad(x, [
  157. pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
  158. pad_h - pad_h // 2
  159. ])
  160. return x
  161. class PatchEmbed(BaseModule):
  162. """Image to Patch Embedding.
  163. We use a conv layer to implement PatchEmbed.
  164. Args:
  165. in_channels (int): The num of input channels. Default: 3
  166. embed_dims (int): The dimensions of embedding. Default: 768
  167. conv_type (str): The config dict for embedding
  168. conv layer type selection. Default: "Conv2d.
  169. kernel_size (int): The kernel_size of embedding conv. Default: 16.
  170. stride (int): The slide stride of embedding conv.
  171. Default: None (Would be set as `kernel_size`).
  172. padding (int | tuple | string ): The padding length of
  173. embedding conv. When it is a string, it means the mode
  174. of adaptive padding, support "same" and "corner" now.
  175. Default: "corner".
  176. dilation (int): The dilation rate of embedding conv. Default: 1.
  177. bias (bool): Bias of embed conv. Default: True.
  178. norm_cfg (dict, optional): Config dict for normalization layer.
  179. Default: None.
  180. input_size (int | tuple | None): The size of input, which will be
  181. used to calculate the out size. Only work when `dynamic_size`
  182. is False. Default: None.
  183. init_cfg (`mmengine.ConfigDict`, optional): The Config for
  184. initialization. Default: None.
  185. """
  186. def __init__(self,
  187. in_channels: int = 3,
  188. embed_dims: int = 768,
  189. conv_type: str = 'Conv2d',
  190. kernel_size: int = 16,
  191. stride: int = 16,
  192. padding: Union[int, tuple, str] = 'corner',
  193. dilation: int = 1,
  194. bias: bool = True,
  195. norm_cfg: OptConfigType = None,
  196. input_size: Union[int, tuple] = None,
  197. init_cfg: OptConfigType = None) -> None:
  198. super(PatchEmbed, self).__init__(init_cfg=init_cfg)
  199. self.embed_dims = embed_dims
  200. if stride is None:
  201. stride = kernel_size
  202. kernel_size = to_2tuple(kernel_size)
  203. stride = to_2tuple(stride)
  204. dilation = to_2tuple(dilation)
  205. if isinstance(padding, str):
  206. self.adap_padding = AdaptivePadding(
  207. kernel_size=kernel_size,
  208. stride=stride,
  209. dilation=dilation,
  210. padding=padding)
  211. # disable the padding of conv
  212. padding = 0
  213. else:
  214. self.adap_padding = None
  215. padding = to_2tuple(padding)
  216. self.projection = build_conv_layer(
  217. dict(type=conv_type),
  218. in_channels=in_channels,
  219. out_channels=embed_dims,
  220. kernel_size=kernel_size,
  221. stride=stride,
  222. padding=padding,
  223. dilation=dilation,
  224. bias=bias)
  225. if norm_cfg is not None:
  226. self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
  227. else:
  228. self.norm = None
  229. if input_size:
  230. input_size = to_2tuple(input_size)
  231. # `init_out_size` would be used outside to
  232. # calculate the num_patches
  233. # when `use_abs_pos_embed` outside
  234. self.init_input_size = input_size
  235. if self.adap_padding:
  236. pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
  237. input_h, input_w = input_size
  238. input_h = input_h + pad_h
  239. input_w = input_w + pad_w
  240. input_size = (input_h, input_w)
  241. # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  242. h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
  243. (kernel_size[0] - 1) - 1) // stride[0] + 1
  244. w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
  245. (kernel_size[1] - 1) - 1) // stride[1] + 1
  246. self.init_out_size = (h_out, w_out)
  247. else:
  248. self.init_input_size = None
  249. self.init_out_size = None
  250. def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]:
  251. """
  252. Args:
  253. x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
  254. Returns:
  255. tuple: Contains merged results and its spatial shape.
  256. - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
  257. - out_size (tuple[int]): Spatial shape of x, arrange as
  258. (out_h, out_w).
  259. """
  260. if self.adap_padding:
  261. x = self.adap_padding(x)
  262. x = self.projection(x)
  263. out_size = (x.shape[2], x.shape[3])
  264. x = x.flatten(2).transpose(1, 2)
  265. if self.norm is not None:
  266. x = self.norm(x)
  267. return x, out_size
  268. class PatchMerging(BaseModule):
  269. """Merge patch feature map.
  270. This layer groups feature map by kernel_size, and applies norm and linear
  271. layers to the grouped feature map. Our implementation uses `nn.Unfold` to
  272. merge patch, which is about 25% faster than original implementation.
  273. Instead, we need to modify pretrained models for compatibility.
  274. Args:
  275. in_channels (int): The num of input channels.
  276. to gets fully covered by filter and stride you specified..
  277. Default: True.
  278. out_channels (int): The num of output channels.
  279. kernel_size (int | tuple, optional): the kernel size in the unfold
  280. layer. Defaults to 2.
  281. stride (int | tuple, optional): the stride of the sliding blocks in the
  282. unfold layer. Default: None. (Would be set as `kernel_size`)
  283. padding (int | tuple | string ): The padding length of
  284. embedding conv. When it is a string, it means the mode
  285. of adaptive padding, support "same" and "corner" now.
  286. Default: "corner".
  287. dilation (int | tuple, optional): dilation parameter in the unfold
  288. layer. Default: 1.
  289. bias (bool, optional): Whether to add bias in linear layer or not.
  290. Defaults: False.
  291. norm_cfg (dict, optional): Config dict for normalization layer.
  292. Default: dict(type='LN').
  293. init_cfg (dict, optional): The extra config for initialization.
  294. Default: None.
  295. """
  296. def __init__(self,
  297. in_channels: int,
  298. out_channels: int,
  299. kernel_size: Optional[Union[int, tuple]] = 2,
  300. stride: Optional[Union[int, tuple]] = None,
  301. padding: Union[int, tuple, str] = 'corner',
  302. dilation: Optional[Union[int, tuple]] = 1,
  303. bias: Optional[bool] = False,
  304. norm_cfg: OptConfigType = dict(type='LN'),
  305. init_cfg: OptConfigType = None) -> None:
  306. super().__init__(init_cfg=init_cfg)
  307. self.in_channels = in_channels
  308. self.out_channels = out_channels
  309. if stride:
  310. stride = stride
  311. else:
  312. stride = kernel_size
  313. kernel_size = to_2tuple(kernel_size)
  314. stride = to_2tuple(stride)
  315. dilation = to_2tuple(dilation)
  316. if isinstance(padding, str):
  317. self.adap_padding = AdaptivePadding(
  318. kernel_size=kernel_size,
  319. stride=stride,
  320. dilation=dilation,
  321. padding=padding)
  322. # disable the padding of unfold
  323. padding = 0
  324. else:
  325. self.adap_padding = None
  326. padding = to_2tuple(padding)
  327. self.sampler = nn.Unfold(
  328. kernel_size=kernel_size,
  329. dilation=dilation,
  330. padding=padding,
  331. stride=stride)
  332. sample_dim = kernel_size[0] * kernel_size[1] * in_channels
  333. if norm_cfg is not None:
  334. self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
  335. else:
  336. self.norm = None
  337. self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
  338. def forward(self, x: Tensor,
  339. input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]:
  340. """
  341. Args:
  342. x (Tensor): Has shape (B, H*W, C_in).
  343. input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
  344. Default: None.
  345. Returns:
  346. tuple: Contains merged results and its spatial shape.
  347. - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
  348. - out_size (tuple[int]): Spatial shape of x, arrange as
  349. (Merged_H, Merged_W).
  350. """
  351. B, L, C = x.shape
  352. assert isinstance(input_size, Sequence), f'Expect ' \
  353. f'input_size is ' \
  354. f'`Sequence` ' \
  355. f'but get {input_size}'
  356. H, W = input_size
  357. assert L == H * W, 'input feature has wrong size'
  358. x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
  359. # Use nn.Unfold to merge patch. About 25% faster than original method,
  360. # but need to modify pretrained model for compatibility
  361. if self.adap_padding:
  362. x = self.adap_padding(x)
  363. H, W = x.shape[-2:]
  364. x = self.sampler(x)
  365. # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
  366. out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
  367. (self.sampler.kernel_size[0] - 1) -
  368. 1) // self.sampler.stride[0] + 1
  369. out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
  370. (self.sampler.kernel_size[1] - 1) -
  371. 1) // self.sampler.stride[1] + 1
  372. output_size = (out_h, out_w)
  373. x = x.transpose(1, 2) # B, H/2*W/2, 4*C
  374. x = self.norm(x) if self.norm else x
  375. x = self.reduction(x)
  376. return x, output_size
  377. class ConditionalAttention(BaseModule):
  378. """A wrapper of conditional attention, dropout and residual connection.
  379. Args:
  380. embed_dims (int): The embedding dimension.
  381. num_heads (int): Parallel attention heads.
  382. attn_drop (float): A Dropout layer on attn_output_weights.
  383. Default: 0.0.
  384. proj_drop: A Dropout layer after `nn.MultiheadAttention`.
  385. Default: 0.0.
  386. cross_attn (bool): Whether the attention module is for cross attention.
  387. Default: False
  388. keep_query_pos (bool): Whether to transform query_pos before cross
  389. attention.
  390. Default: False.
  391. batch_first (bool): When it is True, Key, Query and Value are shape of
  392. (batch, n, embed_dim), otherwise (n, batch, embed_dim).
  393. Default: True.
  394. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  395. Default: None.
  396. """
  397. def __init__(self,
  398. embed_dims: int,
  399. num_heads: int,
  400. attn_drop: float = 0.,
  401. proj_drop: float = 0.,
  402. cross_attn: bool = False,
  403. keep_query_pos: bool = False,
  404. batch_first: bool = True,
  405. init_cfg: OptMultiConfig = None):
  406. super().__init__(init_cfg=init_cfg)
  407. assert batch_first is True, 'Set `batch_first`\
  408. to False is NOT supported in ConditionalAttention. \
  409. First dimension of all DETRs in mmdet is `batch`, \
  410. please set `batch_first` to True.'
  411. self.cross_attn = cross_attn
  412. self.keep_query_pos = keep_query_pos
  413. self.embed_dims = embed_dims
  414. self.num_heads = num_heads
  415. self.attn_drop = Dropout(attn_drop)
  416. self.proj_drop = Dropout(proj_drop)
  417. self._init_layers()
  418. def _init_layers(self):
  419. """Initialize layers for qkv projection."""
  420. embed_dims = self.embed_dims
  421. self.qcontent_proj = Linear(embed_dims, embed_dims)
  422. self.qpos_proj = Linear(embed_dims, embed_dims)
  423. self.kcontent_proj = Linear(embed_dims, embed_dims)
  424. self.kpos_proj = Linear(embed_dims, embed_dims)
  425. self.v_proj = Linear(embed_dims, embed_dims)
  426. if self.cross_attn:
  427. self.qpos_sine_proj = Linear(embed_dims, embed_dims)
  428. self.out_proj = Linear(embed_dims, embed_dims)
  429. nn.init.constant_(self.out_proj.bias, 0.)
  430. def forward_attn(self,
  431. query: Tensor,
  432. key: Tensor,
  433. value: Tensor,
  434. attn_mask: Tensor = None,
  435. key_padding_mask: Tensor = None) -> Tuple[Tensor]:
  436. """Forward process for `ConditionalAttention`.
  437. Args:
  438. query (Tensor): The input query with shape [bs, num_queries,
  439. embed_dims].
  440. key (Tensor): The key tensor with shape [bs, num_keys,
  441. embed_dims].
  442. If None, the `query` will be used. Defaults to None.
  443. value (Tensor): The value tensor with same shape as `key`.
  444. Same in `nn.MultiheadAttention.forward`. Defaults to None.
  445. If None, the `key` will be used.
  446. attn_mask (Tensor): ByteTensor mask with shape [num_queries,
  447. num_keys]. Same in `nn.MultiheadAttention.forward`.
  448. Defaults to None.
  449. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
  450. Defaults to None.
  451. Returns:
  452. Tuple[Tensor]: Attention outputs of shape :math:`(N, L, E)`,
  453. where :math:`N` is the batch size, :math:`L` is the target
  454. sequence length , and :math:`E` is the embedding dimension
  455. `embed_dim`. Attention weights per head of shape :math:`
  456. (num_heads, L, S)`. where :math:`N` is batch size, :math:`L`
  457. is target sequence length, and :math:`S` is the source sequence
  458. length.
  459. """
  460. assert key.size(1) == value.size(1), \
  461. f'{"key, value must have the same sequence length"}'
  462. assert query.size(0) == key.size(0) == value.size(0), \
  463. f'{"batch size must be equal for query, key, value"}'
  464. assert query.size(2) == key.size(2), \
  465. f'{"q_dims, k_dims must be equal"}'
  466. assert value.size(2) == self.embed_dims, \
  467. f'{"v_dims must be equal to embed_dims"}'
  468. bs, tgt_len, hidden_dims = query.size()
  469. _, src_len, _ = key.size()
  470. head_dims = hidden_dims // self.num_heads
  471. v_head_dims = self.embed_dims // self.num_heads
  472. assert head_dims * self.num_heads == hidden_dims, \
  473. f'{"hidden_dims must be divisible by num_heads"}'
  474. scaling = float(head_dims)**-0.5
  475. q = query * scaling
  476. k = key
  477. v = value
  478. if attn_mask is not None:
  479. assert attn_mask.dtype == torch.float32 or \
  480. attn_mask.dtype == torch.float64 or \
  481. attn_mask.dtype == torch.float16 or \
  482. attn_mask.dtype == torch.uint8 or \
  483. attn_mask.dtype == torch.bool, \
  484. 'Only float, byte, and bool types are supported for \
  485. attn_mask'
  486. if attn_mask.dtype == torch.uint8:
  487. warnings.warn('Byte tensor for attn_mask is deprecated.\
  488. Use bool tensor instead.')
  489. attn_mask = attn_mask.to(torch.bool)
  490. if attn_mask.dim() == 2:
  491. attn_mask = attn_mask.unsqueeze(0)
  492. if list(attn_mask.size()) != [1, query.size(1), key.size(1)]:
  493. raise RuntimeError(
  494. 'The size of the 2D attn_mask is not correct.')
  495. elif attn_mask.dim() == 3:
  496. if list(attn_mask.size()) != [
  497. bs * self.num_heads,
  498. query.size(1),
  499. key.size(1)
  500. ]:
  501. raise RuntimeError(
  502. 'The size of the 3D attn_mask is not correct.')
  503. else:
  504. raise RuntimeError(
  505. "attn_mask's dimension {} is not supported".format(
  506. attn_mask.dim()))
  507. # attn_mask's dim is 3 now.
  508. if key_padding_mask is not None and key_padding_mask.dtype == int:
  509. key_padding_mask = key_padding_mask.to(torch.bool)
  510. q = q.contiguous().view(bs, tgt_len, self.num_heads,
  511. head_dims).permute(0, 2, 1, 3).flatten(0, 1)
  512. if k is not None:
  513. k = k.contiguous().view(bs, src_len, self.num_heads,
  514. head_dims).permute(0, 2, 1,
  515. 3).flatten(0, 1)
  516. if v is not None:
  517. v = v.contiguous().view(bs, src_len, self.num_heads,
  518. v_head_dims).permute(0, 2, 1,
  519. 3).flatten(0, 1)
  520. if key_padding_mask is not None:
  521. assert key_padding_mask.size(0) == bs
  522. assert key_padding_mask.size(1) == src_len
  523. attn_output_weights = torch.bmm(q, k.transpose(1, 2))
  524. assert list(attn_output_weights.size()) == [
  525. bs * self.num_heads, tgt_len, src_len
  526. ]
  527. if attn_mask is not None:
  528. if attn_mask.dtype == torch.bool:
  529. attn_output_weights.masked_fill_(attn_mask, float('-inf'))
  530. else:
  531. attn_output_weights += attn_mask
  532. if key_padding_mask is not None:
  533. attn_output_weights = attn_output_weights.view(
  534. bs, self.num_heads, tgt_len, src_len)
  535. attn_output_weights = attn_output_weights.masked_fill(
  536. key_padding_mask.unsqueeze(1).unsqueeze(2),
  537. float('-inf'),
  538. )
  539. attn_output_weights = attn_output_weights.view(
  540. bs * self.num_heads, tgt_len, src_len)
  541. attn_output_weights = F.softmax(
  542. attn_output_weights -
  543. attn_output_weights.max(dim=-1, keepdim=True)[0],
  544. dim=-1)
  545. attn_output_weights = self.attn_drop(attn_output_weights)
  546. attn_output = torch.bmm(attn_output_weights, v)
  547. assert list(
  548. attn_output.size()) == [bs * self.num_heads, tgt_len, v_head_dims]
  549. attn_output = attn_output.view(bs, self.num_heads, tgt_len,
  550. v_head_dims).permute(0, 2, 1,
  551. 3).flatten(2)
  552. attn_output = self.out_proj(attn_output)
  553. # average attention weights over heads
  554. attn_output_weights = attn_output_weights.view(bs, self.num_heads,
  555. tgt_len, src_len)
  556. return attn_output, attn_output_weights.sum(dim=1) / self.num_heads
  557. def forward(self,
  558. query: Tensor,
  559. key: Tensor,
  560. query_pos: Tensor = None,
  561. ref_sine_embed: Tensor = None,
  562. key_pos: Tensor = None,
  563. attn_mask: Tensor = None,
  564. key_padding_mask: Tensor = None,
  565. is_first: bool = False) -> Tensor:
  566. """Forward function for `ConditionalAttention`.
  567. Args:
  568. query (Tensor): The input query with shape [bs, num_queries,
  569. embed_dims].
  570. key (Tensor): The key tensor with shape [bs, num_keys,
  571. embed_dims].
  572. If None, the `query` will be used. Defaults to None.
  573. query_pos (Tensor): The positional encoding for query in self
  574. attention, with the same shape as `x`. If not None, it will
  575. be added to `x` before forward function.
  576. Defaults to None.
  577. query_sine_embed (Tensor): The positional encoding for query in
  578. cross attention, with the same shape as `x`. If not None, it
  579. will be added to `x` before forward function.
  580. Defaults to None.
  581. key_pos (Tensor): The positional encoding for `key`, with the
  582. same shape as `key`. Defaults to None. If not None, it will
  583. be added to `key` before forward function. If None, and
  584. `query_pos` has the same shape as `key`, then `query_pos`
  585. will be used for `key_pos`. Defaults to None.
  586. attn_mask (Tensor): ByteTensor mask with shape [num_queries,
  587. num_keys]. Same in `nn.MultiheadAttention.forward`.
  588. Defaults to None.
  589. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
  590. Defaults to None.
  591. is_first (bool): A indicator to tell whether the current layer
  592. is the first layer of the decoder.
  593. Defaults to False.
  594. Returns:
  595. Tensor: forwarded results with shape
  596. [bs, num_queries, embed_dims].
  597. """
  598. if self.cross_attn:
  599. q_content = self.qcontent_proj(query)
  600. k_content = self.kcontent_proj(key)
  601. v = self.v_proj(key)
  602. bs, nq, c = q_content.size()
  603. _, hw, _ = k_content.size()
  604. k_pos = self.kpos_proj(key_pos)
  605. if is_first or self.keep_query_pos:
  606. q_pos = self.qpos_proj(query_pos)
  607. q = q_content + q_pos
  608. k = k_content + k_pos
  609. else:
  610. q = q_content
  611. k = k_content
  612. q = q.view(bs, nq, self.num_heads, c // self.num_heads)
  613. query_sine_embed = self.qpos_sine_proj(ref_sine_embed)
  614. query_sine_embed = query_sine_embed.view(bs, nq, self.num_heads,
  615. c // self.num_heads)
  616. q = torch.cat([q, query_sine_embed], dim=3).view(bs, nq, 2 * c)
  617. k = k.view(bs, hw, self.num_heads, c // self.num_heads)
  618. k_pos = k_pos.view(bs, hw, self.num_heads, c // self.num_heads)
  619. k = torch.cat([k, k_pos], dim=3).view(bs, hw, 2 * c)
  620. ca_output = self.forward_attn(
  621. query=q,
  622. key=k,
  623. value=v,
  624. attn_mask=attn_mask,
  625. key_padding_mask=key_padding_mask)[0]
  626. query = query + self.proj_drop(ca_output)
  627. else:
  628. q_content = self.qcontent_proj(query)
  629. q_pos = self.qpos_proj(query_pos)
  630. k_content = self.kcontent_proj(query)
  631. k_pos = self.kpos_proj(query_pos)
  632. v = self.v_proj(query)
  633. q = q_content if q_pos is None else q_content + q_pos
  634. k = k_content if k_pos is None else k_content + k_pos
  635. sa_output = self.forward_attn(
  636. query=q,
  637. key=k,
  638. value=v,
  639. attn_mask=attn_mask,
  640. key_padding_mask=key_padding_mask)[0]
  641. query = query + self.proj_drop(sa_output)
  642. return query
  643. class MLP(BaseModule):
  644. """Very simple multi-layer perceptron (also called FFN) with relu. Mostly
  645. used in DETR series detectors.
  646. Args:
  647. input_dim (int): Feature dim of the input tensor.
  648. hidden_dim (int): Feature dim of the hidden layer.
  649. output_dim (int): Feature dim of the output tensor.
  650. num_layers (int): Number of FFN layers. As the last
  651. layer of MLP only contains FFN (Linear).
  652. """
  653. def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
  654. num_layers: int) -> None:
  655. super().__init__()
  656. self.num_layers = num_layers
  657. h = [hidden_dim] * (num_layers - 1)
  658. self.layers = ModuleList(
  659. Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  660. def forward(self, x: Tensor) -> Tensor:
  661. """Forward function of MLP.
  662. Args:
  663. x (Tensor): The input feature, has shape
  664. (num_queries, bs, input_dim).
  665. Returns:
  666. Tensor: The output feature, has shape
  667. (num_queries, bs, output_dim).
  668. """
  669. for i, layer in enumerate(self.layers):
  670. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  671. return x
  672. """A block in RepVGG architecture, supporting optional normalization in the
  673. identity branch.
  674. This block consists of 3x3 and 1x1 convolutions, with an optional identity
  675. shortcut branch that includes normalization.
  676. Args:
  677. in_channels (int): The input channels of the block.
  678. out_channels (int): The output channels of the block.
  679. stride (int): The stride of the block. Defaults to 1.
  680. padding (int): The padding of the block. Defaults to 1.
  681. dilation (int): The dilation of the block. Defaults to 1.
  682. groups (int): The groups of the block. Defaults to 1.
  683. padding_mode (str): The padding mode of the block. Defaults to 'zeros'.
  684. norm_cfg (dict): The config dict for normalization layers.
  685. Defaults to dict(type='BN').
  686. act_cfg (dict): The config dict for activation layers.
  687. Defaults to dict(type='ReLU').
  688. without_branch_norm (bool): Whether to skip branch_norm.
  689. Defaults to True.
  690. init_cfg (dict): The config dict for initialization. Defaults to None.
  691. """
  692. def __init__(self,
  693. in_channels: int,
  694. out_channels: int,
  695. stride: int = 1,
  696. padding: int = 1,
  697. dilation: int = 1,
  698. groups: int = 1,
  699. norm_cfg: OptConfigType = dict(type='BN'),
  700. act_cfg: OptConfigType = dict(type='ReLU'),
  701. without_branch_norm: bool = True,
  702. init_cfg: OptConfigType = None):
  703. super(RepVGGBlock, self).__init__(init_cfg)
  704. self.in_channels = in_channels
  705. self.out_channels = out_channels
  706. self.stride = stride
  707. self.padding = padding
  708. self.dilation = dilation
  709. self.groups = groups
  710. self.norm_cfg = norm_cfg
  711. self.act_cfg = act_cfg
  712. # judge if input shape and output shape are the same.
  713. # If true, add a normalized identity shortcut.
  714. self.branch_norm = None
  715. if out_channels == in_channels and stride == 1 and \
  716. padding == dilation and not without_branch_norm:
  717. self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
  718. self.branch_3x3 = ConvModule(
  719. self.in_channels,
  720. self.out_channels,
  721. 3,
  722. stride=self.stride,
  723. padding=self.padding,
  724. groups=self.groups,
  725. dilation=self.dilation,
  726. norm_cfg=self.norm_cfg,
  727. act_cfg=None)
  728. self.branch_1x1 = ConvModule(
  729. self.in_channels,
  730. self.out_channels,
  731. 1,
  732. groups=self.groups,
  733. norm_cfg=self.norm_cfg,
  734. act_cfg=None)
  735. self.act = build_activation_layer(act_cfg)
  736. def forward(self, x: Tensor) -> Tensor:
  737. """Forward pass through the RepVGG block.
  738. The output is the sum of 3x3 and 1x1 convolution outputs,
  739. along with the normalized identity branch output, followed by
  740. activation.
  741. Args:
  742. x (Tensor): The input tensor.
  743. Returns:
  744. Tensor: The output tensor.
  745. """
  746. if self.branch_norm is None:
  747. branch_norm_out = 0
  748. else:
  749. branch_norm_out = self.branch_norm(x)
  750. out = self.branch_3x3(x) + self.branch_1x1(x) + branch_norm_out
  751. out = self.act(out)
  752. return out
  753. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  754. """Pad 1x1 tensor to 3x3.
  755. Args:
  756. kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
  757. Returns:
  758. Tensor: 3x3 kernel after padded.
  759. """
  760. if kernel1x1 is None:
  761. return 0
  762. else:
  763. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  764. def _fuse_bn_tensor(self, branch: nn.Module) -> Tensor:
  765. """Derives the equivalent kernel and bias of a specific branch layer.
  766. Args:
  767. branch (nn.Module): The layer that needs to be equivalently
  768. transformed, which can be nn.Sequential or nn.Batchnorm2d
  769. Returns:
  770. tuple: Equivalent kernel and bias
  771. """
  772. if branch is None:
  773. return 0, 0
  774. if isinstance(branch, ConvModule):
  775. kernel = branch.conv.weight
  776. running_mean = branch.bn.running_mean
  777. running_var = branch.bn.running_var
  778. gamma = branch.bn.weight
  779. beta = branch.bn.bias
  780. eps = branch.bn.eps
  781. else:
  782. assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d))
  783. if not hasattr(self, 'id_tensor'):
  784. input_dim = self.in_channels // self.groups
  785. kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
  786. dtype=np.float32)
  787. for i in range(self.in_channels):
  788. kernel_value[i, i % input_dim, 1, 1] = 1
  789. self.id_tensor = torch.from_numpy(kernel_value).to(
  790. branch.weight.device)
  791. kernel = self.id_tensor
  792. running_mean = branch.running_mean
  793. running_var = branch.running_var
  794. gamma = branch.weight
  795. beta = branch.bias
  796. eps = branch.eps
  797. std = (running_var + eps).sqrt()
  798. t = (gamma / std).reshape(-1, 1, 1, 1)
  799. return kernel * t, beta - running_mean * gamma / std
  800. def get_equivalent_kernel_bias(self):
  801. """Derives the equivalent kernel and bias in a differentiable way.
  802. Returns:
  803. tuple: Equivalent kernel and bias
  804. """
  805. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3)
  806. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1)
  807. kernelid, biasid = (0, 0) if self.branch_norm is None else \
  808. self._fuse_bn_tensor(self.branch_norm)
  809. return (kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
  810. bias3x3 + bias1x1 + biasid)
  811. def switch_to_deploy(self, test_cfg: Optional[Dict] = None):
  812. """Switches the block to deployment mode.
  813. In deployment mode, the block uses a single convolution operation
  814. derived from the equivalent kernel and bias, replacing the original
  815. branches. This reduces computational complexity during inference.
  816. """
  817. if getattr(self, 'deploy', False):
  818. return
  819. kernel, bias = self.get_equivalent_kernel_bias()
  820. self.conv_reparam = nn.Conv2d(
  821. in_channels=self.branch_3x3.conv.in_channels,
  822. out_channels=self.branch_3x3.conv.out_channels,
  823. kernel_size=self.branch_3x3.conv.kernel_size,
  824. stride=self.branch_3x3.conv.stride,
  825. padding=self.branch_3x3.conv.padding,
  826. dilation=self.branch_3x3.conv.dilation,
  827. groups=self.branch_3x3.conv.groups,
  828. bias=True)
  829. self.conv_reparam.weight.data = kernel
  830. self.conv_reparam.bias.data = bias
  831. for para in self.parameters():
  832. para.detach_()
  833. self.__delattr__('branch_3x3')
  834. self.__delattr__('branch_1x1')
  835. if hasattr(self, 'branch_norm'):
  836. self.__delattr__('branch_norm')
  837. def _forward(self, x):
  838. return self.act(self.conv_reparam(x))
  839. self.forward = types.MethodType(_forward, self)
  840. self.deploy = True
  841. """Position encoding with sine and cosine functions.
  842. See `End-to-End Object Detection with Transformers
  843. <https://arxiv.org/pdf/2005.12872>`_ for details.
  844. Args:
  845. num_feats (int): The feature dimension for each position
  846. along x-axis or y-axis. Note the final returned dimension
  847. for each position is 2 times of this value.
  848. temperature (int, optional): The temperature used for scaling
  849. the position embedding. Defaults to 10000.
  850. normalize (bool, optional): Whether to normalize the position
  851. embedding. Defaults to False.
  852. scale (float, optional): A scale factor that scales the position
  853. embedding. The scale will be used only when `normalize` is True.
  854. Defaults to 2*pi.
  855. eps (float, optional): A value added to the denominator for
  856. numerical stability. Defaults to 1e-6.
  857. offset (float): offset add to embed when do the normalization.
  858. Defaults to 0.
  859. init_cfg (dict or list[dict], optional): Initialization config dict.
  860. Defaults to None
  861. """
  862. def __init__(self,
  863. num_feats: int,
  864. temperature: int = 10000,
  865. normalize: bool = False,
  866. scale: float = 2 * math.pi,
  867. eps: float = 1e-6,
  868. offset: float = 0.,
  869. init_cfg: OptMultiConfig = None) -> None:
  870. super().__init__(init_cfg=init_cfg)
  871. if normalize:
  872. assert isinstance(scale, (float, int)), 'when normalize is set,' \
  873. 'scale should be provided and in float or int type, ' \
  874. f'found {type(scale)}'
  875. self.num_feats = num_feats
  876. self.temperature = temperature
  877. self.normalize = normalize
  878. self.scale = scale
  879. self.eps = eps
  880. self.offset = offset
  881. def forward(self, mask: Tensor) -> Tensor:
  882. """Forward function for `SinePositionalEncoding`.
  883. Args:
  884. mask (Tensor): ByteTensor mask. Non-zero values representing
  885. ignored positions, while zero values means valid positions
  886. for this image. Shape [bs, h, w].
  887. Returns:
  888. pos (Tensor): Returned position embedding with shape
  889. [bs, num_feats*2, h, w].
  890. """
  891. # For convenience of exporting to ONNX, it's required to convert
  892. # `masks` from bool to int.
  893. mask = mask.to(torch.int)
  894. not_mask = 1 - mask # logical_not
  895. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  896. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  897. if self.normalize:
  898. y_embed = (y_embed + self.offset) / \
  899. (y_embed[:, -1:, :] + self.eps) * self.scale
  900. x_embed = (x_embed + self.offset) / \
  901. (x_embed[:, :, -1:] + self.eps) * self.scale
  902. dim_t = torch.arange(
  903. self.num_feats, dtype=torch.float32, device=mask.device)
  904. dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
  905. pos_x = x_embed[:, :, :, None] / dim_t
  906. pos_y = y_embed[:, :, :, None] / dim_t
  907. # use `view` instead of `flatten` for dynamically exporting to ONNX
  908. B, H, W = mask.size()
  909. pos_x = torch.stack(
  910. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
  911. dim=4).view(B, H, W, -1)
  912. pos_y = torch.stack(
  913. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
  914. dim=4).view(B, H, W, -1)
  915. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  916. return pos
  917. def __repr__(self) -> str:
  918. """str: a string that describes the module"""
  919. repr_str = self.__class__.__name__
  920. repr_str += f'(num_feats={self.num_feats}, '
  921. repr_str += f'temperature={self.temperature}, '
  922. repr_str += f'normalize={self.normalize}, '
  923. repr_str += f'scale={self.scale}, '
  924. repr_str += f'eps={self.eps})'
  925. return repr_str
  926. class RtdetrSinePositionalEncoding(BaseModule):
  927. def __init__(self,
  928. embed_dim:int=256,
  929. temperature: int = 10000,
  930. normalize: bool = False,
  931. scale: float = 2 * math.pi,
  932. eps: float = 1e-6,
  933. offset: float = 0.,
  934. init_cfg: OptMultiConfig = None) -> Tensor:
  935. super().__init__(init_cfg=init_cfg)
  936. if normalize:
  937. assert isinstance(scale, (float, int)), 'when normalize is set,' \
  938. 'scale should be provided and in float or int type, ' \
  939. f'found {type(scale)}'
  940. self.embed_dim=embed_dim
  941. self.temperature = temperature
  942. self.normalize = normalize
  943. self.scale = scale
  944. self.eps = eps
  945. self.offset = offset
  946. def forward(self,height,width):
  947. grid_w = torch.arange(int(width), dtype=torch.float32,device='cuda')
  948. grid_h = torch.arange(int(height), dtype=torch.float32,device='cuda')
  949. grid_w, grid_h = torch.meshgrid(grid_w, grid_h,indexing='ij')
  950. assert self.embed_dim % 4 == 0, \
  951. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  952. pos_dim = self.embed_dim // 4
  953. omega = torch.arange(pos_dim, dtype=torch.float32,device='cuda') / pos_dim
  954. omega = self.temperature**(omega / -pos_dim)
  955. out_w = grid_w.flatten()[..., None] @omega[None]
  956. out_h = grid_h.flatten()[..., None] @omega[None]
  957. return torch.concat(
  958. [
  959. torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),
  960. torch.cos(out_h)
  961. ],
  962. axis=1)[None, :, :]
  963. @MODELS.register_module()
  964. class DynamicConv(BaseModule):
  965. """Implements Dynamic Convolution.
  966. This module generate parameters for each sample and
  967. use bmm to implement 1*1 convolution. Code is modified
  968. from the `official github repo <https://github.com/PeizeSun/
  969. SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
  970. Args:
  971. in_channels (int): The input feature channel.
  972. Defaults to 256.
  973. feat_channels (int): The inner feature channel.
  974. Defaults to 64.
  975. out_channels (int, optional): The output feature channel.
  976. When not specified, it will be set to `in_channels`
  977. by default
  978. input_feat_shape (int): The shape of input feature.
  979. Defaults to 7.
  980. with_proj (bool): Project two-dimentional feature to
  981. one-dimentional feature. Default to True.
  982. act_cfg (dict): The activation config for DynamicConv.
  983. norm_cfg (dict): Config dict for normalization layer. Default
  984. layer normalization.
  985. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
  986. Default: None.
  987. """
  988. def __init__(self,
  989. in_channels: int = 256,
  990. feat_channels: int = 64,
  991. out_channels: Optional[int] = None,
  992. input_feat_shape: int = 7,
  993. with_proj: bool = True,
  994. act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
  995. norm_cfg: OptConfigType = dict(type='LN'),
  996. init_cfg: OptConfigType = None) -> None:
  997. super(DynamicConv, self).__init__(init_cfg)
  998. self.in_channels = in_channels
  999. self.feat_channels = feat_channels
  1000. self.out_channels_raw = out_channels
  1001. self.input_feat_shape = input_feat_shape
  1002. self.with_proj = with_proj
  1003. self.act_cfg = act_cfg
  1004. self.norm_cfg = norm_cfg
  1005. self.out_channels = out_channels if out_channels else in_channels
  1006. self.num_params_in = self.in_channels * self.feat_channels
  1007. self.num_params_out = self.out_channels * self.feat_channels
  1008. self.dynamic_layer = nn.Linear(
  1009. self.in_channels, self.num_params_in + self.num_params_out)
  1010. self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
  1011. self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
  1012. self.activation = build_activation_layer(act_cfg)
  1013. num_output = self.out_channels * input_feat_shape**2
  1014. if self.with_proj:
  1015. self.fc_layer = nn.Linear(num_output, self.out_channels)
  1016. self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
  1017. def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor:
  1018. """Forward function for `DynamicConv`.
  1019. Args:
  1020. param_feature (Tensor): The feature can be used
  1021. to generate the parameter, has shape
  1022. (num_all_proposals, in_channels).
  1023. input_feature (Tensor): Feature that
  1024. interact with parameters, has shape
  1025. (num_all_proposals, in_channels, H, W).
  1026. Returns:
  1027. Tensor: The output feature has shape
  1028. (num_all_proposals, out_channels).
  1029. """
  1030. input_feature = input_feature.flatten(2).permute(2, 0, 1)
  1031. input_feature = input_feature.permute(1, 0, 2)
  1032. parameters = self.dynamic_layer(param_feature)
  1033. param_in = parameters[:, :self.num_params_in].view(
  1034. -1, self.in_channels, self.feat_channels)
  1035. param_out = parameters[:, -self.num_params_out:].view(
  1036. -1, self.feat_channels, self.out_channels)
  1037. # input_feature has shape (num_all_proposals, H*W, in_channels)
  1038. # param_in has shape (num_all_proposals, in_channels, feat_channels)
  1039. # feature has shape (num_all_proposals, H*W, feat_channels)
  1040. features = torch.bmm(input_feature, param_in)
  1041. features = self.norm_in(features)
  1042. features = self.activation(features)
  1043. # param_out has shape (batch_size, feat_channels, out_channels)
  1044. features = torch.bmm(features, param_out)
  1045. features = self.norm_out(features)
  1046. features = self.activation(features)
  1047. if self.with_proj:
  1048. features = features.flatten(1)
  1049. features = self.fc_layer(features)
  1050. features = self.fc_norm(features)
  1051. features = self.activation(features)
  1052. return features