yolov8.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from typing import List, Union,Tuple
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import build_plugin_layer
  5. from mmdet.utils import ConfigType, OptMultiConfig
  6. from mmcv.cnn import ConvModule
  7. from ..utils import make_divisible, make_round
  8. from ..layers import CSPLayerWithTwoConv, SPPFBottleneck
  9. from mmengine.model import BaseModule
  10. from mmengine.registry import MODELS
  11. from torch.nn.modules.batchnorm import _BatchNorm
  12. @MODELS.register_module()
  13. class YOLOv8CSPDarknet(BaseModule):
  14. """CSP-Darknet backbone used in YOLOv8.
  15. Args:
  16. arch (str): Architecture of CSP-Darknet, from {P5}.
  17. Defaults to P5.
  18. last_stage_out_channels (int): Final layer output channel.
  19. Defaults to 1024.
  20. plugins (list[dict]): List of plugins for stages, each dict contains:
  21. - cfg (dict, required): Cfg dict to build plugin.
  22. - stages (tuple[bool], optional): Stages to apply plugin, length
  23. should be same as 'num_stages'.
  24. deepen_factor (float): Depth multiplier, multiply number of
  25. blocks in CSP layer by this amount. Defaults to 1.0.
  26. widen_factor (float): Width multiplier, multiply number of
  27. channels in each layer by this amount. Defaults to 1.0.
  28. input_channels (int): Number of input image channels. Defaults to: 3.
  29. out_indices (Tuple[int]): Output from which stages.
  30. Defaults to (2, 3, 4).
  31. frozen_stages (int): Stages to be frozen (stop grad and set eval
  32. mode). -1 means not freezing any parameters. Defaults to -1.
  33. norm_cfg (dict): Dictionary to construct and config norm layer.
  34. Defaults to dict(type='BN', requires_grad=True).
  35. act_cfg (dict): Config dict for activation layer.
  36. Defaults to dict(type='SiLU', inplace=True).
  37. norm_eval (bool): Whether to set norm layers to eval mode, namely,
  38. freeze running stats (mean and var). Note: Effect on Batch Norm
  39. and its variants only. Defaults to False.
  40. init_cfg (Union[dict,list[dict]], optional): Initialization config
  41. dict. Defaults to None.
  42. Example:
  43. >>> from mmyolo.models import YOLOv8CSPDarknet
  44. >>> import torch
  45. >>> model = YOLOv8CSPDarknet()
  46. >>> model.eval()
  47. >>> inputs = torch.rand(1, 3, 416, 416)
  48. >>> level_outputs = model(inputs)
  49. >>> for level_out in level_outputs:
  50. ... print(tuple(level_out.shape))
  51. ...
  52. (1, 256, 52, 52)
  53. (1, 512, 26, 26)
  54. (1, 1024, 13, 13)
  55. """
  56. # From left to right:
  57. # in_channels, out_channels, num_blocks, add_identity, use_spp
  58. # the final out_channels will be set according to the param.
  59. arch_settings = [[64, 128, 3, True, False], [128, 256, 6, True, False],
  60. [256, 512, 6, True, False], [512, None, 3, True, True]]
  61. def __init__(self,
  62. last_stage_out_channels: int = 1024,
  63. plugins: Union[dict, List[dict]] = None,
  64. deepen_factor: float = 1.0,
  65. widen_factor: float = 1.0,
  66. input_channels: int = 3,
  67. out_indices: List[int] = (2, 3, 4),
  68. frozen_stages: int = -1,
  69. norm_cfg: ConfigType = dict(
  70. type='BN', momentum=0.03, eps=0.001),
  71. act_cfg: ConfigType = dict(type='SiLU', inplace=True),
  72. norm_eval: bool = False,
  73. init_cfg: OptMultiConfig = None):
  74. super().__init__()
  75. self.arch_settings[-1][1] = last_stage_out_channels
  76. self.num_stages = len(self.arch_settings)
  77. self.input_channels = input_channels
  78. self.out_indices = out_indices
  79. self.frozen_stages = frozen_stages
  80. self.widen_factor = widen_factor
  81. self.deepen_factor = deepen_factor
  82. self.norm_eval = norm_eval
  83. self.norm_cfg = norm_cfg
  84. self.act_cfg = act_cfg
  85. self.plugins = plugins
  86. self.stem = self.build_stem_layer()
  87. self.layers = ['stem']
  88. for idx, setting in enumerate(self.arch_settings):
  89. stage = []
  90. stage += self.build_stage_layer(idx, setting)
  91. if plugins is not None:
  92. stage += self.make_stage_plugins(plugins, idx, setting)
  93. self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
  94. self.layers.append(f'stage{idx + 1}')
  95. # self.arch_settings[arch],
  96. # deepen_factor,
  97. # widen_factor,
  98. # input_channels=input_channels,
  99. # out_indices=out_indices,
  100. # plugins=plugins,
  101. # frozen_stages=frozen_stages,
  102. # norm_cfg=norm_cfg,
  103. # act_cfg=act_cfg,
  104. # norm_eval=norm_eval,
  105. # init_cfg=init_cfg
  106. def build_stem_layer(self) -> nn.Module:
  107. """Build a stem layer."""
  108. return ConvModule(
  109. self.input_channels,
  110. make_divisible(self.arch_settings[0][0], self.widen_factor),
  111. kernel_size=3,
  112. stride=2,
  113. padding=1,
  114. norm_cfg=self.norm_cfg,
  115. act_cfg=self.act_cfg)
  116. def build_stage_layer(self, stage_idx: int, setting: list) -> list:
  117. """Build a stage layer.
  118. Args:
  119. stage_idx (int): The index of a stage layer.
  120. setting (list): The architecture setting of a stage layer.
  121. """
  122. in_channels, out_channels, num_blocks, add_identity, use_spp = setting
  123. in_channels = make_divisible(in_channels, self.widen_factor)
  124. out_channels = make_divisible(out_channels, self.widen_factor)
  125. print(out_channels)
  126. num_blocks = make_round(num_blocks, self.deepen_factor)
  127. stage = []
  128. conv_layer = ConvModule(
  129. in_channels,
  130. out_channels,
  131. kernel_size=3,
  132. stride=2,
  133. padding=1,
  134. norm_cfg=self.norm_cfg,
  135. act_cfg=self.act_cfg)
  136. stage.append(conv_layer)
  137. csp_layer = CSPLayerWithTwoConv(
  138. out_channels,
  139. out_channels,
  140. num_blocks=num_blocks,
  141. add_identity=add_identity,
  142. norm_cfg=self.norm_cfg,
  143. act_cfg=self.act_cfg)
  144. stage.append(csp_layer)
  145. if use_spp:
  146. spp = SPPFBottleneck(
  147. out_channels,
  148. out_channels,
  149. kernel_sizes=5,
  150. norm_cfg=self.norm_cfg,
  151. act_cfg=self.act_cfg)
  152. stage.append(spp)
  153. return stage
  154. def make_stage_plugins(self, plugins, stage_idx, setting):
  155. """Make plugins for backbone ``stage_idx`` th stage.
  156. Currently we support to insert ``context_block``,
  157. ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
  158. into the backbone.
  159. An example of plugins format could be:
  160. Examples:
  161. >>> plugins=[
  162. ... dict(cfg=dict(type='xxx', arg1='xxx'),
  163. ... stages=(False, True, True, True)),
  164. ... dict(cfg=dict(type='yyy'),
  165. ... stages=(True, True, True, True)),
  166. ... ]
  167. >>> model = YOLOv5CSPDarknet()
  168. >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
  169. >>> assert len(stage_plugins) == 1
  170. Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
  171. .. code-block:: none
  172. conv1 -> conv2 -> conv3 -> yyy
  173. Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
  174. .. code-block:: none
  175. conv1 -> conv2 -> conv3 -> xxx -> yyy
  176. Args:
  177. plugins (list[dict]): List of plugins cfg to build. The postfix is
  178. required if multiple same type plugins are inserted.
  179. stage_idx (int): Index of stage to build
  180. If stages is missing, the plugin would be applied to all
  181. stages.
  182. setting (list): The architecture setting of a stage layer.
  183. Returns:
  184. list[nn.Module]: Plugins for current stage
  185. """
  186. # TODO: It is not general enough to support any channel and needs
  187. # to be refactored
  188. in_channels = int(setting[1] * self.widen_factor)
  189. plugin_layers = []
  190. for plugin in plugins:
  191. plugin = plugin.copy()
  192. stages = plugin.pop('stages', None)
  193. assert stages is None or len(stages) == self.num_stages
  194. if stages is None or stages[stage_idx]:
  195. name, layer = build_plugin_layer(
  196. plugin['cfg'], in_channels=in_channels)
  197. plugin_layers.append(layer)
  198. return plugin_layers
  199. def init_weights(self):
  200. """Initialize the parameters."""
  201. if self.init_cfg is None:
  202. for m in self.modules():
  203. if isinstance(m, torch.nn.Conv2d):
  204. # In order to be consistent with the source code,
  205. # reset the Conv2d initialization parameters
  206. m.reset_parameters()
  207. else:
  208. super().init_weights()
  209. def make_stage_plugins(self, plugins, stage_idx, setting):
  210. """Make plugins for backbone ``stage_idx`` th stage.
  211. Currently we support to insert ``context_block``,
  212. ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
  213. into the backbone.
  214. An example of plugins format could be:
  215. Examples:
  216. >>> plugins=[
  217. ... dict(cfg=dict(type='xxx', arg1='xxx'),
  218. ... stages=(False, True, True, True)),
  219. ... dict(cfg=dict(type='yyy'),
  220. ... stages=(True, True, True, True)),
  221. ... ]
  222. >>> model = YOLOv5CSPDarknet()
  223. >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
  224. >>> assert len(stage_plugins) == 1
  225. Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
  226. .. code-block:: none
  227. conv1 -> conv2 -> conv3 -> yyy
  228. Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
  229. .. code-block:: none
  230. conv1 -> conv2 -> conv3 -> xxx -> yyy
  231. Args:
  232. plugins (list[dict]): List of plugins cfg to build. The postfix is
  233. required if multiple same type plugins are inserted.
  234. stage_idx (int): Index of stage to build
  235. If stages is missing, the plugin would be applied to all
  236. stages.
  237. setting (list): The architecture setting of a stage layer.
  238. Returns:
  239. list[nn.Module]: Plugins for current stage
  240. """
  241. # TODO: It is not general enough to support any channel and needs
  242. # to be refactored
  243. in_channels = int(setting[1] * self.widen_factor)
  244. plugin_layers = []
  245. for plugin in plugins:
  246. plugin = plugin.copy()
  247. stages = plugin.pop('stages', None)
  248. assert stages is None or len(stages) == self.num_stages
  249. if stages is None or stages[stage_idx]:
  250. layer = build_plugin_layer(
  251. plugin['cfg'], in_channels=in_channels)[1]
  252. plugin_layers.append(layer)
  253. return plugin_layers
  254. def forward(self, x: torch.Tensor) -> tuple:
  255. """Forward batch_inputs from the data_preprocessor."""
  256. outs = []
  257. for i, layer_name in enumerate(self.layers):
  258. layer = getattr(self, layer_name)
  259. x = layer(x)
  260. if i in self.out_indices:
  261. outs.append(x)
  262. return tuple(outs)