# Copyright (c) OpenMMLab. All rights reserved. from typing import List import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_norm_layer from mmengine.model import BaseModule from torch import Tensor from mmdet.registry import MODELS from mmdet.utils import MultiConfig, OptConfigType @MODELS.register_module() class SimpleFPN(BaseModule): """Simple Feature Pyramid Network for ViTDet.""" def __init__(self, backbone_channel: int, in_channels: List[int], out_channels: int, num_outs: int, conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = None, act_cfg: OptConfigType = None, init_cfg: MultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) assert isinstance(in_channels, list) self.backbone_channel = backbone_channel self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.fpn1 = nn.Sequential( nn.ConvTranspose2d(self.backbone_channel, self.backbone_channel // 2, 2, 2), build_norm_layer(norm_cfg, self.backbone_channel // 2)[1], nn.GELU(), nn.ConvTranspose2d(self.backbone_channel // 2, self.backbone_channel // 4, 2, 2)) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(self.backbone_channel, self.backbone_channel // 2, 2, 2)) self.fpn3 = nn.Sequential(nn.Identity()) self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2)) self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for i in range(self.num_ins): l_conv = ConvModule( in_channels[i], out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) def forward(self, input: Tensor) -> tuple: """Forward function. Args: inputs (Tensor): Features from the upstream network, 4D-tensor Returns: tuple: Feature maps, each is a 4D-tensor. """ # build FPN inputs = [] inputs.append(self.fpn1(input)) inputs.append(self.fpn2(input)) inputs.append(self.fpn3(input)) inputs.append(self.fpn4(input)) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] # build outputs # part 1: from original levels outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)] # part 2: add extra levels if self.num_outs > len(outs): for i in range(self.num_outs - self.num_ins): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) return tuple(outs)