pvt.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. from collections import OrderedDict
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
  10. from mmcv.cnn.bricks.drop import build_dropout
  11. from mmcv.cnn.bricks.transformer import MultiheadAttention
  12. from mmengine.logging import MMLogger
  13. from mmengine.model import (BaseModule, ModuleList, Sequential, constant_init,
  14. normal_init, trunc_normal_init)
  15. from mmengine.model.weight_init import trunc_normal_
  16. from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
  17. from torch.nn.modules.utils import _pair as to_2tuple
  18. from mmdet.registry import MODELS
  19. from ..layers import PatchEmbed, nchw_to_nlc, nlc_to_nchw
  20. class MixFFN(BaseModule):
  21. """An implementation of MixFFN of PVT.
  22. The differences between MixFFN & FFN:
  23. 1. Use 1X1 Conv to replace Linear layer.
  24. 2. Introduce 3X3 Depth-wise Conv to encode positional information.
  25. Args:
  26. embed_dims (int): The feature dimension. Same as
  27. `MultiheadAttention`.
  28. feedforward_channels (int): The hidden dimension of FFNs.
  29. act_cfg (dict, optional): The activation config for FFNs.
  30. Default: dict(type='GELU').
  31. ffn_drop (float, optional): Probability of an element to be
  32. zeroed in FFN. Default 0.0.
  33. dropout_layer (obj:`ConfigDict`): The dropout_layer used
  34. when adding the shortcut.
  35. Default: None.
  36. use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
  37. Defaults: False.
  38. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
  39. Default: None.
  40. """
  41. def __init__(self,
  42. embed_dims,
  43. feedforward_channels,
  44. act_cfg=dict(type='GELU'),
  45. ffn_drop=0.,
  46. dropout_layer=None,
  47. use_conv=False,
  48. init_cfg=None):
  49. super(MixFFN, self).__init__(init_cfg=init_cfg)
  50. self.embed_dims = embed_dims
  51. self.feedforward_channels = feedforward_channels
  52. self.act_cfg = act_cfg
  53. activate = build_activation_layer(act_cfg)
  54. in_channels = embed_dims
  55. fc1 = Conv2d(
  56. in_channels=in_channels,
  57. out_channels=feedforward_channels,
  58. kernel_size=1,
  59. stride=1,
  60. bias=True)
  61. if use_conv:
  62. # 3x3 depth wise conv to provide positional encode information
  63. dw_conv = Conv2d(
  64. in_channels=feedforward_channels,
  65. out_channels=feedforward_channels,
  66. kernel_size=3,
  67. stride=1,
  68. padding=(3 - 1) // 2,
  69. bias=True,
  70. groups=feedforward_channels)
  71. fc2 = Conv2d(
  72. in_channels=feedforward_channels,
  73. out_channels=in_channels,
  74. kernel_size=1,
  75. stride=1,
  76. bias=True)
  77. drop = nn.Dropout(ffn_drop)
  78. layers = [fc1, activate, drop, fc2, drop]
  79. if use_conv:
  80. layers.insert(1, dw_conv)
  81. self.layers = Sequential(*layers)
  82. self.dropout_layer = build_dropout(
  83. dropout_layer) if dropout_layer else torch.nn.Identity()
  84. def forward(self, x, hw_shape, identity=None):
  85. out = nlc_to_nchw(x, hw_shape)
  86. out = self.layers(out)
  87. out = nchw_to_nlc(out)
  88. if identity is None:
  89. identity = x
  90. return identity + self.dropout_layer(out)
  91. class SpatialReductionAttention(MultiheadAttention):
  92. """An implementation of Spatial Reduction Attention of PVT.
  93. This module is modified from MultiheadAttention which is a module from
  94. mmcv.cnn.bricks.transformer.
  95. Args:
  96. embed_dims (int): The embedding dimension.
  97. num_heads (int): Parallel attention heads.
  98. attn_drop (float): A Dropout layer on attn_output_weights.
  99. Default: 0.0.
  100. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
  101. Default: 0.0.
  102. dropout_layer (obj:`ConfigDict`): The dropout_layer used
  103. when adding the shortcut. Default: None.
  104. batch_first (bool): Key, Query and Value are shape of
  105. (batch, n, embed_dim)
  106. or (n, batch, embed_dim). Default: False.
  107. qkv_bias (bool): enable bias for qkv if True. Default: True.
  108. norm_cfg (dict): Config dict for normalization layer.
  109. Default: dict(type='LN').
  110. sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
  111. Attention of PVT. Default: 1.
  112. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
  113. Default: None.
  114. """
  115. def __init__(self,
  116. embed_dims,
  117. num_heads,
  118. attn_drop=0.,
  119. proj_drop=0.,
  120. dropout_layer=None,
  121. batch_first=True,
  122. qkv_bias=True,
  123. norm_cfg=dict(type='LN'),
  124. sr_ratio=1,
  125. init_cfg=None):
  126. super().__init__(
  127. embed_dims,
  128. num_heads,
  129. attn_drop,
  130. proj_drop,
  131. batch_first=batch_first,
  132. dropout_layer=dropout_layer,
  133. bias=qkv_bias,
  134. init_cfg=init_cfg)
  135. self.sr_ratio = sr_ratio
  136. if sr_ratio > 1:
  137. self.sr = Conv2d(
  138. in_channels=embed_dims,
  139. out_channels=embed_dims,
  140. kernel_size=sr_ratio,
  141. stride=sr_ratio)
  142. # The ret[0] of build_norm_layer is norm name.
  143. self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
  144. # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
  145. from mmdet import digit_version, mmcv_version
  146. if mmcv_version < digit_version('1.3.17'):
  147. warnings.warn('The legacy version of forward function in'
  148. 'SpatialReductionAttention is deprecated in'
  149. 'mmcv>=1.3.17 and will no longer support in the'
  150. 'future. Please upgrade your mmcv.')
  151. self.forward = self.legacy_forward
  152. def forward(self, x, hw_shape, identity=None):
  153. x_q = x
  154. if self.sr_ratio > 1:
  155. x_kv = nlc_to_nchw(x, hw_shape)
  156. x_kv = self.sr(x_kv)
  157. x_kv = nchw_to_nlc(x_kv)
  158. x_kv = self.norm(x_kv)
  159. else:
  160. x_kv = x
  161. if identity is None:
  162. identity = x_q
  163. # Because the dataflow('key', 'query', 'value') of
  164. # ``torch.nn.MultiheadAttention`` is (num_queries, batch,
  165. # embed_dims), We should adjust the shape of dataflow from
  166. # batch_first (batch, num_queries, embed_dims) to num_queries_first
  167. # (num_queries ,batch, embed_dims), and recover ``attn_output``
  168. # from num_queries_first to batch_first.
  169. if self.batch_first:
  170. x_q = x_q.transpose(0, 1)
  171. x_kv = x_kv.transpose(0, 1)
  172. out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
  173. if self.batch_first:
  174. out = out.transpose(0, 1)
  175. return identity + self.dropout_layer(self.proj_drop(out))
  176. def legacy_forward(self, x, hw_shape, identity=None):
  177. """multi head attention forward in mmcv version < 1.3.17."""
  178. x_q = x
  179. if self.sr_ratio > 1:
  180. x_kv = nlc_to_nchw(x, hw_shape)
  181. x_kv = self.sr(x_kv)
  182. x_kv = nchw_to_nlc(x_kv)
  183. x_kv = self.norm(x_kv)
  184. else:
  185. x_kv = x
  186. if identity is None:
  187. identity = x_q
  188. out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
  189. return identity + self.dropout_layer(self.proj_drop(out))
  190. class PVTEncoderLayer(BaseModule):
  191. """Implements one encoder layer in PVT.
  192. Args:
  193. embed_dims (int): The feature dimension.
  194. num_heads (int): Parallel attention heads.
  195. feedforward_channels (int): The hidden dimension for FFNs.
  196. drop_rate (float): Probability of an element to be zeroed.
  197. after the feed forward layer. Default: 0.0.
  198. attn_drop_rate (float): The drop out rate for attention layer.
  199. Default: 0.0.
  200. drop_path_rate (float): stochastic depth rate. Default: 0.0.
  201. qkv_bias (bool): enable bias for qkv if True.
  202. Default: True.
  203. act_cfg (dict): The activation config for FFNs.
  204. Default: dict(type='GELU').
  205. norm_cfg (dict): Config dict for normalization layer.
  206. Default: dict(type='LN').
  207. sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
  208. Attention of PVT. Default: 1.
  209. use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
  210. Default: False.
  211. init_cfg (dict, optional): Initialization config dict.
  212. Default: None.
  213. """
  214. def __init__(self,
  215. embed_dims,
  216. num_heads,
  217. feedforward_channels,
  218. drop_rate=0.,
  219. attn_drop_rate=0.,
  220. drop_path_rate=0.,
  221. qkv_bias=True,
  222. act_cfg=dict(type='GELU'),
  223. norm_cfg=dict(type='LN'),
  224. sr_ratio=1,
  225. use_conv_ffn=False,
  226. init_cfg=None):
  227. super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
  228. # The ret[0] of build_norm_layer is norm name.
  229. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
  230. self.attn = SpatialReductionAttention(
  231. embed_dims=embed_dims,
  232. num_heads=num_heads,
  233. attn_drop=attn_drop_rate,
  234. proj_drop=drop_rate,
  235. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  236. qkv_bias=qkv_bias,
  237. norm_cfg=norm_cfg,
  238. sr_ratio=sr_ratio)
  239. # The ret[0] of build_norm_layer is norm name.
  240. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
  241. self.ffn = MixFFN(
  242. embed_dims=embed_dims,
  243. feedforward_channels=feedforward_channels,
  244. ffn_drop=drop_rate,
  245. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  246. use_conv=use_conv_ffn,
  247. act_cfg=act_cfg)
  248. def forward(self, x, hw_shape):
  249. x = self.attn(self.norm1(x), hw_shape, identity=x)
  250. x = self.ffn(self.norm2(x), hw_shape, identity=x)
  251. return x
  252. class AbsolutePositionEmbedding(BaseModule):
  253. """An implementation of the absolute position embedding in PVT.
  254. Args:
  255. pos_shape (int): The shape of the absolute position embedding.
  256. pos_dim (int): The dimension of the absolute position embedding.
  257. drop_rate (float): Probability of an element to be zeroed.
  258. Default: 0.0.
  259. """
  260. def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
  261. super().__init__(init_cfg=init_cfg)
  262. if isinstance(pos_shape, int):
  263. pos_shape = to_2tuple(pos_shape)
  264. elif isinstance(pos_shape, tuple):
  265. if len(pos_shape) == 1:
  266. pos_shape = to_2tuple(pos_shape[0])
  267. assert len(pos_shape) == 2, \
  268. f'The size of image should have length 1 or 2, ' \
  269. f'but got {len(pos_shape)}'
  270. self.pos_shape = pos_shape
  271. self.pos_dim = pos_dim
  272. self.pos_embed = nn.Parameter(
  273. torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
  274. self.drop = nn.Dropout(p=drop_rate)
  275. def init_weights(self):
  276. trunc_normal_(self.pos_embed, std=0.02)
  277. def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
  278. """Resize pos_embed weights.
  279. Resize pos_embed using bilinear interpolate method.
  280. Args:
  281. pos_embed (torch.Tensor): Position embedding weights.
  282. input_shape (tuple): Tuple for (downsampled input image height,
  283. downsampled input image width).
  284. mode (str): Algorithm used for upsampling:
  285. ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
  286. ``'trilinear'``. Default: ``'bilinear'``.
  287. Return:
  288. torch.Tensor: The resized pos_embed of shape [B, L_new, C].
  289. """
  290. assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
  291. pos_h, pos_w = self.pos_shape
  292. pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
  293. pos_embed_weight = pos_embed_weight.reshape(
  294. 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
  295. pos_embed_weight = F.interpolate(
  296. pos_embed_weight, size=input_shape, mode=mode)
  297. pos_embed_weight = torch.flatten(pos_embed_weight,
  298. 2).transpose(1, 2).contiguous()
  299. pos_embed = pos_embed_weight
  300. return pos_embed
  301. def forward(self, x, hw_shape, mode='bilinear'):
  302. pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
  303. return self.drop(x + pos_embed)
  304. @MODELS.register_module()
  305. class PyramidVisionTransformer(BaseModule):
  306. """Pyramid Vision Transformer (PVT)
  307. Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
  308. Dense Prediction without Convolutions
  309. <https://arxiv.org/pdf/2102.12122.pdf>`_.
  310. Args:
  311. pretrain_img_size (int | tuple[int]): The size of input image when
  312. pretrain. Defaults: 224.
  313. in_channels (int): Number of input channels. Default: 3.
  314. embed_dims (int): Embedding dimension. Default: 64.
  315. num_stags (int): The num of stages. Default: 4.
  316. num_layers (Sequence[int]): The layer number of each transformer encode
  317. layer. Default: [3, 4, 6, 3].
  318. num_heads (Sequence[int]): The attention heads of each transformer
  319. encode layer. Default: [1, 2, 5, 8].
  320. patch_sizes (Sequence[int]): The patch_size of each patch embedding.
  321. Default: [4, 2, 2, 2].
  322. strides (Sequence[int]): The stride of each patch embedding.
  323. Default: [4, 2, 2, 2].
  324. paddings (Sequence[int]): The padding of each patch embedding.
  325. Default: [0, 0, 0, 0].
  326. sr_ratios (Sequence[int]): The spatial reduction rate of each
  327. transformer encode layer. Default: [8, 4, 2, 1].
  328. out_indices (Sequence[int] | int): Output from which stages.
  329. Default: (0, 1, 2, 3).
  330. mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
  331. embedding dim of each transformer encode layer.
  332. Default: [8, 8, 4, 4].
  333. qkv_bias (bool): Enable bias for qkv if True. Default: True.
  334. drop_rate (float): Probability of an element to be zeroed.
  335. Default 0.0.
  336. attn_drop_rate (float): The drop out rate for attention layer.
  337. Default 0.0.
  338. drop_path_rate (float): stochastic depth rate. Default 0.1.
  339. use_abs_pos_embed (bool): If True, add absolute position embedding to
  340. the patch embedding. Defaults: True.
  341. use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
  342. Default: False.
  343. act_cfg (dict): The activation config for FFNs.
  344. Default: dict(type='GELU').
  345. norm_cfg (dict): Config dict for normalization layer.
  346. Default: dict(type='LN').
  347. pretrained (str, optional): model pretrained path. Default: None.
  348. convert_weights (bool): The flag indicates whether the
  349. pre-trained model is from the original repo. We may need
  350. to convert some keys to make it compatible.
  351. Default: True.
  352. init_cfg (dict or list[dict], optional): Initialization config dict.
  353. Default: None.
  354. """
  355. def __init__(self,
  356. pretrain_img_size=224,
  357. in_channels=3,
  358. embed_dims=64,
  359. num_stages=4,
  360. num_layers=[3, 4, 6, 3],
  361. num_heads=[1, 2, 5, 8],
  362. patch_sizes=[4, 2, 2, 2],
  363. strides=[4, 2, 2, 2],
  364. paddings=[0, 0, 0, 0],
  365. sr_ratios=[8, 4, 2, 1],
  366. out_indices=(0, 1, 2, 3),
  367. mlp_ratios=[8, 8, 4, 4],
  368. qkv_bias=True,
  369. drop_rate=0.,
  370. attn_drop_rate=0.,
  371. drop_path_rate=0.1,
  372. use_abs_pos_embed=True,
  373. norm_after_stage=False,
  374. use_conv_ffn=False,
  375. act_cfg=dict(type='GELU'),
  376. norm_cfg=dict(type='LN', eps=1e-6),
  377. pretrained=None,
  378. convert_weights=True,
  379. init_cfg=None):
  380. super().__init__(init_cfg=init_cfg)
  381. self.convert_weights = convert_weights
  382. if isinstance(pretrain_img_size, int):
  383. pretrain_img_size = to_2tuple(pretrain_img_size)
  384. elif isinstance(pretrain_img_size, tuple):
  385. if len(pretrain_img_size) == 1:
  386. pretrain_img_size = to_2tuple(pretrain_img_size[0])
  387. assert len(pretrain_img_size) == 2, \
  388. f'The size of image should have length 1 or 2, ' \
  389. f'but got {len(pretrain_img_size)}'
  390. assert not (init_cfg and pretrained), \
  391. 'init_cfg and pretrained cannot be setting at the same time'
  392. if isinstance(pretrained, str):
  393. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  394. 'please use "init_cfg" instead')
  395. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  396. elif pretrained is None:
  397. self.init_cfg = init_cfg
  398. else:
  399. raise TypeError('pretrained must be a str or None')
  400. self.embed_dims = embed_dims
  401. self.num_stages = num_stages
  402. self.num_layers = num_layers
  403. self.num_heads = num_heads
  404. self.patch_sizes = patch_sizes
  405. self.strides = strides
  406. self.sr_ratios = sr_ratios
  407. assert num_stages == len(num_layers) == len(num_heads) \
  408. == len(patch_sizes) == len(strides) == len(sr_ratios)
  409. self.out_indices = out_indices
  410. assert max(out_indices) < self.num_stages
  411. self.pretrained = pretrained
  412. # transformer encoder
  413. dpr = [
  414. x.item()
  415. for x in torch.linspace(0, drop_path_rate, sum(num_layers))
  416. ] # stochastic num_layer decay rule
  417. cur = 0
  418. self.layers = ModuleList()
  419. for i, num_layer in enumerate(num_layers):
  420. embed_dims_i = embed_dims * num_heads[i]
  421. patch_embed = PatchEmbed(
  422. in_channels=in_channels,
  423. embed_dims=embed_dims_i,
  424. kernel_size=patch_sizes[i],
  425. stride=strides[i],
  426. padding=paddings[i],
  427. bias=True,
  428. norm_cfg=norm_cfg)
  429. layers = ModuleList()
  430. if use_abs_pos_embed:
  431. pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
  432. pos_embed = AbsolutePositionEmbedding(
  433. pos_shape=pos_shape,
  434. pos_dim=embed_dims_i,
  435. drop_rate=drop_rate)
  436. layers.append(pos_embed)
  437. layers.extend([
  438. PVTEncoderLayer(
  439. embed_dims=embed_dims_i,
  440. num_heads=num_heads[i],
  441. feedforward_channels=mlp_ratios[i] * embed_dims_i,
  442. drop_rate=drop_rate,
  443. attn_drop_rate=attn_drop_rate,
  444. drop_path_rate=dpr[cur + idx],
  445. qkv_bias=qkv_bias,
  446. act_cfg=act_cfg,
  447. norm_cfg=norm_cfg,
  448. sr_ratio=sr_ratios[i],
  449. use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
  450. ])
  451. in_channels = embed_dims_i
  452. # The ret[0] of build_norm_layer is norm name.
  453. if norm_after_stage:
  454. norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
  455. else:
  456. norm = nn.Identity()
  457. self.layers.append(ModuleList([patch_embed, layers, norm]))
  458. cur += num_layer
  459. def init_weights(self):
  460. logger = MMLogger.get_current_instance()
  461. if self.init_cfg is None:
  462. logger.warn(f'No pre-trained weights for '
  463. f'{self.__class__.__name__}, '
  464. f'training start from scratch')
  465. for m in self.modules():
  466. if isinstance(m, nn.Linear):
  467. trunc_normal_init(m, std=.02, bias=0.)
  468. elif isinstance(m, nn.LayerNorm):
  469. constant_init(m, 1.0)
  470. elif isinstance(m, nn.Conv2d):
  471. fan_out = m.kernel_size[0] * m.kernel_size[
  472. 1] * m.out_channels
  473. fan_out //= m.groups
  474. normal_init(m, 0, math.sqrt(2.0 / fan_out))
  475. elif isinstance(m, AbsolutePositionEmbedding):
  476. m.init_weights()
  477. else:
  478. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  479. f'specify `Pretrained` in ' \
  480. f'`init_cfg` in ' \
  481. f'{self.__class__.__name__} '
  482. checkpoint = CheckpointLoader.load_checkpoint(
  483. self.init_cfg.checkpoint, logger=logger, map_location='cpu')
  484. logger.warn(f'Load pre-trained model for '
  485. f'{self.__class__.__name__} from original repo')
  486. if 'state_dict' in checkpoint:
  487. state_dict = checkpoint['state_dict']
  488. elif 'model' in checkpoint:
  489. state_dict = checkpoint['model']
  490. else:
  491. state_dict = checkpoint
  492. if self.convert_weights:
  493. # Because pvt backbones are not supported by mmpretrain,
  494. # so we need to convert pre-trained weights to match this
  495. # implementation.
  496. state_dict = pvt_convert(state_dict)
  497. load_state_dict(self, state_dict, strict=False, logger=logger)
  498. def forward(self, x):
  499. outs = []
  500. for i, layer in enumerate(self.layers):
  501. x, hw_shape = layer[0](x)
  502. for block in layer[1]:
  503. x = block(x, hw_shape)
  504. x = layer[2](x)
  505. x = nlc_to_nchw(x, hw_shape)
  506. if i in self.out_indices:
  507. outs.append(x)
  508. return outs
  509. @MODELS.register_module()
  510. class PyramidVisionTransformerV2(PyramidVisionTransformer):
  511. """Implementation of `PVTv2: Improved Baselines with Pyramid Vision
  512. Transformer <https://arxiv.org/pdf/2106.13797.pdf>`_."""
  513. def __init__(self, **kwargs):
  514. super(PyramidVisionTransformerV2, self).__init__(
  515. patch_sizes=[7, 3, 3, 3],
  516. paddings=[3, 1, 1, 1],
  517. use_abs_pos_embed=False,
  518. norm_after_stage=True,
  519. use_conv_ffn=True,
  520. **kwargs)
  521. def pvt_convert(ckpt):
  522. new_ckpt = OrderedDict()
  523. # Process the concat between q linear weights and kv linear weights
  524. use_abs_pos_embed = False
  525. use_conv_ffn = False
  526. for k in ckpt.keys():
  527. if k.startswith('pos_embed'):
  528. use_abs_pos_embed = True
  529. if k.find('dwconv') >= 0:
  530. use_conv_ffn = True
  531. for k, v in ckpt.items():
  532. if k.startswith('head'):
  533. continue
  534. if k.startswith('norm.'):
  535. continue
  536. if k.startswith('cls_token'):
  537. continue
  538. if k.startswith('pos_embed'):
  539. stage_i = int(k.replace('pos_embed', ''))
  540. new_k = k.replace(f'pos_embed{stage_i}',
  541. f'layers.{stage_i - 1}.1.0.pos_embed')
  542. if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7
  543. new_v = v[:, 1:, :] # remove cls token
  544. else:
  545. new_v = v
  546. elif k.startswith('patch_embed'):
  547. stage_i = int(k.split('.')[0].replace('patch_embed', ''))
  548. new_k = k.replace(f'patch_embed{stage_i}',
  549. f'layers.{stage_i - 1}.0')
  550. new_v = v
  551. if 'proj.' in new_k:
  552. new_k = new_k.replace('proj.', 'projection.')
  553. elif k.startswith('block'):
  554. stage_i = int(k.split('.')[0].replace('block', ''))
  555. layer_i = int(k.split('.')[1])
  556. new_layer_i = layer_i + use_abs_pos_embed
  557. new_k = k.replace(f'block{stage_i}.{layer_i}',
  558. f'layers.{stage_i - 1}.1.{new_layer_i}')
  559. new_v = v
  560. if 'attn.q.' in new_k:
  561. sub_item_k = k.replace('q.', 'kv.')
  562. new_k = new_k.replace('q.', 'attn.in_proj_')
  563. new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
  564. elif 'attn.kv.' in new_k:
  565. continue
  566. elif 'attn.proj.' in new_k:
  567. new_k = new_k.replace('proj.', 'attn.out_proj.')
  568. elif 'attn.sr.' in new_k:
  569. new_k = new_k.replace('sr.', 'sr.')
  570. elif 'mlp.' in new_k:
  571. string = f'{new_k}-'
  572. new_k = new_k.replace('mlp.', 'ffn.layers.')
  573. if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
  574. new_v = v.reshape((*v.shape, 1, 1))
  575. new_k = new_k.replace('fc1.', '0.')
  576. new_k = new_k.replace('dwconv.dwconv.', '1.')
  577. if use_conv_ffn:
  578. new_k = new_k.replace('fc2.', '4.')
  579. else:
  580. new_k = new_k.replace('fc2.', '3.')
  581. string += f'{new_k} {v.shape}-{new_v.shape}'
  582. elif k.startswith('norm'):
  583. stage_i = int(k[4])
  584. new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2')
  585. new_v = v
  586. else:
  587. new_k = k
  588. new_v = v
  589. new_ckpt[new_k] = new_v
  590. return new_ckpt