vit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import build_activation_layer, build_norm_layer
  7. from mmcv.cnn.bricks import DropPath
  8. from mmengine.logging import MMLogger
  9. from mmengine.model import BaseModule
  10. from mmengine.runner.checkpoint import CheckpointLoader
  11. from mmdet.registry import MODELS
  12. @MODELS.register_module()
  13. class LN2d(nn.Module):
  14. """A LayerNorm variant, popularized by Transformers, that performs
  15. pointwise mean and variance normalization over the channel dimension for
  16. inputs that have shape (batch_size, channels, height, width)."""
  17. def __init__(self, normalized_shape, eps=1e-6):
  18. super().__init__()
  19. self.weight = nn.Parameter(torch.ones(normalized_shape))
  20. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  21. self.eps = eps
  22. self.normalized_shape = (normalized_shape, )
  23. def forward(self, x):
  24. u = x.mean(1, keepdim=True)
  25. s = (x - u).pow(2).mean(1, keepdim=True)
  26. x = (x - u) / torch.sqrt(s + self.eps)
  27. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  28. return x
  29. def get_abs_pos(abs_pos, has_cls_token, hw):
  30. h, w = hw
  31. if has_cls_token:
  32. abs_pos = abs_pos[:, 1:]
  33. xy_num = abs_pos.shape[1]
  34. size = int(math.sqrt(xy_num))
  35. assert size * size == xy_num
  36. if size != h or size != w:
  37. new_abs_pos = F.interpolate(
  38. abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
  39. size=(h, w),
  40. mode='bicubic',
  41. align_corners=False,
  42. )
  43. return new_abs_pos.permute(0, 2, 3, 1)
  44. else:
  45. return abs_pos.reshape(1, h, w, -1)
  46. def get_rel_pos(q_size, k_size, rel_pos):
  47. """
  48. Get relative positional embeddings according to the relative positions
  49. of query and key sizes.
  50. Args:
  51. q_size (int): size of query q.
  52. k_size (int): size of key k.
  53. rel_pos (Tensor): relative position embeddings (L, C).
  54. Returns:
  55. Extracted positional embeddings according to relative positions.
  56. """
  57. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  58. # Interpolate rel pos if needed.
  59. if rel_pos.shape[0] != max_rel_dist:
  60. # Interpolate rel pos.
  61. rel_pos_resized = F.interpolate(
  62. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  63. size=max_rel_dist,
  64. mode='linear',
  65. )
  66. rel_pos_resized = rel_pos_resized.reshape(-1,
  67. max_rel_dist).permute(1, 0)
  68. else:
  69. rel_pos_resized = rel_pos
  70. # Scale the coords with short length if shapes for q and k are different.
  71. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  72. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  73. relative_coords = (q_coords -
  74. k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  75. return rel_pos_resized[relative_coords.long()]
  76. def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
  77. """
  78. Args:
  79. attn (Tensor): attention map.
  80. q (Tensor):
  81. query q in the attention layer with shape (B, q_h * q_w, C).
  82. rel_pos_h (Tensor):
  83. relative position embeddings (Lh, C) for height axis.
  84. rel_pos_w (Tensor):
  85. relative position embeddings (Lw, C) for width axis.
  86. q_size (Tuple):
  87. spatial sequence size of query q with (q_h, q_w).
  88. k_size (Tuple):
  89. spatial sequence size of key k with (k_h, k_w).
  90. Returns:
  91. attn (Tensor): attention map with added relative positional embeddings.
  92. """
  93. q_h, q_w = q_size
  94. k_h, k_w = k_size
  95. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  96. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  97. B, _, dim = q.shape
  98. r_q = q.reshape(B, q_h, q_w, dim)
  99. rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
  100. rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
  101. attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] +
  102. rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
  103. return attn
  104. def window_partition(x, window_size):
  105. """
  106. Args:
  107. x: (B, H, W, C)
  108. window_size (int): window size
  109. Returns:
  110. windows: (num_windows*B, window_size, window_size, C)
  111. """
  112. B, H, W, C = x.shape
  113. pad_h = (window_size - H % window_size) % window_size
  114. pad_w = (window_size - W % window_size) % window_size
  115. if pad_h > 0 or pad_w > 0:
  116. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  117. Hp, Wp = H + pad_h, W + pad_w
  118. x = x.view(B, Hp // window_size, window_size, Wp // window_size,
  119. window_size, C)
  120. windows = x.permute(0, 1, 3, 2, 4,
  121. 5).contiguous().view(-1, window_size, window_size, C)
  122. return windows, (Hp, Wp)
  123. def window_unpartition(windows, window_size, pad_hw, hw):
  124. """
  125. Args:
  126. windows: (num_windows*B, window_size, window_size, C)
  127. window_size (int): Window size
  128. H (int): Height of image
  129. W (int): Width of image
  130. Returns:
  131. x: (B, H, W, C)
  132. """
  133. Hp, Wp = pad_hw
  134. H, W = hw
  135. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  136. x = windows.view(B, Hp // window_size, Wp // window_size, window_size,
  137. window_size, -1)
  138. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  139. if Hp > H or Wp > W:
  140. x = x[:, :H, :W, :].contiguous()
  141. return x
  142. class Attention(nn.Module):
  143. def __init__(self,
  144. dim,
  145. num_heads=8,
  146. qkv_bias=True,
  147. use_rel_pos=False,
  148. rel_pos_zero_init=True,
  149. input_size=None):
  150. super().__init__()
  151. self.num_heads = num_heads
  152. head_dim = dim // num_heads
  153. self.scale = head_dim**-0.5
  154. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  155. self.proj = nn.Linear(dim, dim)
  156. self.use_rel_pos = use_rel_pos
  157. if self.use_rel_pos:
  158. # initialize relative positional embeddings
  159. self.rel_pos_h = nn.Parameter(
  160. torch.zeros(2 * input_size[0] - 1, head_dim))
  161. self.rel_pos_w = nn.Parameter(
  162. torch.zeros(2 * input_size[1] - 1, head_dim))
  163. if not rel_pos_zero_init:
  164. nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
  165. nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
  166. def forward(self, x):
  167. B, H, W, _ = x.shape
  168. # qkv with shape (3, B, nHead, H * W, C)
  169. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,
  170. -1).permute(2, 0, 3, 1, 4)
  171. # q, k, v with shape (B * nHead, H * W, C)
  172. q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
  173. attn = (q * self.scale) @ k.transpose(-2, -1)
  174. if self.use_rel_pos:
  175. attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h,
  176. self.rel_pos_w, (H, W), (H, W))
  177. attn = attn.softmax(dim=-1)
  178. x = (attn @ v).view(B, self.num_heads, H, W,
  179. -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
  180. x = self.proj(x)
  181. return x
  182. class Mlp(nn.Module):
  183. """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
  184. def __init__(
  185. self,
  186. in_features,
  187. hidden_features=None,
  188. out_features=None,
  189. act_cfg=dict(type='GELU'),
  190. bias=True,
  191. drop=0.,
  192. ):
  193. super().__init__()
  194. out_features = out_features or in_features
  195. hidden_features = hidden_features or in_features
  196. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
  197. self.act = build_activation_layer(act_cfg)
  198. self.drop1 = nn.Dropout(drop)
  199. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
  200. self.drop2 = nn.Dropout(drop)
  201. def forward(self, x):
  202. x = self.fc1(x)
  203. x = self.act(x)
  204. x = self.drop1(x)
  205. x = self.fc2(x)
  206. x = self.drop2(x)
  207. return x
  208. class Block(nn.Module):
  209. def __init__(
  210. self,
  211. dim,
  212. num_heads,
  213. mlp_ratio=4.0,
  214. qkv_bias=True,
  215. drop_path=0.0,
  216. norm_cfg=dict(type='LN', eps=1e-6),
  217. act_cfg=dict(type='GELU'),
  218. use_rel_pos=False,
  219. rel_pos_zero_init=True,
  220. window_size=0,
  221. input_size=None,
  222. ):
  223. super().__init__()
  224. self.norm1 = build_norm_layer(norm_cfg, dim)[1]
  225. self.attn = Attention(
  226. dim,
  227. num_heads=num_heads,
  228. qkv_bias=qkv_bias,
  229. use_rel_pos=use_rel_pos,
  230. rel_pos_zero_init=rel_pos_zero_init,
  231. input_size=input_size if window_size == 0 else
  232. (window_size, window_size),
  233. )
  234. self.drop_path = DropPath(
  235. drop_path) if drop_path > 0. else nn.Identity()
  236. self.norm2 = build_norm_layer(norm_cfg, dim)[1]
  237. self.mlp = Mlp(
  238. in_features=dim,
  239. hidden_features=int(dim * mlp_ratio),
  240. act_cfg=act_cfg)
  241. self.window_size = window_size
  242. def forward(self, x):
  243. shortcut = x
  244. x = self.norm1(x)
  245. # Window partition
  246. if self.window_size > 0:
  247. H, W = x.shape[1], x.shape[2]
  248. x, pad_hw = window_partition(x, self.window_size)
  249. x = self.attn(x)
  250. # Reverse window partition
  251. if self.window_size > 0:
  252. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  253. x = shortcut + self.drop_path(x)
  254. x = x + self.drop_path(self.mlp(self.norm2(x)))
  255. return x
  256. class PatchEmbed(nn.Module):
  257. """Image to Patch Embedding."""
  258. def __init__(self,
  259. kernel_size=(16, 16),
  260. stride=(16, 16),
  261. padding=(0, 0),
  262. in_chans=3,
  263. embed_dim=768):
  264. """
  265. Args:
  266. kernel_size (Tuple): kernel size of the projection layer.
  267. stride (Tuple): stride of the projection layer.
  268. padding (Tuple): padding size of the projection layer.
  269. in_chans (int): Number of input image channels.
  270. embed_dim (int): embed_dim (int): Patch embedding dimension.
  271. """
  272. super().__init__()
  273. self.proj = nn.Conv2d(
  274. in_chans,
  275. embed_dim,
  276. kernel_size=kernel_size,
  277. stride=stride,
  278. padding=padding)
  279. def forward(self, x):
  280. x = self.proj(x)
  281. # B C H W -> B H W C
  282. x = x.permute(0, 2, 3, 1)
  283. return x
  284. @MODELS.register_module()
  285. class ViT(BaseModule):
  286. """Vision Transformer with support for patch or hybrid CNN input stage."""
  287. def __init__(self,
  288. img_size=1024,
  289. patch_size=16,
  290. in_chans=3,
  291. embed_dim=768,
  292. depth=12,
  293. num_heads=12,
  294. mlp_ratio=4.0,
  295. qkv_bias=True,
  296. drop_path_rate=0.0,
  297. norm_cfg=dict(type='LN', eps=1e-6),
  298. act_cfg=dict(type='GELU'),
  299. use_abs_pos=True,
  300. use_rel_pos=False,
  301. rel_pos_zero_init=True,
  302. window_size=0,
  303. window_block_indexes=(0, 1, 3, 4, 6, 7, 9, 10),
  304. pretrain_img_size=224,
  305. pretrain_use_cls_token=True,
  306. init_cfg=None):
  307. super().__init__()
  308. self.pretrain_use_cls_token = pretrain_use_cls_token
  309. self.init_cfg = init_cfg
  310. self.patch_embed = PatchEmbed(
  311. kernel_size=(patch_size, patch_size),
  312. stride=(patch_size, patch_size),
  313. in_chans=in_chans,
  314. embed_dim=embed_dim)
  315. if use_abs_pos:
  316. num_patches = (pretrain_img_size // patch_size) * (
  317. pretrain_img_size // patch_size)
  318. num_positions = (num_patches +
  319. 1) if pretrain_use_cls_token else num_patches
  320. self.pos_embed = nn.Parameter(
  321. torch.zeros(1, num_positions, embed_dim))
  322. else:
  323. self.pos_embed = None
  324. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
  325. self.blocks = nn.ModuleList([
  326. Block(
  327. dim=embed_dim,
  328. num_heads=num_heads,
  329. mlp_ratio=mlp_ratio,
  330. qkv_bias=qkv_bias,
  331. drop_path=dpr[i],
  332. norm_cfg=norm_cfg,
  333. act_cfg=act_cfg,
  334. use_rel_pos=use_rel_pos,
  335. rel_pos_zero_init=rel_pos_zero_init,
  336. window_size=window_size if i in window_block_indexes else 0,
  337. input_size=(img_size // patch_size, img_size // patch_size))
  338. for i in range(depth)
  339. ])
  340. if self.pos_embed is not None:
  341. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  342. def _init_weights(self, m):
  343. if isinstance(m, nn.Linear):
  344. nn.init.trunc_normal_(m.weight, std=0.02)
  345. if isinstance(m, nn.Linear) and m.bias is not None:
  346. nn.init.constant_(m.bias, 0)
  347. elif isinstance(m, nn.LayerNorm):
  348. nn.init.constant_(m.bias, 0)
  349. nn.init.constant_(m.weight, 1.0)
  350. def init_weights(self):
  351. logger = MMLogger.get_current_instance()
  352. if self.init_cfg is None:
  353. logger.warn(f'No pre-trained weights for '
  354. f'{self.__class__.__name__}, '
  355. f'training start from scratch')
  356. self.apply(self._init_weights)
  357. else:
  358. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  359. f'specify `Pretrained` in ' \
  360. f'`init_cfg` in ' \
  361. f'{self.__class__.__name__} '
  362. ckpt = CheckpointLoader.load_checkpoint(
  363. self.init_cfg.checkpoint, logger=logger, map_location='cpu')
  364. if 'model' in ckpt:
  365. _state_dict = ckpt['model']
  366. self.load_state_dict(_state_dict, False)
  367. def forward(self, x):
  368. x = self.patch_embed(x)
  369. if self.pos_embed is not None:
  370. x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token,
  371. (x.shape[1], x.shape[2]))
  372. for blk in self.blocks:
  373. x = blk(x)
  374. x = x.permute(0, 3, 1, 2)
  375. return x