yolox_pafpn.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  6. from mmengine.model import BaseModule
  7. from mmdet.registry import MODELS
  8. from ..layers import CSPLayer
  9. @MODELS.register_module()
  10. class YOLOXPAFPN(BaseModule):
  11. """Path Aggregation Network used in YOLOX.
  12. Args:
  13. in_channels (List[int]): Number of input channels per scale.
  14. out_channels (int): Number of output channels (used at each scale)
  15. num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3
  16. use_depthwise (bool): Whether to depthwise separable convolution in
  17. blocks. Default: False
  18. upsample_cfg (dict): Config dict for interpolate layer.
  19. Default: `dict(scale_factor=2, mode='nearest')`
  20. conv_cfg (dict, optional): Config dict for convolution layer.
  21. Default: None, which means using conv2d.
  22. norm_cfg (dict): Config dict for normalization layer.
  23. Default: dict(type='BN')
  24. act_cfg (dict): Config dict for activation layer.
  25. Default: dict(type='Swish')
  26. init_cfg (dict or list[dict], optional): Initialization config dict.
  27. Default: None.
  28. """
  29. def __init__(self,
  30. in_channels,
  31. out_channels,
  32. num_csp_blocks=3,
  33. use_depthwise=False,
  34. upsample_cfg=dict(scale_factor=2, mode='nearest'),
  35. conv_cfg=None,
  36. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  37. act_cfg=dict(type='Swish'),
  38. init_cfg=dict(
  39. type='Kaiming',
  40. layer='Conv2d',
  41. a=math.sqrt(5),
  42. distribution='uniform',
  43. mode='fan_in',
  44. nonlinearity='leaky_relu')):
  45. super(YOLOXPAFPN, self).__init__(init_cfg)
  46. self.in_channels = in_channels
  47. self.out_channels = out_channels
  48. conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
  49. # build top-down blocks
  50. self.upsample = nn.Upsample(**upsample_cfg)
  51. self.reduce_layers = nn.ModuleList()
  52. self.top_down_blocks = nn.ModuleList()
  53. for idx in range(len(in_channels) - 1, 0, -1):
  54. self.reduce_layers.append(
  55. ConvModule(
  56. in_channels[idx],
  57. in_channels[idx - 1],
  58. 1,
  59. conv_cfg=conv_cfg,
  60. norm_cfg=norm_cfg,
  61. act_cfg=act_cfg))
  62. self.top_down_blocks.append(
  63. CSPLayer(
  64. in_channels[idx - 1] * 2,
  65. in_channels[idx - 1],
  66. num_blocks=num_csp_blocks,
  67. add_identity=False,
  68. use_depthwise=use_depthwise,
  69. conv_cfg=conv_cfg,
  70. norm_cfg=norm_cfg,
  71. act_cfg=act_cfg))
  72. # build bottom-up blocks
  73. self.downsamples = nn.ModuleList()
  74. self.bottom_up_blocks = nn.ModuleList()
  75. for idx in range(len(in_channels) - 1):
  76. self.downsamples.append(
  77. conv(
  78. in_channels[idx],
  79. in_channels[idx],
  80. 3,
  81. stride=2,
  82. padding=1,
  83. conv_cfg=conv_cfg,
  84. norm_cfg=norm_cfg,
  85. act_cfg=act_cfg))
  86. self.bottom_up_blocks.append(
  87. CSPLayer(
  88. in_channels[idx] * 2,
  89. in_channels[idx + 1],
  90. num_blocks=num_csp_blocks,
  91. add_identity=False,
  92. use_depthwise=use_depthwise,
  93. conv_cfg=conv_cfg,
  94. norm_cfg=norm_cfg,
  95. act_cfg=act_cfg))
  96. self.out_convs = nn.ModuleList()
  97. for i in range(len(in_channels)):
  98. self.out_convs.append(
  99. ConvModule(
  100. in_channels[i],
  101. out_channels,
  102. 1,
  103. conv_cfg=conv_cfg,
  104. norm_cfg=norm_cfg,
  105. act_cfg=act_cfg))
  106. def forward(self, inputs):
  107. """
  108. Args:
  109. inputs (tuple[Tensor]): input features.
  110. Returns:
  111. tuple[Tensor]: YOLOXPAFPN features.
  112. """
  113. assert len(inputs) == len(self.in_channels)
  114. # top-down path
  115. inner_outs = [inputs[-1]]
  116. for idx in range(len(self.in_channels) - 1, 0, -1):
  117. feat_heigh = inner_outs[0]
  118. feat_low = inputs[idx - 1]
  119. feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx](
  120. feat_heigh)
  121. inner_outs[0] = feat_heigh
  122. upsample_feat = self.upsample(feat_heigh)
  123. inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
  124. torch.cat([upsample_feat, feat_low], 1))
  125. inner_outs.insert(0, inner_out)
  126. # bottom-up path
  127. outs = [inner_outs[0]]
  128. for idx in range(len(self.in_channels) - 1):
  129. feat_low = outs[-1]
  130. feat_height = inner_outs[idx + 1]
  131. downsample_feat = self.downsamples[idx](feat_low)
  132. out = self.bottom_up_blocks[idx](
  133. torch.cat([downsample_feat, feat_height], 1))
  134. outs.append(out)
  135. # out convs
  136. for idx, conv in enumerate(self.out_convs):
  137. outs[idx] = conv(outs[idx])
  138. return tuple(outs)