nasfcos_fpn.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmcv.ops.merge_cells import ConcatCell
  6. from mmengine.model import BaseModule, caffe2_xavier_init
  7. from mmdet.registry import MODELS
  8. @MODELS.register_module()
  9. class NASFCOS_FPN(BaseModule):
  10. """FPN structure in NASFPN.
  11. Implementation of paper `NAS-FCOS: Fast Neural Architecture Search for
  12. Object Detection <https://arxiv.org/abs/1906.04423>`_
  13. Args:
  14. in_channels (List[int]): Number of input channels per scale.
  15. out_channels (int): Number of output channels (used at each scale)
  16. num_outs (int): Number of output scales.
  17. start_level (int): Index of the start input backbone level used to
  18. build the feature pyramid. Default: 0.
  19. end_level (int): Index of the end input backbone level (exclusive) to
  20. build the feature pyramid. Default: -1, which means the last level.
  21. add_extra_convs (bool): It decides whether to add conv
  22. layers on top of the original feature maps. Default to False.
  23. If True, its actual mode is specified by `extra_convs_on_inputs`.
  24. conv_cfg (dict): dictionary to construct and config conv layer.
  25. norm_cfg (dict): dictionary to construct and config norm layer.
  26. init_cfg (dict or list[dict], optional): Initialization config dict.
  27. Default: None
  28. """
  29. def __init__(self,
  30. in_channels,
  31. out_channels,
  32. num_outs,
  33. start_level=1,
  34. end_level=-1,
  35. add_extra_convs=False,
  36. conv_cfg=None,
  37. norm_cfg=None,
  38. init_cfg=None):
  39. assert init_cfg is None, 'To prevent abnormal initialization ' \
  40. 'behavior, init_cfg is not allowed to be set'
  41. super(NASFCOS_FPN, self).__init__(init_cfg)
  42. assert isinstance(in_channels, list)
  43. self.in_channels = in_channels
  44. self.out_channels = out_channels
  45. self.num_ins = len(in_channels)
  46. self.num_outs = num_outs
  47. self.norm_cfg = norm_cfg
  48. self.conv_cfg = conv_cfg
  49. if end_level == -1 or end_level == self.num_ins - 1:
  50. self.backbone_end_level = self.num_ins
  51. assert num_outs >= self.num_ins - start_level
  52. else:
  53. # if end_level is not the last level, no extra level is allowed
  54. self.backbone_end_level = end_level + 1
  55. assert end_level < self.num_ins
  56. assert num_outs == end_level - start_level + 1
  57. self.start_level = start_level
  58. self.end_level = end_level
  59. self.add_extra_convs = add_extra_convs
  60. self.adapt_convs = nn.ModuleList()
  61. for i in range(self.start_level, self.backbone_end_level):
  62. adapt_conv = ConvModule(
  63. in_channels[i],
  64. out_channels,
  65. 1,
  66. stride=1,
  67. padding=0,
  68. bias=False,
  69. norm_cfg=dict(type='BN'),
  70. act_cfg=dict(type='ReLU', inplace=False))
  71. self.adapt_convs.append(adapt_conv)
  72. # C2 is omitted according to the paper
  73. extra_levels = num_outs - self.backbone_end_level + self.start_level
  74. def build_concat_cell(with_input1_conv, with_input2_conv):
  75. cell_conv_cfg = dict(
  76. kernel_size=1, padding=0, bias=False, groups=out_channels)
  77. return ConcatCell(
  78. in_channels=out_channels,
  79. out_channels=out_channels,
  80. with_out_conv=True,
  81. out_conv_cfg=cell_conv_cfg,
  82. out_norm_cfg=dict(type='BN'),
  83. out_conv_order=('norm', 'act', 'conv'),
  84. with_input1_conv=with_input1_conv,
  85. with_input2_conv=with_input2_conv,
  86. input_conv_cfg=conv_cfg,
  87. input_norm_cfg=norm_cfg,
  88. upsample_mode='nearest')
  89. # Denote c3=f0, c4=f1, c5=f2 for convince
  90. self.fpn = nn.ModuleDict()
  91. self.fpn['c22_1'] = build_concat_cell(True, True)
  92. self.fpn['c22_2'] = build_concat_cell(True, True)
  93. self.fpn['c32'] = build_concat_cell(True, False)
  94. self.fpn['c02'] = build_concat_cell(True, False)
  95. self.fpn['c42'] = build_concat_cell(True, True)
  96. self.fpn['c36'] = build_concat_cell(True, True)
  97. self.fpn['c61'] = build_concat_cell(True, True) # f9
  98. self.extra_downsamples = nn.ModuleList()
  99. for i in range(extra_levels):
  100. extra_act_cfg = None if i == 0 \
  101. else dict(type='ReLU', inplace=False)
  102. self.extra_downsamples.append(
  103. ConvModule(
  104. out_channels,
  105. out_channels,
  106. 3,
  107. stride=2,
  108. padding=1,
  109. act_cfg=extra_act_cfg,
  110. order=('act', 'norm', 'conv')))
  111. def forward(self, inputs):
  112. """Forward function."""
  113. feats = [
  114. adapt_conv(inputs[i + self.start_level])
  115. for i, adapt_conv in enumerate(self.adapt_convs)
  116. ]
  117. for (i, module_name) in enumerate(self.fpn):
  118. idx_1, idx_2 = int(module_name[1]), int(module_name[2])
  119. res = self.fpn[module_name](feats[idx_1], feats[idx_2])
  120. feats.append(res)
  121. ret = []
  122. for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5
  123. feats1, feats2 = feats[idx], feats[5]
  124. feats2_resize = F.interpolate(
  125. feats2,
  126. size=feats1.size()[2:],
  127. mode='bilinear',
  128. align_corners=False)
  129. feats_sum = feats1 + feats2_resize
  130. ret.append(
  131. F.interpolate(
  132. feats_sum,
  133. size=inputs[input_idx].size()[2:],
  134. mode='bilinear',
  135. align_corners=False))
  136. for submodule in self.extra_downsamples:
  137. ret.append(submodule(ret[-1]))
  138. return tuple(ret)
  139. def init_weights(self):
  140. """Initialize the weights of module."""
  141. super(NASFCOS_FPN, self).init_weights()
  142. for module in self.fpn.values():
  143. if hasattr(module, 'conv_out'):
  144. caffe2_xavier_init(module.out_conv.conv)
  145. for modules in [
  146. self.adapt_convs.modules(),
  147. self.extra_downsamples.modules()
  148. ]:
  149. for module in modules:
  150. if isinstance(module, nn.Conv2d):
  151. caffe2_xavier_init(module)