focalnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.utils.checkpoint as checkpoint
  5. from mmcv.cnn.bricks import DropPath
  6. from mmdet.registry import MODELS
  7. # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/backbone/focal_dw.py # noqa
  8. @MODELS.register_module()
  9. class FocalNet(nn.Module):
  10. def __init__(
  11. self,
  12. patch_size=4,
  13. in_chans=3,
  14. embed_dim=96,
  15. depths=[2, 2, 6, 2],
  16. mlp_ratio=4.,
  17. drop_rate=0.,
  18. drop_path_rate=0.3,
  19. norm_layer=nn.LayerNorm,
  20. patch_norm=True,
  21. out_indices=[0, 1, 2, 3],
  22. frozen_stages=-1,
  23. focal_levels=[3, 3, 3, 3],
  24. focal_windows=[3, 3, 3, 3],
  25. use_pre_norms=[False, False, False, False],
  26. use_conv_embed=True,
  27. use_postln=True,
  28. use_postln_in_modulation=False,
  29. scaling_modulator=True,
  30. use_layerscale=True,
  31. use_checkpoint=False,
  32. ):
  33. super().__init__()
  34. self.num_layers = len(depths)
  35. self.embed_dim = embed_dim
  36. self.patch_norm = patch_norm
  37. self.out_indices = out_indices
  38. self.frozen_stages = frozen_stages
  39. # split image into non-overlapping patches
  40. self.patch_embed = PatchEmbed(
  41. patch_size=patch_size,
  42. in_chans=in_chans,
  43. embed_dim=embed_dim,
  44. norm_layer=norm_layer if self.patch_norm else None,
  45. use_conv_embed=use_conv_embed,
  46. is_stem=True,
  47. use_pre_norm=False)
  48. self.pos_drop = nn.Dropout(p=drop_rate)
  49. dpr = [
  50. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  51. ]
  52. self.layers = nn.ModuleList()
  53. for i_layer in range(self.num_layers):
  54. layer = BasicLayer(
  55. dim=int(embed_dim * 2**i_layer),
  56. depth=depths[i_layer],
  57. mlp_ratio=mlp_ratio,
  58. drop=drop_rate,
  59. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  60. norm_layer=norm_layer,
  61. downsample=PatchEmbed if
  62. (i_layer < self.num_layers - 1) else None,
  63. focal_window=focal_windows[i_layer],
  64. focal_level=focal_levels[i_layer],
  65. use_pre_norm=use_pre_norms[i_layer],
  66. use_conv_embed=use_conv_embed,
  67. use_postln=use_postln,
  68. use_postln_in_modulation=use_postln_in_modulation,
  69. scaling_modulator=scaling_modulator,
  70. use_layerscale=use_layerscale,
  71. use_checkpoint=use_checkpoint)
  72. self.layers.append(layer)
  73. num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
  74. self.num_features = num_features
  75. # add a norm layer for each output
  76. for i_layer in self.out_indices:
  77. layer = norm_layer(num_features[i_layer])
  78. layer_name = f'norm{i_layer}'
  79. self.add_module(layer_name, layer)
  80. def forward(self, x):
  81. x = self.patch_embed(x)
  82. Wh, Ww = x.size(2), x.size(3)
  83. x = x.flatten(2).transpose(1, 2)
  84. x = self.pos_drop(x)
  85. outs = {}
  86. for i in range(self.num_layers):
  87. layer = self.layers[i]
  88. x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
  89. if i in self.out_indices:
  90. norm_layer = getattr(self, f'norm{i}')
  91. x_out = norm_layer(x_out)
  92. out = x_out.view(-1, H, W,
  93. self.num_features[i]).permute(0, 3, 1,
  94. 2).contiguous()
  95. outs['res{}'.format(i + 2)] = out
  96. return outs
  97. class Mlp(nn.Module):
  98. """Multilayer perceptron."""
  99. def __init__(self,
  100. in_features,
  101. hidden_features=None,
  102. out_features=None,
  103. act_layer=nn.GELU,
  104. drop=0.):
  105. super().__init__()
  106. out_features = out_features or in_features
  107. hidden_features = hidden_features or in_features
  108. self.fc1 = nn.Linear(in_features, hidden_features)
  109. self.act = act_layer()
  110. self.fc2 = nn.Linear(hidden_features, out_features)
  111. self.drop = nn.Dropout(drop)
  112. def forward(self, x):
  113. x = self.fc1(x)
  114. x = self.act(x)
  115. x = self.drop(x)
  116. x = self.fc2(x)
  117. x = self.drop(x)
  118. return x
  119. class FocalModulation(nn.Module):
  120. """Focal Modulation.
  121. Args:
  122. dim (int): Number of input channels.
  123. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  124. focal_level (int): Number of focal levels
  125. focal_window (int): Focal window size at focal level 1
  126. focal_factor (int, default=2): Step to increase the focal window
  127. """
  128. def __init__(self,
  129. dim,
  130. proj_drop=0.,
  131. focal_level=2,
  132. focal_window=7,
  133. focal_factor=2,
  134. use_postln_in_modulation=False,
  135. scaling_modulator=False):
  136. super().__init__()
  137. self.dim = dim
  138. self.focal_level = focal_level
  139. self.focal_window = focal_window
  140. self.focal_factor = focal_factor
  141. self.use_postln_in_modulation = use_postln_in_modulation
  142. self.scaling_modulator = scaling_modulator
  143. self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True)
  144. self.h = nn.Conv2d(
  145. dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
  146. self.act = nn.GELU()
  147. self.proj = nn.Linear(dim, dim)
  148. self.proj_drop = nn.Dropout(proj_drop)
  149. self.focal_layers = nn.ModuleList()
  150. if self.use_postln_in_modulation:
  151. self.ln = nn.LayerNorm(dim)
  152. for k in range(self.focal_level):
  153. kernel_size = self.focal_factor * k + self.focal_window
  154. self.focal_layers.append(
  155. nn.Sequential(
  156. nn.Conv2d(
  157. dim,
  158. dim,
  159. kernel_size=kernel_size,
  160. stride=1,
  161. groups=dim,
  162. padding=kernel_size // 2,
  163. bias=False),
  164. nn.GELU(),
  165. ))
  166. def forward(self, x):
  167. """Forward function.
  168. Args:
  169. x: input features with shape of (B, H, W, C)
  170. """
  171. B, nH, nW, C = x.shape
  172. x = self.f(x)
  173. x = x.permute(0, 3, 1, 2).contiguous()
  174. q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1)
  175. ctx_all = 0
  176. for level in range(self.focal_level):
  177. ctx = self.focal_layers[level](ctx)
  178. ctx_all = ctx_all + ctx * gates[:, level:level + 1]
  179. ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
  180. ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
  181. if self.scaling_modulator:
  182. ctx_all = ctx_all / (self.focal_level + 1)
  183. x_out = q * self.h(ctx_all)
  184. x_out = x_out.permute(0, 2, 3, 1).contiguous()
  185. if self.use_postln_in_modulation:
  186. x_out = self.ln(x_out)
  187. x_out = self.proj(x_out)
  188. x_out = self.proj_drop(x_out)
  189. return x_out
  190. class FocalModulationBlock(nn.Module):
  191. """Focal Modulation Block.
  192. Args:
  193. dim (int): Number of input channels.
  194. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  195. drop (float, optional): Dropout rate. Default: 0.0
  196. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  197. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  198. norm_layer (nn.Module, optional): Normalization layer.
  199. Default: nn.LayerNorm
  200. focal_level (int): number of focal levels
  201. focal_window (int): focal kernel size at level 1
  202. """
  203. def __init__(self,
  204. dim,
  205. mlp_ratio=4.,
  206. drop=0.,
  207. drop_path=0.,
  208. act_layer=nn.GELU,
  209. norm_layer=nn.LayerNorm,
  210. focal_level=2,
  211. focal_window=9,
  212. use_postln=False,
  213. use_postln_in_modulation=False,
  214. scaling_modulator=False,
  215. use_layerscale=False,
  216. layerscale_value=1e-4):
  217. super().__init__()
  218. self.dim = dim
  219. self.mlp_ratio = mlp_ratio
  220. self.focal_window = focal_window
  221. self.focal_level = focal_level
  222. self.use_postln = use_postln
  223. self.use_layerscale = use_layerscale
  224. self.dw1 = nn.Conv2d(
  225. dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
  226. self.norm1 = norm_layer(dim)
  227. self.modulation = FocalModulation(
  228. dim,
  229. focal_window=self.focal_window,
  230. focal_level=self.focal_level,
  231. proj_drop=drop,
  232. use_postln_in_modulation=use_postln_in_modulation,
  233. scaling_modulator=scaling_modulator)
  234. self.dw2 = nn.Conv2d(
  235. dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
  236. self.drop_path = DropPath(
  237. drop_path) if drop_path > 0. else nn.Identity()
  238. self.norm2 = norm_layer(dim)
  239. mlp_hidden_dim = int(dim * mlp_ratio)
  240. self.mlp = Mlp(
  241. in_features=dim,
  242. hidden_features=mlp_hidden_dim,
  243. act_layer=act_layer,
  244. drop=drop)
  245. self.H = None
  246. self.W = None
  247. self.gamma_1 = 1.0
  248. self.gamma_2 = 1.0
  249. if self.use_layerscale:
  250. self.gamma_1 = nn.Parameter(
  251. layerscale_value * torch.ones(dim), requires_grad=True)
  252. self.gamma_2 = nn.Parameter(
  253. layerscale_value * torch.ones(dim), requires_grad=True)
  254. def forward(self, x):
  255. """Forward function.
  256. Args:
  257. x: Input feature, tensor size (B, H*W, C).
  258. H, W: Spatial resolution of the input feature.
  259. """
  260. B, L, C = x.shape
  261. H, W = self.H, self.W
  262. assert L == H * W, 'input feature has wrong size'
  263. x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
  264. x = x + self.dw1(x)
  265. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
  266. shortcut = x
  267. if not self.use_postln:
  268. x = self.norm1(x)
  269. x = x.view(B, H, W, C)
  270. # FM
  271. x = self.modulation(x).view(B, H * W, C)
  272. x = shortcut + self.drop_path(self.gamma_1 * x)
  273. if self.use_postln:
  274. x = self.norm1(x)
  275. x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
  276. x = x + self.dw2(x)
  277. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
  278. if not self.use_postln:
  279. x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
  280. else:
  281. x = x + self.drop_path(self.gamma_2 * self.mlp(x))
  282. x = self.norm2(x)
  283. return x
  284. class BasicLayer(nn.Module):
  285. """A basic focal modulation layer for one stage.
  286. Args:
  287. dim (int): Number of feature channels
  288. depth (int): Depths of this stage.
  289. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  290. Default: 4.
  291. drop (float, optional): Dropout rate. Default: 0.0
  292. drop_path (float | tuple[float], optional): Stochastic depth rate.
  293. Default: 0.0
  294. norm_layer (nn.Module, optional): Normalization layer.
  295. Default: nn.LayerNorm
  296. downsample (nn.Module | None, optional): Downsample layer at the
  297. end of the layer. Default: None
  298. focal_level (int): Number of focal levels
  299. focal_window (int): Focal window size at focal level 1
  300. use_conv_embed (bool): Use overlapped convolution for patch
  301. embedding or now. Default: False
  302. use_checkpoint (bool): Whether to use checkpointing to save memory.
  303. Default: False
  304. """
  305. def __init__(
  306. self,
  307. dim,
  308. depth,
  309. mlp_ratio=4.,
  310. drop=0.,
  311. drop_path=0.,
  312. norm_layer=nn.LayerNorm,
  313. downsample=None,
  314. focal_window=9,
  315. focal_level=2,
  316. use_conv_embed=False,
  317. use_postln=False,
  318. use_postln_in_modulation=False,
  319. scaling_modulator=False,
  320. use_layerscale=False,
  321. use_checkpoint=False,
  322. use_pre_norm=False,
  323. ):
  324. super().__init__()
  325. self.depth = depth
  326. self.use_checkpoint = use_checkpoint
  327. # build blocks
  328. self.blocks = nn.ModuleList([
  329. FocalModulationBlock(
  330. dim=dim,
  331. mlp_ratio=mlp_ratio,
  332. drop=drop,
  333. drop_path=drop_path[i]
  334. if isinstance(drop_path, list) else drop_path,
  335. focal_window=focal_window,
  336. focal_level=focal_level,
  337. use_postln=use_postln,
  338. use_postln_in_modulation=use_postln_in_modulation,
  339. scaling_modulator=scaling_modulator,
  340. use_layerscale=use_layerscale,
  341. norm_layer=norm_layer) for i in range(depth)
  342. ])
  343. # patch merging layer
  344. if downsample is not None:
  345. self.downsample = downsample(
  346. patch_size=2,
  347. in_chans=dim,
  348. embed_dim=2 * dim,
  349. use_conv_embed=use_conv_embed,
  350. norm_layer=norm_layer,
  351. is_stem=False,
  352. use_pre_norm=use_pre_norm)
  353. else:
  354. self.downsample = None
  355. def forward(self, x, H, W):
  356. """Forward function.
  357. Args:
  358. x: Input feature, tensor size (B, H*W, C).
  359. H, W: Spatial resolution of the input feature.
  360. """
  361. for blk in self.blocks:
  362. blk.H, blk.W = H, W
  363. if self.use_checkpoint:
  364. x = checkpoint.checkpoint(blk, x)
  365. else:
  366. x = blk(x)
  367. if self.downsample is not None:
  368. x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
  369. x_down = self.downsample(x_reshaped)
  370. x_down = x_down.flatten(2).transpose(1, 2)
  371. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  372. return x, H, W, x_down, Wh, Ww
  373. else:
  374. return x, H, W, x, H, W
  375. class PatchEmbed(nn.Module):
  376. """Image to Patch Embedding.
  377. Args:
  378. patch_size (int): Patch token size. Default: 4.
  379. in_chans (int): Number of input image channels. Default: 3.
  380. embed_dim (int): Number of linear projection output channels.
  381. Default: 96.
  382. norm_layer (nn.Module, optional): Normalization layer.
  383. Default: None
  384. use_conv_embed (bool): Whether use overlapped convolution for
  385. patch embedding. Default: False
  386. is_stem (bool): Is the stem block or not.
  387. """
  388. def __init__(self,
  389. patch_size=4,
  390. in_chans=3,
  391. embed_dim=96,
  392. norm_layer=None,
  393. use_conv_embed=False,
  394. is_stem=False,
  395. use_pre_norm=False):
  396. super().__init__()
  397. patch_size = (patch_size, patch_size)
  398. self.patch_size = patch_size
  399. self.in_chans = in_chans
  400. self.embed_dim = embed_dim
  401. self.use_pre_norm = use_pre_norm
  402. if use_conv_embed:
  403. # if we choose to use conv embedding,
  404. # then we treat the stem and non-stem differently
  405. if is_stem:
  406. kernel_size = 7
  407. padding = 3
  408. stride = 4
  409. else:
  410. kernel_size = 3
  411. padding = 1
  412. stride = 2
  413. self.proj = nn.Conv2d(
  414. in_chans,
  415. embed_dim,
  416. kernel_size=kernel_size,
  417. stride=stride,
  418. padding=padding)
  419. else:
  420. self.proj = nn.Conv2d(
  421. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  422. if self.use_pre_norm:
  423. if norm_layer is not None:
  424. self.norm = norm_layer(in_chans)
  425. else:
  426. self.norm = None
  427. else:
  428. if norm_layer is not None:
  429. self.norm = norm_layer(embed_dim)
  430. else:
  431. self.norm = None
  432. def forward(self, x):
  433. """Forward function."""
  434. B, C, H, W = x.size()
  435. if W % self.patch_size[1] != 0:
  436. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
  437. if H % self.patch_size[0] != 0:
  438. x = F.pad(x,
  439. (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
  440. if self.use_pre_norm:
  441. if self.norm is not None:
  442. x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
  443. x = self.norm(x).transpose(1, 2).view(B, C, H, W)
  444. x = self.proj(x)
  445. else:
  446. x = self.proj(x) # B C Wh Ww
  447. if self.norm is not None:
  448. Wh, Ww = x.size(2), x.size(3)
  449. x = x.flatten(2).transpose(1, 2)
  450. x = self.norm(x)
  451. x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
  452. return x