hrfpn.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import ConvModule
  6. from mmengine.model import BaseModule
  7. from torch.utils.checkpoint import checkpoint
  8. from mmdet.registry import MODELS
  9. @MODELS.register_module()
  10. class HRFPN(BaseModule):
  11. """HRFPN (High Resolution Feature Pyramids)
  12. paper: `High-Resolution Representations for Labeling Pixels and Regions
  13. <https://arxiv.org/abs/1904.04514>`_.
  14. Args:
  15. in_channels (list): number of channels for each branch.
  16. out_channels (int): output channels of feature pyramids.
  17. num_outs (int): number of output stages.
  18. pooling_type (str): pooling for generating feature pyramids
  19. from {MAX, AVG}.
  20. conv_cfg (dict): dictionary to construct and config conv layer.
  21. norm_cfg (dict): dictionary to construct and config norm layer.
  22. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  23. memory while slowing down the training speed.
  24. stride (int): stride of 3x3 convolutional layers
  25. init_cfg (dict or list[dict], optional): Initialization config dict.
  26. """
  27. def __init__(self,
  28. in_channels,
  29. out_channels,
  30. num_outs=5,
  31. pooling_type='AVG',
  32. conv_cfg=None,
  33. norm_cfg=None,
  34. with_cp=False,
  35. stride=1,
  36. init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
  37. super(HRFPN, self).__init__(init_cfg)
  38. assert isinstance(in_channels, list)
  39. self.in_channels = in_channels
  40. self.out_channels = out_channels
  41. self.num_ins = len(in_channels)
  42. self.num_outs = num_outs
  43. self.with_cp = with_cp
  44. self.conv_cfg = conv_cfg
  45. self.norm_cfg = norm_cfg
  46. self.reduction_conv = ConvModule(
  47. sum(in_channels),
  48. out_channels,
  49. kernel_size=1,
  50. conv_cfg=self.conv_cfg,
  51. act_cfg=None)
  52. self.fpn_convs = nn.ModuleList()
  53. for i in range(self.num_outs):
  54. self.fpn_convs.append(
  55. ConvModule(
  56. out_channels,
  57. out_channels,
  58. kernel_size=3,
  59. padding=1,
  60. stride=stride,
  61. conv_cfg=self.conv_cfg,
  62. act_cfg=None))
  63. if pooling_type == 'MAX':
  64. self.pooling = F.max_pool2d
  65. else:
  66. self.pooling = F.avg_pool2d
  67. def forward(self, inputs):
  68. """Forward function."""
  69. assert len(inputs) == self.num_ins
  70. outs = [inputs[0]]
  71. for i in range(1, self.num_ins):
  72. outs.append(
  73. F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
  74. out = torch.cat(outs, dim=1)
  75. if out.requires_grad and self.with_cp:
  76. out = checkpoint(self.reduction_conv, out)
  77. else:
  78. out = self.reduction_conv(out)
  79. outs = [out]
  80. for i in range(1, self.num_outs):
  81. outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
  82. outputs = []
  83. for i in range(self.num_outs):
  84. if outs[i].requires_grad and self.with_cp:
  85. tmp_out = checkpoint(self.fpn_convs[i], outs[i])
  86. else:
  87. tmp_out = self.fpn_convs[i](outs[i])
  88. outputs.append(tmp_out)
  89. return tuple(outputs)