rtmdet_l_8xb32_300e_coco.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
  3. # mmcv >= 2.0.1
  4. # mmengine >= 0.8.0
  5. from mmengine.config import read_base
  6. with read_base():
  7. from .._base_.default_runtime import *
  8. from .._base_.schedules.schedule_1x import *
  9. from .._base_.datasets.coco_detection import *
  10. from .rtmdet_tta import *
  11. from mmcv.ops import nms
  12. from mmcv.transforms.loading import LoadImageFromFile
  13. from mmcv.transforms.processing import RandomResize
  14. from mmengine.hooks.ema_hook import EMAHook
  15. from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
  16. from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR
  17. from torch.nn import SyncBatchNorm
  18. from torch.nn.modules.activation import SiLU
  19. from torch.optim.adamw import AdamW
  20. from mmdet.datasets.transforms.formatting import PackDetInputs
  21. from mmdet.datasets.transforms.loading import LoadAnnotations
  22. from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
  23. Pad, RandomCrop, RandomFlip,
  24. Resize, YOLOXHSVRandomAug)
  25. from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook
  26. from mmdet.models.backbones.cspnext import CSPNeXt
  27. from mmdet.models.data_preprocessors.data_preprocessor import \
  28. DetDataPreprocessor
  29. from mmdet.models.dense_heads.rtmdet_head import RTMDetSepBNHead
  30. from mmdet.models.detectors.rtmdet import RTMDet
  31. from mmdet.models.layers.ema import ExpMomentumEMA
  32. from mmdet.models.losses.gfocal_loss import QualityFocalLoss
  33. from mmdet.models.losses.iou_loss import GIoULoss
  34. from mmdet.models.necks.cspnext_pafpn import CSPNeXtPAFPN
  35. from mmdet.models.task_modules.assigners.dynamic_soft_label_assigner import \
  36. DynamicSoftLabelAssigner
  37. from mmdet.models.task_modules.coders.distance_point_bbox_coder import \
  38. DistancePointBBoxCoder
  39. from mmdet.models.task_modules.prior_generators.point_generator import \
  40. MlvlPointGenerator
  41. model = dict(
  42. type=RTMDet,
  43. data_preprocessor=dict(
  44. type=DetDataPreprocessor,
  45. mean=[103.53, 116.28, 123.675],
  46. std=[57.375, 57.12, 58.395],
  47. bgr_to_rgb=False,
  48. batch_augments=None),
  49. backbone=dict(
  50. type=CSPNeXt,
  51. arch='P5',
  52. expand_ratio=0.5,
  53. deepen_factor=1,
  54. widen_factor=1,
  55. channel_attention=True,
  56. norm_cfg=dict(type=SyncBatchNorm),
  57. act_cfg=dict(type=SiLU, inplace=True)),
  58. neck=dict(
  59. type=CSPNeXtPAFPN,
  60. in_channels=[256, 512, 1024],
  61. out_channels=256,
  62. num_csp_blocks=3,
  63. expand_ratio=0.5,
  64. norm_cfg=dict(type=SyncBatchNorm),
  65. act_cfg=dict(type=SiLU, inplace=True)),
  66. bbox_head=dict(
  67. type=RTMDetSepBNHead,
  68. num_classes=80,
  69. in_channels=256,
  70. stacked_convs=2,
  71. feat_channels=256,
  72. anchor_generator=dict(
  73. type=MlvlPointGenerator, offset=0, strides=[8, 16, 32]),
  74. bbox_coder=dict(type=DistancePointBBoxCoder),
  75. loss_cls=dict(
  76. type=QualityFocalLoss, use_sigmoid=True, beta=2.0,
  77. loss_weight=1.0),
  78. loss_bbox=dict(type=GIoULoss, loss_weight=2.0),
  79. with_objectness=False,
  80. exp_on_reg=True,
  81. share_conv=True,
  82. pred_kernel_size=1,
  83. norm_cfg=dict(type=SyncBatchNorm),
  84. act_cfg=dict(type=SiLU, inplace=True)),
  85. train_cfg=dict(
  86. assigner=dict(type=DynamicSoftLabelAssigner, topk=13),
  87. allowed_border=-1,
  88. pos_weight=-1,
  89. debug=False),
  90. test_cfg=dict(
  91. nms_pre=30000,
  92. min_bbox_size=0,
  93. score_thr=0.001,
  94. nms=dict(type=nms, iou_threshold=0.65),
  95. max_per_img=300),
  96. )
  97. train_pipeline = [
  98. dict(type=LoadImageFromFile, backend_args=backend_args),
  99. dict(type=LoadAnnotations, with_bbox=True),
  100. dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0),
  101. dict(
  102. type=RandomResize,
  103. scale=(1280, 1280),
  104. ratio_range=(0.1, 2.0),
  105. resize_type=Resize,
  106. keep_ratio=True),
  107. dict(type=RandomCrop, crop_size=(640, 640)),
  108. dict(type=YOLOXHSVRandomAug),
  109. dict(type=RandomFlip, prob=0.5),
  110. dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
  111. dict(
  112. type=CachedMixUp,
  113. img_scale=(640, 640),
  114. ratio_range=(1.0, 1.0),
  115. max_cached_images=20,
  116. pad_val=(114, 114, 114)),
  117. dict(type=PackDetInputs)
  118. ]
  119. train_pipeline_stage2 = [
  120. dict(type=LoadImageFromFile, backend_args=backend_args),
  121. dict(type=LoadAnnotations, with_bbox=True),
  122. dict(
  123. type=RandomResize,
  124. scale=(640, 640),
  125. ratio_range=(0.1, 2.0),
  126. resize_type=Resize,
  127. keep_ratio=True),
  128. dict(type=RandomCrop, crop_size=(640, 640)),
  129. dict(type=YOLOXHSVRandomAug),
  130. dict(type=RandomFlip, prob=0.5),
  131. dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
  132. dict(type=PackDetInputs)
  133. ]
  134. test_pipeline = [
  135. dict(type=LoadImageFromFile, backend_args=backend_args),
  136. dict(type=Resize, scale=(640, 640), keep_ratio=True),
  137. dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
  138. dict(type=LoadAnnotations, with_bbox=True),
  139. dict(
  140. type=PackDetInputs,
  141. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  142. 'scale_factor'))
  143. ]
  144. train_dataloader.update(
  145. dict(
  146. batch_size=32,
  147. num_workers=10,
  148. batch_sampler=None,
  149. pin_memory=True,
  150. dataset=dict(pipeline=train_pipeline)))
  151. val_dataloader.update(
  152. dict(batch_size=5, num_workers=10, dataset=dict(pipeline=test_pipeline)))
  153. test_dataloader = val_dataloader
  154. max_epochs = 300
  155. stage2_num_epochs = 20
  156. base_lr = 0.004
  157. interval = 10
  158. train_cfg.update(
  159. dict(
  160. max_epochs=max_epochs,
  161. val_interval=interval,
  162. dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)]))
  163. val_evaluator.update(dict(proposal_nums=(100, 1, 10)))
  164. test_evaluator = val_evaluator
  165. # optimizer
  166. optim_wrapper = dict(
  167. type=OptimWrapper,
  168. optimizer=dict(type=AdamW, lr=base_lr, weight_decay=0.05),
  169. paramwise_cfg=dict(
  170. norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
  171. # learning rate
  172. param_scheduler = [
  173. dict(
  174. type=LinearLR, start_factor=1.0e-5, by_epoch=False, begin=0, end=1000),
  175. dict(
  176. # use cosine lr from 150 to 300 epoch
  177. type=CosineAnnealingLR,
  178. eta_min=base_lr * 0.05,
  179. begin=max_epochs // 2,
  180. end=max_epochs,
  181. T_max=max_epochs // 2,
  182. by_epoch=True,
  183. convert_to_iter_based=True),
  184. ]
  185. # hooks
  186. default_hooks.update(
  187. dict(
  188. checkpoint=dict(
  189. interval=interval,
  190. max_keep_ckpts=3 # only keep latest 3 checkpoints
  191. )))
  192. custom_hooks = [
  193. dict(
  194. type=EMAHook,
  195. ema_type=ExpMomentumEMA,
  196. momentum=0.0002,
  197. update_buffers=True,
  198. priority=49),
  199. dict(
  200. type=PipelineSwitchHook,
  201. switch_epoch=max_epochs - stage2_num_epochs,
  202. switch_pipeline=train_pipeline_stage2)
  203. ]