bifpn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from typing import List
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn.bricks import Swish
  5. from mmengine.model import BaseModule
  6. from mmdet.registry import MODELS
  7. from mmdet.utils import MultiConfig, OptConfigType
  8. from .utils import DepthWiseConvBlock, DownChannelBlock, MaxPool2dSamePadding
  9. class BiFPNStage(nn.Module):
  10. """
  11. in_channels: List[int], input dim for P3, P4, P5
  12. out_channels: int, output dim for P2 - P7
  13. first_time: int, whether is the first bifpnstage
  14. conv_bn_act_pattern: bool, whether use conv_bn_act_pattern
  15. norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for
  16. normalization layer.
  17. epsilon: float, hyperparameter in fusion features
  18. """
  19. def __init__(self,
  20. in_channels: List[int],
  21. out_channels: int,
  22. first_time: bool = False,
  23. apply_bn_for_resampling: bool = True,
  24. conv_bn_act_pattern: bool = False,
  25. norm_cfg: OptConfigType = dict(
  26. type='BN', momentum=1e-2, eps=1e-3),
  27. epsilon: float = 1e-4) -> None:
  28. super().__init__()
  29. assert isinstance(in_channels, list)
  30. self.in_channels = in_channels
  31. self.out_channels = out_channels
  32. self.first_time = first_time
  33. self.apply_bn_for_resampling = apply_bn_for_resampling
  34. self.conv_bn_act_pattern = conv_bn_act_pattern
  35. self.norm_cfg = norm_cfg
  36. self.epsilon = epsilon
  37. if self.first_time:
  38. self.p5_down_channel = DownChannelBlock(
  39. self.in_channels[-1],
  40. self.out_channels,
  41. apply_norm=self.apply_bn_for_resampling,
  42. conv_bn_act_pattern=self.conv_bn_act_pattern,
  43. norm_cfg=norm_cfg)
  44. self.p4_down_channel = DownChannelBlock(
  45. self.in_channels[-2],
  46. self.out_channels,
  47. apply_norm=self.apply_bn_for_resampling,
  48. conv_bn_act_pattern=self.conv_bn_act_pattern,
  49. norm_cfg=norm_cfg)
  50. self.p3_down_channel = DownChannelBlock(
  51. self.in_channels[-3],
  52. self.out_channels,
  53. apply_norm=self.apply_bn_for_resampling,
  54. conv_bn_act_pattern=self.conv_bn_act_pattern,
  55. norm_cfg=norm_cfg)
  56. self.p5_to_p6 = nn.Sequential(
  57. DownChannelBlock(
  58. self.in_channels[-1],
  59. self.out_channels,
  60. apply_norm=self.apply_bn_for_resampling,
  61. conv_bn_act_pattern=self.conv_bn_act_pattern,
  62. norm_cfg=norm_cfg), MaxPool2dSamePadding(3, 2))
  63. self.p6_to_p7 = MaxPool2dSamePadding(3, 2)
  64. self.p4_level_connection = DownChannelBlock(
  65. self.in_channels[-2],
  66. self.out_channels,
  67. apply_norm=self.apply_bn_for_resampling,
  68. conv_bn_act_pattern=self.conv_bn_act_pattern,
  69. norm_cfg=norm_cfg)
  70. self.p5_level_connection = DownChannelBlock(
  71. self.in_channels[-1],
  72. self.out_channels,
  73. apply_norm=self.apply_bn_for_resampling,
  74. conv_bn_act_pattern=self.conv_bn_act_pattern,
  75. norm_cfg=norm_cfg)
  76. self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
  77. self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
  78. self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
  79. self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
  80. # bottom to up: feature map down_sample module
  81. self.p4_down_sample = MaxPool2dSamePadding(3, 2)
  82. self.p5_down_sample = MaxPool2dSamePadding(3, 2)
  83. self.p6_down_sample = MaxPool2dSamePadding(3, 2)
  84. self.p7_down_sample = MaxPool2dSamePadding(3, 2)
  85. # Fuse Conv Layers
  86. self.conv6_up = DepthWiseConvBlock(
  87. out_channels,
  88. out_channels,
  89. apply_norm=self.apply_bn_for_resampling,
  90. conv_bn_act_pattern=self.conv_bn_act_pattern,
  91. norm_cfg=norm_cfg)
  92. self.conv5_up = DepthWiseConvBlock(
  93. out_channels,
  94. out_channels,
  95. apply_norm=self.apply_bn_for_resampling,
  96. conv_bn_act_pattern=self.conv_bn_act_pattern,
  97. norm_cfg=norm_cfg)
  98. self.conv4_up = DepthWiseConvBlock(
  99. out_channels,
  100. out_channels,
  101. apply_norm=self.apply_bn_for_resampling,
  102. conv_bn_act_pattern=self.conv_bn_act_pattern,
  103. norm_cfg=norm_cfg)
  104. self.conv3_up = DepthWiseConvBlock(
  105. out_channels,
  106. out_channels,
  107. apply_norm=self.apply_bn_for_resampling,
  108. conv_bn_act_pattern=self.conv_bn_act_pattern,
  109. norm_cfg=norm_cfg)
  110. self.conv4_down = DepthWiseConvBlock(
  111. out_channels,
  112. out_channels,
  113. apply_norm=self.apply_bn_for_resampling,
  114. conv_bn_act_pattern=self.conv_bn_act_pattern,
  115. norm_cfg=norm_cfg)
  116. self.conv5_down = DepthWiseConvBlock(
  117. out_channels,
  118. out_channels,
  119. apply_norm=self.apply_bn_for_resampling,
  120. conv_bn_act_pattern=self.conv_bn_act_pattern,
  121. norm_cfg=norm_cfg)
  122. self.conv6_down = DepthWiseConvBlock(
  123. out_channels,
  124. out_channels,
  125. apply_norm=self.apply_bn_for_resampling,
  126. conv_bn_act_pattern=self.conv_bn_act_pattern,
  127. norm_cfg=norm_cfg)
  128. self.conv7_down = DepthWiseConvBlock(
  129. out_channels,
  130. out_channels,
  131. apply_norm=self.apply_bn_for_resampling,
  132. conv_bn_act_pattern=self.conv_bn_act_pattern,
  133. norm_cfg=norm_cfg)
  134. # weights
  135. self.p6_w1 = nn.Parameter(
  136. torch.ones(2, dtype=torch.float32), requires_grad=True)
  137. self.p6_w1_relu = nn.ReLU()
  138. self.p5_w1 = nn.Parameter(
  139. torch.ones(2, dtype=torch.float32), requires_grad=True)
  140. self.p5_w1_relu = nn.ReLU()
  141. self.p4_w1 = nn.Parameter(
  142. torch.ones(2, dtype=torch.float32), requires_grad=True)
  143. self.p4_w1_relu = nn.ReLU()
  144. self.p3_w1 = nn.Parameter(
  145. torch.ones(2, dtype=torch.float32), requires_grad=True)
  146. self.p3_w1_relu = nn.ReLU()
  147. self.p4_w2 = nn.Parameter(
  148. torch.ones(3, dtype=torch.float32), requires_grad=True)
  149. self.p4_w2_relu = nn.ReLU()
  150. self.p5_w2 = nn.Parameter(
  151. torch.ones(3, dtype=torch.float32), requires_grad=True)
  152. self.p5_w2_relu = nn.ReLU()
  153. self.p6_w2 = nn.Parameter(
  154. torch.ones(3, dtype=torch.float32), requires_grad=True)
  155. self.p6_w2_relu = nn.ReLU()
  156. self.p7_w2 = nn.Parameter(
  157. torch.ones(2, dtype=torch.float32), requires_grad=True)
  158. self.p7_w2_relu = nn.ReLU()
  159. self.swish = Swish()
  160. def combine(self, x):
  161. if not self.conv_bn_act_pattern:
  162. x = self.swish(x)
  163. return x
  164. def forward(self, x):
  165. if self.first_time:
  166. p3, p4, p5 = x
  167. # build feature map P6
  168. p6_in = self.p5_to_p6(p5)
  169. # build feature map P7
  170. p7_in = self.p6_to_p7(p6_in)
  171. p3_in = self.p3_down_channel(p3)
  172. p4_in = self.p4_down_channel(p4)
  173. p5_in = self.p5_down_channel(p5)
  174. else:
  175. p3_in, p4_in, p5_in, p6_in, p7_in = x
  176. # Weights for P6_0 and P7_0 to P6_1
  177. p6_w1 = self.p6_w1_relu(self.p6_w1)
  178. weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
  179. # Connections for P6_0 and P7_0 to P6_1 respectively
  180. p6_up = self.conv6_up(
  181. self.combine(weight[0] * p6_in +
  182. weight[1] * self.p6_upsample(p7_in)))
  183. # Weights for P5_0 and P6_1 to P5_1
  184. p5_w1 = self.p5_w1_relu(self.p5_w1)
  185. weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
  186. # Connections for P5_0 and P6_1 to P5_1 respectively
  187. p5_up = self.conv5_up(
  188. self.combine(weight[0] * p5_in +
  189. weight[1] * self.p5_upsample(p6_up)))
  190. # Weights for P4_0 and P5_1 to P4_1
  191. p4_w1 = self.p4_w1_relu(self.p4_w1)
  192. weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
  193. # Connections for P4_0 and P5_1 to P4_1 respectively
  194. p4_up = self.conv4_up(
  195. self.combine(weight[0] * p4_in +
  196. weight[1] * self.p4_upsample(p5_up)))
  197. # Weights for P3_0 and P4_1 to P3_2
  198. p3_w1 = self.p3_w1_relu(self.p3_w1)
  199. weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
  200. # Connections for P3_0 and P4_1 to P3_2 respectively
  201. p3_out = self.conv3_up(
  202. self.combine(weight[0] * p3_in +
  203. weight[1] * self.p3_upsample(p4_up)))
  204. if self.first_time:
  205. p4_in = self.p4_level_connection(p4)
  206. p5_in = self.p5_level_connection(p5)
  207. # Weights for P4_0, P4_1 and P3_2 to P4_2
  208. p4_w2 = self.p4_w2_relu(self.p4_w2)
  209. weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
  210. # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
  211. p4_out = self.conv4_down(
  212. self.combine(weight[0] * p4_in + weight[1] * p4_up +
  213. weight[2] * self.p4_down_sample(p3_out)))
  214. # Weights for P5_0, P5_1 and P4_2 to P5_2
  215. p5_w2 = self.p5_w2_relu(self.p5_w2)
  216. weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
  217. # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
  218. p5_out = self.conv5_down(
  219. self.combine(weight[0] * p5_in + weight[1] * p5_up +
  220. weight[2] * self.p5_down_sample(p4_out)))
  221. # Weights for P6_0, P6_1 and P5_2 to P6_2
  222. p6_w2 = self.p6_w2_relu(self.p6_w2)
  223. weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
  224. # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
  225. p6_out = self.conv6_down(
  226. self.combine(weight[0] * p6_in + weight[1] * p6_up +
  227. weight[2] * self.p6_down_sample(p5_out)))
  228. # Weights for P7_0 and P6_2 to P7_2
  229. p7_w2 = self.p7_w2_relu(self.p7_w2)
  230. weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
  231. # Connections for P7_0 and P6_2 to P7_2
  232. p7_out = self.conv7_down(
  233. self.combine(weight[0] * p7_in +
  234. weight[1] * self.p7_down_sample(p6_out)))
  235. return p3_out, p4_out, p5_out, p6_out, p7_out
  236. @MODELS.register_module()
  237. class BiFPN(BaseModule):
  238. """
  239. num_stages: int, bifpn number of repeats
  240. in_channels: List[int], input dim for P3, P4, P5
  241. out_channels: int, output dim for P2 - P7
  242. start_level: int, Index of input features in backbone
  243. epsilon: float, hyperparameter in fusion features
  244. apply_bn_for_resampling: bool, whether use bn after resampling
  245. conv_bn_act_pattern: bool, whether use conv_bn_act_pattern
  246. norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for
  247. normalization layer.
  248. init_cfg: MultiConfig: init method
  249. """
  250. def __init__(self,
  251. num_stages: int,
  252. in_channels: List[int],
  253. out_channels: int,
  254. start_level: int = 0,
  255. epsilon: float = 1e-4,
  256. apply_bn_for_resampling: bool = True,
  257. conv_bn_act_pattern: bool = False,
  258. norm_cfg: OptConfigType = dict(
  259. type='BN', momentum=1e-2, eps=1e-3),
  260. init_cfg: MultiConfig = None) -> None:
  261. super().__init__(init_cfg=init_cfg)
  262. self.start_level = start_level
  263. self.bifpn = nn.Sequential(*[
  264. BiFPNStage(
  265. in_channels=in_channels,
  266. out_channels=out_channels,
  267. first_time=True if _ == 0 else False,
  268. apply_bn_for_resampling=apply_bn_for_resampling,
  269. conv_bn_act_pattern=conv_bn_act_pattern,
  270. norm_cfg=norm_cfg,
  271. epsilon=epsilon) for _ in range(num_stages)
  272. ])
  273. def forward(self, x):
  274. x = x[self.start_level:]
  275. x = self.bifpn(x)
  276. return x