nas_fpn.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell
  6. from mmengine.model import BaseModule, ModuleList
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import MultiConfig, OptConfigType
  10. @MODELS.register_module()
  11. class NASFPN(BaseModule):
  12. """NAS-FPN.
  13. Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture
  14. for Object Detection <https://arxiv.org/abs/1904.07392>`_
  15. Args:
  16. in_channels (List[int]): Number of input channels per scale.
  17. out_channels (int): Number of output channels (used at each scale)
  18. num_outs (int): Number of output scales.
  19. stack_times (int): The number of times the pyramid architecture will
  20. be stacked.
  21. start_level (int): Index of the start input backbone level used to
  22. build the feature pyramid. Defaults to 0.
  23. end_level (int): Index of the end input backbone level (exclusive) to
  24. build the feature pyramid. Defaults to -1, which means the
  25. last level.
  26. norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  27. normalization layer. Defaults to None.
  28. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  29. dict]): Initialization config dict.
  30. """
  31. def __init__(
  32. self,
  33. in_channels: List[int],
  34. out_channels: int,
  35. num_outs: int,
  36. stack_times: int,
  37. start_level: int = 0,
  38. end_level: int = -1,
  39. norm_cfg: OptConfigType = None,
  40. init_cfg: MultiConfig = dict(type='Caffe2Xavier', layer='Conv2d')
  41. ) -> None:
  42. super().__init__(init_cfg=init_cfg)
  43. assert isinstance(in_channels, list)
  44. self.in_channels = in_channels
  45. self.out_channels = out_channels
  46. self.num_ins = len(in_channels) # num of input feature levels
  47. self.num_outs = num_outs # num of output feature levels
  48. self.stack_times = stack_times
  49. self.norm_cfg = norm_cfg
  50. if end_level == -1 or end_level == self.num_ins - 1:
  51. self.backbone_end_level = self.num_ins
  52. assert num_outs >= self.num_ins - start_level
  53. else:
  54. # if end_level is not the last level, no extra level is allowed
  55. self.backbone_end_level = end_level + 1
  56. assert end_level < self.num_ins
  57. assert num_outs == end_level - start_level + 1
  58. self.start_level = start_level
  59. self.end_level = end_level
  60. # add lateral connections
  61. self.lateral_convs = nn.ModuleList()
  62. for i in range(self.start_level, self.backbone_end_level):
  63. l_conv = ConvModule(
  64. in_channels[i],
  65. out_channels,
  66. 1,
  67. norm_cfg=norm_cfg,
  68. act_cfg=None)
  69. self.lateral_convs.append(l_conv)
  70. # add extra downsample layers (stride-2 pooling or conv)
  71. extra_levels = num_outs - self.backbone_end_level + self.start_level
  72. self.extra_downsamples = nn.ModuleList()
  73. for i in range(extra_levels):
  74. extra_conv = ConvModule(
  75. out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
  76. self.extra_downsamples.append(
  77. nn.Sequential(extra_conv, nn.MaxPool2d(2, 2)))
  78. # add NAS FPN connections
  79. self.fpn_stages = ModuleList()
  80. for _ in range(self.stack_times):
  81. stage = nn.ModuleDict()
  82. # gp(p6, p4) -> p4_1
  83. stage['gp_64_4'] = GlobalPoolingCell(
  84. in_channels=out_channels,
  85. out_channels=out_channels,
  86. out_norm_cfg=norm_cfg)
  87. # sum(p4_1, p4) -> p4_2
  88. stage['sum_44_4'] = SumCell(
  89. in_channels=out_channels,
  90. out_channels=out_channels,
  91. out_norm_cfg=norm_cfg)
  92. # sum(p4_2, p3) -> p3_out
  93. stage['sum_43_3'] = SumCell(
  94. in_channels=out_channels,
  95. out_channels=out_channels,
  96. out_norm_cfg=norm_cfg)
  97. # sum(p3_out, p4_2) -> p4_out
  98. stage['sum_34_4'] = SumCell(
  99. in_channels=out_channels,
  100. out_channels=out_channels,
  101. out_norm_cfg=norm_cfg)
  102. # sum(p5, gp(p4_out, p3_out)) -> p5_out
  103. stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
  104. stage['sum_55_5'] = SumCell(
  105. in_channels=out_channels,
  106. out_channels=out_channels,
  107. out_norm_cfg=norm_cfg)
  108. # sum(p7, gp(p5_out, p4_2)) -> p7_out
  109. stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
  110. stage['sum_77_7'] = SumCell(
  111. in_channels=out_channels,
  112. out_channels=out_channels,
  113. out_norm_cfg=norm_cfg)
  114. # gp(p7_out, p5_out) -> p6_out
  115. stage['gp_75_6'] = GlobalPoolingCell(
  116. in_channels=out_channels,
  117. out_channels=out_channels,
  118. out_norm_cfg=norm_cfg)
  119. self.fpn_stages.append(stage)
  120. def forward(self, inputs: Tuple[Tensor]) -> tuple:
  121. """Forward function.
  122. Args:
  123. inputs (tuple[Tensor]): Features from the upstream network, each
  124. is a 4D-tensor.
  125. Returns:
  126. tuple: Feature maps, each is a 4D-tensor.
  127. """
  128. # build P3-P5
  129. feats = [
  130. lateral_conv(inputs[i + self.start_level])
  131. for i, lateral_conv in enumerate(self.lateral_convs)
  132. ]
  133. # build P6-P7 on top of P5
  134. for downsample in self.extra_downsamples:
  135. feats.append(downsample(feats[-1]))
  136. p3, p4, p5, p6, p7 = feats
  137. for stage in self.fpn_stages:
  138. # gp(p6, p4) -> p4_1
  139. p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
  140. # sum(p4_1, p4) -> p4_2
  141. p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
  142. # sum(p4_2, p3) -> p3_out
  143. p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
  144. # sum(p3_out, p4_2) -> p4_out
  145. p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
  146. # sum(p5, gp(p4_out, p3_out)) -> p5_out
  147. p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
  148. p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
  149. # sum(p7, gp(p5_out, p4_2)) -> p7_out
  150. p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
  151. p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
  152. # gp(p7_out, p5_out) -> p6_out
  153. p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
  154. return p3, p4, p5, p6, p7