123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule, build_norm_layer
- from mmengine.model import BaseModule
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.utils import MultiConfig, OptConfigType
- @MODELS.register_module()
- class SimpleFPN(BaseModule):
- """Simple Feature Pyramid Network for ViTDet."""
- def __init__(self,
- backbone_channel: int,
- in_channels: List[int],
- out_channels: int,
- num_outs: int,
- conv_cfg: OptConfigType = None,
- norm_cfg: OptConfigType = None,
- act_cfg: OptConfigType = None,
- init_cfg: MultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- assert isinstance(in_channels, list)
- self.backbone_channel = backbone_channel
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.num_ins = len(in_channels)
- self.num_outs = num_outs
- self.fpn1 = nn.Sequential(
- nn.ConvTranspose2d(self.backbone_channel,
- self.backbone_channel // 2, 2, 2),
- build_norm_layer(norm_cfg, self.backbone_channel // 2)[1],
- nn.GELU(),
- nn.ConvTranspose2d(self.backbone_channel // 2,
- self.backbone_channel // 4, 2, 2))
- self.fpn2 = nn.Sequential(
- nn.ConvTranspose2d(self.backbone_channel,
- self.backbone_channel // 2, 2, 2))
- self.fpn3 = nn.Sequential(nn.Identity())
- self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
- self.lateral_convs = nn.ModuleList()
- self.fpn_convs = nn.ModuleList()
- for i in range(self.num_ins):
- l_conv = ConvModule(
- in_channels[i],
- out_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- inplace=False)
- fpn_conv = ConvModule(
- out_channels,
- out_channels,
- 3,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- inplace=False)
- self.lateral_convs.append(l_conv)
- self.fpn_convs.append(fpn_conv)
- def forward(self, input: Tensor) -> tuple:
- """Forward function.
- Args:
- inputs (Tensor): Features from the upstream network, 4D-tensor
- Returns:
- tuple: Feature maps, each is a 4D-tensor.
- """
- # build FPN
- inputs = []
- inputs.append(self.fpn1(input))
- inputs.append(self.fpn2(input))
- inputs.append(self.fpn3(input))
- inputs.append(self.fpn4(input))
- # build laterals
- laterals = [
- lateral_conv(inputs[i])
- for i, lateral_conv in enumerate(self.lateral_convs)
- ]
- # build outputs
- # part 1: from original levels
- outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)]
- # part 2: add extra levels
- if self.num_outs > len(outs):
- for i in range(self.num_outs - self.num_ins):
- outs.append(F.max_pool2d(outs[-1], 1, stride=2))
- return tuple(outs)
|