ssd_neck.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  5. from mmengine.model import BaseModule
  6. from mmdet.registry import MODELS
  7. @MODELS.register_module()
  8. class SSDNeck(BaseModule):
  9. """Extra layers of SSD backbone to generate multi-scale feature maps.
  10. Args:
  11. in_channels (Sequence[int]): Number of input channels per scale.
  12. out_channels (Sequence[int]): Number of output channels per scale.
  13. level_strides (Sequence[int]): Stride of 3x3 conv per level.
  14. level_paddings (Sequence[int]): Padding size of 3x3 conv per level.
  15. l2_norm_scale (float|None): L2 normalization layer init scale.
  16. If None, not use L2 normalization on the first input feature.
  17. last_kernel_size (int): Kernel size of the last conv layer.
  18. Default: 3.
  19. use_depthwise (bool): Whether to use DepthwiseSeparableConv.
  20. Default: False.
  21. conv_cfg (dict): Config dict for convolution layer. Default: None.
  22. norm_cfg (dict): Dictionary to construct and config norm layer.
  23. Default: None.
  24. act_cfg (dict): Config dict for activation layer.
  25. Default: dict(type='ReLU').
  26. init_cfg (dict or list[dict], optional): Initialization config dict.
  27. """
  28. def __init__(self,
  29. in_channels,
  30. out_channels,
  31. level_strides,
  32. level_paddings,
  33. l2_norm_scale=20.,
  34. last_kernel_size=3,
  35. use_depthwise=False,
  36. conv_cfg=None,
  37. norm_cfg=None,
  38. act_cfg=dict(type='ReLU'),
  39. init_cfg=[
  40. dict(
  41. type='Xavier', distribution='uniform',
  42. layer='Conv2d'),
  43. dict(type='Constant', val=1, layer='BatchNorm2d'),
  44. ]):
  45. super(SSDNeck, self).__init__(init_cfg)
  46. assert len(out_channels) > len(in_channels)
  47. assert len(out_channels) - len(in_channels) == len(level_strides)
  48. assert len(level_strides) == len(level_paddings)
  49. assert in_channels == out_channels[:len(in_channels)]
  50. if l2_norm_scale:
  51. self.l2_norm = L2Norm(in_channels[0], l2_norm_scale)
  52. self.init_cfg += [
  53. dict(
  54. type='Constant',
  55. val=self.l2_norm.scale,
  56. override=dict(name='l2_norm'))
  57. ]
  58. self.extra_layers = nn.ModuleList()
  59. extra_layer_channels = out_channels[len(in_channels):]
  60. second_conv = DepthwiseSeparableConvModule if \
  61. use_depthwise else ConvModule
  62. for i, (out_channel, stride, padding) in enumerate(
  63. zip(extra_layer_channels, level_strides, level_paddings)):
  64. kernel_size = last_kernel_size \
  65. if i == len(extra_layer_channels) - 1 else 3
  66. per_lvl_convs = nn.Sequential(
  67. ConvModule(
  68. out_channels[len(in_channels) - 1 + i],
  69. out_channel // 2,
  70. 1,
  71. conv_cfg=conv_cfg,
  72. norm_cfg=norm_cfg,
  73. act_cfg=act_cfg),
  74. second_conv(
  75. out_channel // 2,
  76. out_channel,
  77. kernel_size,
  78. stride=stride,
  79. padding=padding,
  80. conv_cfg=conv_cfg,
  81. norm_cfg=norm_cfg,
  82. act_cfg=act_cfg))
  83. self.extra_layers.append(per_lvl_convs)
  84. def forward(self, inputs):
  85. """Forward function."""
  86. outs = [feat for feat in inputs]
  87. if hasattr(self, 'l2_norm'):
  88. outs[0] = self.l2_norm(outs[0])
  89. feat = outs[-1]
  90. for layer in self.extra_layers:
  91. feat = layer(feat)
  92. outs.append(feat)
  93. return tuple(outs)
  94. class L2Norm(nn.Module):
  95. def __init__(self, n_dims, scale=20., eps=1e-10):
  96. """L2 normalization layer.
  97. Args:
  98. n_dims (int): Number of dimensions to be normalized
  99. scale (float, optional): Defaults to 20..
  100. eps (float, optional): Used to avoid division by zero.
  101. Defaults to 1e-10.
  102. """
  103. super(L2Norm, self).__init__()
  104. self.n_dims = n_dims
  105. self.weight = nn.Parameter(torch.Tensor(self.n_dims))
  106. self.eps = eps
  107. self.scale = scale
  108. def forward(self, x):
  109. """Forward function."""
  110. # normalization layer convert to FP32 in FP16 training
  111. x_float = x.float()
  112. norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
  113. return (self.weight[None, :, None, None].float().expand_as(x_float) *
  114. x_float / norm).type_as(x)