123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.model import BaseModule
- from torch.utils.checkpoint import checkpoint
- from mmdet.registry import MODELS
- @MODELS.register_module()
- class HRFPN(BaseModule):
- """HRFPN (High Resolution Feature Pyramids)
- paper: `High-Resolution Representations for Labeling Pixels and Regions
- <https://arxiv.org/abs/1904.04514>`_.
- Args:
- in_channels (list): number of channels for each branch.
- out_channels (int): output channels of feature pyramids.
- num_outs (int): number of output stages.
- pooling_type (str): pooling for generating feature pyramids
- from {MAX, AVG}.
- conv_cfg (dict): dictionary to construct and config conv layer.
- norm_cfg (dict): dictionary to construct and config norm layer.
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
- memory while slowing down the training speed.
- stride (int): stride of 3x3 convolutional layers
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- def __init__(self,
- in_channels,
- out_channels,
- num_outs=5,
- pooling_type='AVG',
- conv_cfg=None,
- norm_cfg=None,
- with_cp=False,
- stride=1,
- init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
- super(HRFPN, self).__init__(init_cfg)
- assert isinstance(in_channels, list)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.num_ins = len(in_channels)
- self.num_outs = num_outs
- self.with_cp = with_cp
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.reduction_conv = ConvModule(
- sum(in_channels),
- out_channels,
- kernel_size=1,
- conv_cfg=self.conv_cfg,
- act_cfg=None)
- self.fpn_convs = nn.ModuleList()
- for i in range(self.num_outs):
- self.fpn_convs.append(
- ConvModule(
- out_channels,
- out_channels,
- kernel_size=3,
- padding=1,
- stride=stride,
- conv_cfg=self.conv_cfg,
- act_cfg=None))
- if pooling_type == 'MAX':
- self.pooling = F.max_pool2d
- else:
- self.pooling = F.avg_pool2d
- def forward(self, inputs):
- """Forward function."""
- assert len(inputs) == self.num_ins
- outs = [inputs[0]]
- for i in range(1, self.num_ins):
- outs.append(
- F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
- out = torch.cat(outs, dim=1)
- if out.requires_grad and self.with_cp:
- out = checkpoint(self.reduction_conv, out)
- else:
- out = self.reduction_conv(out)
- outs = [out]
- for i in range(1, self.num_outs):
- outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
- outputs = []
- for i in range(self.num_outs):
- if outs[i].requires_grad and self.with_cp:
- tmp_out = checkpoint(self.fpn_convs[i], outs[i])
- else:
- tmp_out = self.fpn_convs[i](outs[i])
- outputs.append(tmp_out)
- return tuple(outputs)
|