fpn.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple, Union
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import ConvModule
  6. from mmengine.model import BaseModule
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import ConfigType, MultiConfig, OptConfigType
  10. @MODELS.register_module()
  11. class FPN(BaseModule):
  12. r"""Feature Pyramid Network.
  13. This is an implementation of paper `Feature Pyramid Networks for Object
  14. Detection <https://arxiv.org/abs/1612.03144>`_.
  15. Args:
  16. in_channels (list[int]): Number of input channels per scale.
  17. out_channels (int): Number of output channels (used at each scale).
  18. num_outs (int): Number of output scales.
  19. start_level (int): Index of the start input backbone level used to
  20. build the feature pyramid. Defaults to 0.
  21. end_level (int): Index of the end input backbone level (exclusive) to
  22. build the feature pyramid. Defaults to -1, which means the
  23. last level.
  24. add_extra_convs (bool | str): If bool, it decides whether to add conv
  25. layers on top of the original feature maps. Defaults to False.
  26. If True, it is equivalent to `add_extra_convs='on_input'`.
  27. If str, it specifies the source feature map of the extra convs.
  28. Only the following options are allowed
  29. - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
  30. - 'on_lateral': Last feature map after lateral convs.
  31. - 'on_output': The last output feature map after fpn convs.
  32. relu_before_extra_convs (bool): Whether to apply relu before the extra
  33. conv. Defaults to False.
  34. no_norm_on_lateral (bool): Whether to apply norm on lateral.
  35. Defaults to False.
  36. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  37. convolution layer. Defaults to None.
  38. norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  39. normalization layer. Defaults to None.
  40. act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  41. activation layer in ConvModule. Defaults to None.
  42. upsample_cfg (:obj:`ConfigDict` or dict, optional): Config dict
  43. for interpolate layer. Defaults to dict(mode='nearest').
  44. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  45. dict]): Initialization config dict.
  46. Example:
  47. >>> import torch
  48. >>> in_channels = [2, 3, 5, 7]
  49. >>> scales = [340, 170, 84, 43]
  50. >>> inputs = [torch.rand(1, c, s, s)
  51. ... for c, s in zip(in_channels, scales)]
  52. >>> self = FPN(in_channels, 11, len(in_channels)).eval()
  53. >>> outputs = self.forward(inputs)
  54. >>> for i in range(len(outputs)):
  55. ... print(f'outputs[{i}].shape = {outputs[i].shape}')
  56. outputs[0].shape = torch.Size([1, 11, 340, 340])
  57. outputs[1].shape = torch.Size([1, 11, 170, 170])
  58. outputs[2].shape = torch.Size([1, 11, 84, 84])
  59. outputs[3].shape = torch.Size([1, 11, 43, 43])
  60. """
  61. def __init__(
  62. self,
  63. in_channels: List[int],
  64. out_channels: int,
  65. num_outs: int,
  66. start_level: int = 0,
  67. end_level: int = -1,
  68. add_extra_convs: Union[bool, str] = False,
  69. relu_before_extra_convs: bool = False,
  70. no_norm_on_lateral: bool = False,
  71. conv_cfg: OptConfigType = None,
  72. norm_cfg: OptConfigType = None,
  73. act_cfg: OptConfigType = None,
  74. upsample_cfg: ConfigType = dict(mode='nearest'),
  75. init_cfg: MultiConfig = dict(
  76. type='Xavier', layer='Conv2d', distribution='uniform')
  77. ) -> None:
  78. super().__init__(init_cfg=init_cfg)
  79. assert isinstance(in_channels, list)
  80. self.in_channels = in_channels
  81. self.out_channels = out_channels
  82. self.num_ins = len(in_channels)
  83. self.num_outs = num_outs
  84. self.relu_before_extra_convs = relu_before_extra_convs
  85. self.no_norm_on_lateral = no_norm_on_lateral
  86. self.fp16_enabled = False
  87. self.upsample_cfg = upsample_cfg.copy()
  88. if end_level == -1 or end_level == self.num_ins - 1:
  89. self.backbone_end_level = self.num_ins
  90. assert num_outs >= self.num_ins - start_level
  91. else:
  92. # if end_level is not the last level, no extra level is allowed
  93. self.backbone_end_level = end_level + 1
  94. assert end_level < self.num_ins
  95. assert num_outs == end_level - start_level + 1
  96. self.start_level = start_level
  97. self.end_level = end_level
  98. self.add_extra_convs = add_extra_convs
  99. assert isinstance(add_extra_convs, (str, bool))
  100. if isinstance(add_extra_convs, str):
  101. # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
  102. assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
  103. elif add_extra_convs: # True
  104. self.add_extra_convs = 'on_input'
  105. self.lateral_convs = nn.ModuleList()
  106. self.fpn_convs = nn.ModuleList()
  107. for i in range(self.start_level, self.backbone_end_level):
  108. l_conv = ConvModule(
  109. in_channels[i],
  110. out_channels,
  111. 1,
  112. conv_cfg=conv_cfg,
  113. norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
  114. act_cfg=act_cfg,
  115. inplace=False)
  116. fpn_conv = ConvModule(
  117. out_channels,
  118. out_channels,
  119. 3,
  120. padding=1,
  121. conv_cfg=conv_cfg,
  122. norm_cfg=norm_cfg,
  123. act_cfg=act_cfg,
  124. inplace=False)
  125. self.lateral_convs.append(l_conv)
  126. self.fpn_convs.append(fpn_conv)
  127. # add extra conv layers (e.g., RetinaNet)
  128. extra_levels = num_outs - self.backbone_end_level + self.start_level
  129. if self.add_extra_convs and extra_levels >= 1:
  130. for i in range(extra_levels):
  131. if i == 0 and self.add_extra_convs == 'on_input':
  132. in_channels = self.in_channels[self.backbone_end_level - 1]
  133. else:
  134. in_channels = out_channels
  135. extra_fpn_conv = ConvModule(
  136. in_channels,
  137. out_channels,
  138. 3,
  139. stride=2,
  140. padding=1,
  141. conv_cfg=conv_cfg,
  142. norm_cfg=norm_cfg,
  143. act_cfg=act_cfg,
  144. inplace=False)
  145. self.fpn_convs.append(extra_fpn_conv)
  146. def forward(self, inputs: Tuple[Tensor]) -> tuple:
  147. """Forward function.
  148. Args:
  149. inputs (tuple[Tensor]): Features from the upstream network, each
  150. is a 4D-tensor.
  151. Returns:
  152. tuple: Feature maps, each is a 4D-tensor.
  153. """
  154. assert len(inputs) == len(self.in_channels)
  155. # build laterals
  156. laterals = [
  157. lateral_conv(inputs[i + self.start_level])
  158. for i, lateral_conv in enumerate(self.lateral_convs)
  159. ]
  160. # build top-down path
  161. used_backbone_levels = len(laterals)
  162. for i in range(used_backbone_levels - 1, 0, -1):
  163. # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
  164. # it cannot co-exist with `size` in `F.interpolate`.
  165. if 'scale_factor' in self.upsample_cfg:
  166. # fix runtime error of "+=" inplace operation in PyTorch 1.10
  167. laterals[i - 1] = laterals[i - 1] + F.interpolate(
  168. laterals[i], **self.upsample_cfg)
  169. else:
  170. prev_shape = laterals[i - 1].shape[2:]
  171. laterals[i - 1] = laterals[i - 1] + F.interpolate(
  172. laterals[i], size=prev_shape, **self.upsample_cfg)
  173. # build outputs
  174. # part 1: from original levels
  175. outs = [
  176. self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
  177. ]
  178. # part 2: add extra levels
  179. if self.num_outs > len(outs):
  180. # use max pool to get more levels on top of outputs
  181. # (e.g., Faster R-CNN, Mask R-CNN)
  182. if not self.add_extra_convs:
  183. for i in range(self.num_outs - used_backbone_levels):
  184. outs.append(F.max_pool2d(outs[-1], 1, stride=2))
  185. # add conv layers on top of original feature maps (RetinaNet)
  186. else:
  187. if self.add_extra_convs == 'on_input':
  188. extra_source = inputs[self.backbone_end_level - 1]
  189. elif self.add_extra_convs == 'on_lateral':
  190. extra_source = laterals[-1]
  191. elif self.add_extra_convs == 'on_output':
  192. extra_source = outs[-1]
  193. else:
  194. raise NotImplementedError
  195. outs.append(self.fpn_convs[used_backbone_levels](extra_source))
  196. for i in range(used_backbone_levels + 1, self.num_outs):
  197. if self.relu_before_extra_convs:
  198. outs.append(self.fpn_convs[i](F.relu(outs[-1])))
  199. else:
  200. outs.append(self.fpn_convs[i](outs[-1]))
  201. return tuple(outs)