simple_fpn.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import ConvModule, build_norm_layer
  6. from mmengine.model import BaseModule
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import MultiConfig, OptConfigType
  10. @MODELS.register_module()
  11. class SimpleFPN(BaseModule):
  12. """Simple Feature Pyramid Network for ViTDet."""
  13. def __init__(self,
  14. backbone_channel: int,
  15. in_channels: List[int],
  16. out_channels: int,
  17. num_outs: int,
  18. conv_cfg: OptConfigType = None,
  19. norm_cfg: OptConfigType = None,
  20. act_cfg: OptConfigType = None,
  21. init_cfg: MultiConfig = None) -> None:
  22. super().__init__(init_cfg=init_cfg)
  23. assert isinstance(in_channels, list)
  24. self.backbone_channel = backbone_channel
  25. self.in_channels = in_channels
  26. self.out_channels = out_channels
  27. self.num_ins = len(in_channels)
  28. self.num_outs = num_outs
  29. self.fpn1 = nn.Sequential(
  30. nn.ConvTranspose2d(self.backbone_channel,
  31. self.backbone_channel // 2, 2, 2),
  32. build_norm_layer(norm_cfg, self.backbone_channel // 2)[1],
  33. nn.GELU(),
  34. nn.ConvTranspose2d(self.backbone_channel // 2,
  35. self.backbone_channel // 4, 2, 2))
  36. self.fpn2 = nn.Sequential(
  37. nn.ConvTranspose2d(self.backbone_channel,
  38. self.backbone_channel // 2, 2, 2))
  39. self.fpn3 = nn.Sequential(nn.Identity())
  40. self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
  41. self.lateral_convs = nn.ModuleList()
  42. self.fpn_convs = nn.ModuleList()
  43. for i in range(self.num_ins):
  44. l_conv = ConvModule(
  45. in_channels[i],
  46. out_channels,
  47. 1,
  48. conv_cfg=conv_cfg,
  49. norm_cfg=norm_cfg,
  50. act_cfg=act_cfg,
  51. inplace=False)
  52. fpn_conv = ConvModule(
  53. out_channels,
  54. out_channels,
  55. 3,
  56. padding=1,
  57. conv_cfg=conv_cfg,
  58. norm_cfg=norm_cfg,
  59. act_cfg=act_cfg,
  60. inplace=False)
  61. self.lateral_convs.append(l_conv)
  62. self.fpn_convs.append(fpn_conv)
  63. def forward(self, input: Tensor) -> tuple:
  64. """Forward function.
  65. Args:
  66. inputs (Tensor): Features from the upstream network, 4D-tensor
  67. Returns:
  68. tuple: Feature maps, each is a 4D-tensor.
  69. """
  70. # build FPN
  71. inputs = []
  72. inputs.append(self.fpn1(input))
  73. inputs.append(self.fpn2(input))
  74. inputs.append(self.fpn3(input))
  75. inputs.append(self.fpn4(input))
  76. # build laterals
  77. laterals = [
  78. lateral_conv(inputs[i])
  79. for i, lateral_conv in enumerate(self.lateral_convs)
  80. ]
  81. # build outputs
  82. # part 1: from original levels
  83. outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)]
  84. # part 2: add extra levels
  85. if self.num_outs > len(outs):
  86. for i in range(self.num_outs - self.num_ins):
  87. outs.append(F.max_pool2d(outs[-1], 1, stride=2))
  88. return tuple(outs)