regnet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import numpy as np
  4. import torch.nn as nn
  5. from mmcv.cnn import build_conv_layer, build_norm_layer
  6. from mmdet.registry import MODELS
  7. from .resnet import ResNet
  8. from .resnext import Bottleneck
  9. @MODELS.register_module()
  10. class RegNet(ResNet):
  11. """RegNet backbone.
  12. More details can be found in `paper <https://arxiv.org/abs/2003.13678>`_ .
  13. Args:
  14. arch (dict): The parameter of RegNets.
  15. - w0 (int): initial width
  16. - wa (float): slope of width
  17. - wm (float): quantization parameter to quantize the width
  18. - depth (int): depth of the backbone
  19. - group_w (int): width of group
  20. - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
  21. strides (Sequence[int]): Strides of the first block of each stage.
  22. base_channels (int): Base channels after stem layer.
  23. in_channels (int): Number of input image channels. Default: 3.
  24. dilations (Sequence[int]): Dilation of each stage.
  25. out_indices (Sequence[int]): Output from which stages.
  26. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
  27. layer is the 3x3 conv layer, otherwise the stride-two layer is
  28. the first 1x1 conv layer.
  29. frozen_stages (int): Stages to be frozen (all param fixed). -1 means
  30. not freezing any parameters.
  31. norm_cfg (dict): dictionary to construct and config norm layer.
  32. norm_eval (bool): Whether to set norm layers to eval mode, namely,
  33. freeze running stats (mean and var). Note: Effect on Batch Norm
  34. and its variants only.
  35. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  36. memory while slowing down the training speed.
  37. zero_init_residual (bool): whether to use zero init for last norm layer
  38. in resblocks to let them behave as identity.
  39. pretrained (str, optional): model pretrained path. Default: None
  40. init_cfg (dict or list[dict], optional): Initialization config dict.
  41. Default: None
  42. Example:
  43. >>> from mmdet.models import RegNet
  44. >>> import torch
  45. >>> self = RegNet(
  46. arch=dict(
  47. w0=88,
  48. wa=26.31,
  49. wm=2.25,
  50. group_w=48,
  51. depth=25,
  52. bot_mul=1.0))
  53. >>> self.eval()
  54. >>> inputs = torch.rand(1, 3, 32, 32)
  55. >>> level_outputs = self.forward(inputs)
  56. >>> for level_out in level_outputs:
  57. ... print(tuple(level_out.shape))
  58. (1, 96, 8, 8)
  59. (1, 192, 4, 4)
  60. (1, 432, 2, 2)
  61. (1, 1008, 1, 1)
  62. """
  63. arch_settings = {
  64. 'regnetx_400mf':
  65. dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
  66. 'regnetx_800mf':
  67. dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
  68. 'regnetx_1.6gf':
  69. dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
  70. 'regnetx_3.2gf':
  71. dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
  72. 'regnetx_4.0gf':
  73. dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
  74. 'regnetx_6.4gf':
  75. dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
  76. 'regnetx_8.0gf':
  77. dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
  78. 'regnetx_12gf':
  79. dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
  80. }
  81. def __init__(self,
  82. arch,
  83. in_channels=3,
  84. stem_channels=32,
  85. base_channels=32,
  86. strides=(2, 2, 2, 2),
  87. dilations=(1, 1, 1, 1),
  88. out_indices=(0, 1, 2, 3),
  89. style='pytorch',
  90. deep_stem=False,
  91. avg_down=False,
  92. frozen_stages=-1,
  93. conv_cfg=None,
  94. norm_cfg=dict(type='BN', requires_grad=True),
  95. norm_eval=True,
  96. dcn=None,
  97. stage_with_dcn=(False, False, False, False),
  98. plugins=None,
  99. with_cp=False,
  100. zero_init_residual=True,
  101. pretrained=None,
  102. init_cfg=None):
  103. super(ResNet, self).__init__(init_cfg)
  104. # Generate RegNet parameters first
  105. if isinstance(arch, str):
  106. assert arch in self.arch_settings, \
  107. f'"arch": "{arch}" is not one of the' \
  108. ' arch_settings'
  109. arch = self.arch_settings[arch]
  110. elif not isinstance(arch, dict):
  111. raise ValueError('Expect "arch" to be either a string '
  112. f'or a dict, got {type(arch)}')
  113. widths, num_stages = self.generate_regnet(
  114. arch['w0'],
  115. arch['wa'],
  116. arch['wm'],
  117. arch['depth'],
  118. )
  119. # Convert to per stage format
  120. stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
  121. # Generate group widths and bot muls
  122. group_widths = [arch['group_w'] for _ in range(num_stages)]
  123. self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
  124. # Adjust the compatibility of stage_widths and group_widths
  125. stage_widths, group_widths = self.adjust_width_group(
  126. stage_widths, self.bottleneck_ratio, group_widths)
  127. # Group params by stage
  128. self.stage_widths = stage_widths
  129. self.group_widths = group_widths
  130. self.depth = sum(stage_blocks)
  131. self.stem_channels = stem_channels
  132. self.base_channels = base_channels
  133. self.num_stages = num_stages
  134. assert num_stages >= 1 and num_stages <= 4
  135. self.strides = strides
  136. self.dilations = dilations
  137. assert len(strides) == len(dilations) == num_stages
  138. self.out_indices = out_indices
  139. assert max(out_indices) < num_stages
  140. self.style = style
  141. self.deep_stem = deep_stem
  142. self.avg_down = avg_down
  143. self.frozen_stages = frozen_stages
  144. self.conv_cfg = conv_cfg
  145. self.norm_cfg = norm_cfg
  146. self.with_cp = with_cp
  147. self.norm_eval = norm_eval
  148. self.dcn = dcn
  149. self.stage_with_dcn = stage_with_dcn
  150. if dcn is not None:
  151. assert len(stage_with_dcn) == num_stages
  152. self.plugins = plugins
  153. self.zero_init_residual = zero_init_residual
  154. self.block = Bottleneck
  155. expansion_bak = self.block.expansion
  156. self.block.expansion = 1
  157. self.stage_blocks = stage_blocks[:num_stages]
  158. self._make_stem_layer(in_channels, stem_channels)
  159. block_init_cfg = None
  160. assert not (init_cfg and pretrained), \
  161. 'init_cfg and pretrained cannot be specified at the same time'
  162. if isinstance(pretrained, str):
  163. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  164. 'please use "init_cfg" instead')
  165. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  166. elif pretrained is None:
  167. if init_cfg is None:
  168. self.init_cfg = [
  169. dict(type='Kaiming', layer='Conv2d'),
  170. dict(
  171. type='Constant',
  172. val=1,
  173. layer=['_BatchNorm', 'GroupNorm'])
  174. ]
  175. if self.zero_init_residual:
  176. block_init_cfg = dict(
  177. type='Constant', val=0, override=dict(name='norm3'))
  178. else:
  179. raise TypeError('pretrained must be a str or None')
  180. self.inplanes = stem_channels
  181. self.res_layers = []
  182. for i, num_blocks in enumerate(self.stage_blocks):
  183. stride = self.strides[i]
  184. dilation = self.dilations[i]
  185. group_width = self.group_widths[i]
  186. width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
  187. stage_groups = width // group_width
  188. dcn = self.dcn if self.stage_with_dcn[i] else None
  189. if self.plugins is not None:
  190. stage_plugins = self.make_stage_plugins(self.plugins, i)
  191. else:
  192. stage_plugins = None
  193. res_layer = self.make_res_layer(
  194. block=self.block,
  195. inplanes=self.inplanes,
  196. planes=self.stage_widths[i],
  197. num_blocks=num_blocks,
  198. stride=stride,
  199. dilation=dilation,
  200. style=self.style,
  201. avg_down=self.avg_down,
  202. with_cp=self.with_cp,
  203. conv_cfg=self.conv_cfg,
  204. norm_cfg=self.norm_cfg,
  205. dcn=dcn,
  206. plugins=stage_plugins,
  207. groups=stage_groups,
  208. base_width=group_width,
  209. base_channels=self.stage_widths[i],
  210. init_cfg=block_init_cfg)
  211. self.inplanes = self.stage_widths[i]
  212. layer_name = f'layer{i + 1}'
  213. self.add_module(layer_name, res_layer)
  214. self.res_layers.append(layer_name)
  215. self._freeze_stages()
  216. self.feat_dim = stage_widths[-1]
  217. self.block.expansion = expansion_bak
  218. def _make_stem_layer(self, in_channels, base_channels):
  219. self.conv1 = build_conv_layer(
  220. self.conv_cfg,
  221. in_channels,
  222. base_channels,
  223. kernel_size=3,
  224. stride=2,
  225. padding=1,
  226. bias=False)
  227. self.norm1_name, norm1 = build_norm_layer(
  228. self.norm_cfg, base_channels, postfix=1)
  229. self.add_module(self.norm1_name, norm1)
  230. self.relu = nn.ReLU(inplace=True)
  231. def generate_regnet(self,
  232. initial_width,
  233. width_slope,
  234. width_parameter,
  235. depth,
  236. divisor=8):
  237. """Generates per block width from RegNet parameters.
  238. Args:
  239. initial_width ([int]): Initial width of the backbone
  240. width_slope ([float]): Slope of the quantized linear function
  241. width_parameter ([int]): Parameter used to quantize the width.
  242. depth ([int]): Depth of the backbone.
  243. divisor (int, optional): The divisor of channels. Defaults to 8.
  244. Returns:
  245. list, int: return a list of widths of each stage and the number \
  246. of stages
  247. """
  248. assert width_slope >= 0
  249. assert initial_width > 0
  250. assert width_parameter > 1
  251. assert initial_width % divisor == 0
  252. widths_cont = np.arange(depth) * width_slope + initial_width
  253. ks = np.round(
  254. np.log(widths_cont / initial_width) / np.log(width_parameter))
  255. widths = initial_width * np.power(width_parameter, ks)
  256. widths = np.round(np.divide(widths, divisor)) * divisor
  257. num_stages = len(np.unique(widths))
  258. widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
  259. return widths, num_stages
  260. @staticmethod
  261. def quantize_float(number, divisor):
  262. """Converts a float to closest non-zero int divisible by divisor.
  263. Args:
  264. number (int): Original number to be quantized.
  265. divisor (int): Divisor used to quantize the number.
  266. Returns:
  267. int: quantized number that is divisible by devisor.
  268. """
  269. return int(round(number / divisor) * divisor)
  270. def adjust_width_group(self, widths, bottleneck_ratio, groups):
  271. """Adjusts the compatibility of widths and groups.
  272. Args:
  273. widths (list[int]): Width of each stage.
  274. bottleneck_ratio (float): Bottleneck ratio.
  275. groups (int): number of groups in each stage
  276. Returns:
  277. tuple(list): The adjusted widths and groups of each stage.
  278. """
  279. bottleneck_width = [
  280. int(w * b) for w, b in zip(widths, bottleneck_ratio)
  281. ]
  282. groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
  283. bottleneck_width = [
  284. self.quantize_float(w_bot, g)
  285. for w_bot, g in zip(bottleneck_width, groups)
  286. ]
  287. widths = [
  288. int(w_bot / b)
  289. for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
  290. ]
  291. return widths, groups
  292. def get_stages_from_blocks(self, widths):
  293. """Gets widths/stage_blocks of network at each stage.
  294. Args:
  295. widths (list[int]): Width in each stage.
  296. Returns:
  297. tuple(list): width and depth of each stage
  298. """
  299. width_diff = [
  300. width != width_prev
  301. for width, width_prev in zip(widths + [0], [0] + widths)
  302. ]
  303. stage_widths = [
  304. width for width, diff in zip(widths, width_diff[:-1]) if diff
  305. ]
  306. stage_blocks = np.diff([
  307. depth for depth, diff in zip(range(len(width_diff)), width_diff)
  308. if diff
  309. ]).tolist()
  310. return stage_widths, stage_blocks
  311. def forward(self, x):
  312. """Forward function."""
  313. x = self.conv1(x)
  314. x = self.norm1(x)
  315. x = self.relu(x)
  316. outs = []
  317. for i, layer_name in enumerate(self.res_layers):
  318. res_layer = getattr(self, layer_name)
  319. x = res_layer(x)
  320. if i in self.out_indices:
  321. outs.append(x)
  322. return tuple(outs)