rtmdet_s_8xb32_300e_coco.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
  3. # mmcv >= 2.0.1
  4. # mmengine >= 0.8.0
  5. from mmengine.config import read_base
  6. with read_base():
  7. from .rtmdet_l_8xb32_300e_coco import *
  8. from mmcv.transforms.loading import LoadImageFromFile
  9. from mmcv.transforms.processing import RandomResize
  10. from mmengine.hooks.ema_hook import EMAHook
  11. from mmdet.datasets.transforms.formatting import PackDetInputs
  12. from mmdet.datasets.transforms.loading import LoadAnnotations
  13. from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
  14. Pad, RandomCrop, RandomFlip,
  15. Resize, YOLOXHSVRandomAug)
  16. from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook
  17. from mmdet.models.layers.ema import ExpMomentumEMA
  18. checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa
  19. model.update(
  20. dict(
  21. backbone=dict(
  22. deepen_factor=0.33,
  23. widen_factor=0.5,
  24. init_cfg=dict(
  25. type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
  26. neck=dict(
  27. in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
  28. bbox_head=dict(in_channels=128, feat_channels=128, exp_on_reg=False)))
  29. train_pipeline = [
  30. dict(type=LoadImageFromFile, backend_args=backend_args),
  31. dict(type=LoadAnnotations, with_bbox=True),
  32. dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0),
  33. dict(
  34. type=RandomResize,
  35. scale=(1280, 1280),
  36. ratio_range=(0.5, 2.0),
  37. resize_type=Resize,
  38. keep_ratio=True),
  39. dict(type=RandomCrop, crop_size=(640, 640)),
  40. dict(type=YOLOXHSVRandomAug),
  41. dict(type=RandomFlip, prob=0.5),
  42. dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
  43. dict(
  44. type=CachedMixUp,
  45. img_scale=(640, 640),
  46. ratio_range=(1.0, 1.0),
  47. max_cached_images=20,
  48. pad_val=(114, 114, 114)),
  49. dict(type=PackDetInputs)
  50. ]
  51. train_pipeline_stage2 = [
  52. dict(type=LoadImageFromFile, backend_args=backend_args),
  53. dict(type=LoadAnnotations, with_bbox=True),
  54. dict(
  55. type=RandomResize,
  56. scale=(640, 640),
  57. ratio_range=(0.5, 2.0),
  58. resize_type=Resize,
  59. keep_ratio=True),
  60. dict(type=RandomCrop, crop_size=(640, 640)),
  61. dict(type=YOLOXHSVRandomAug),
  62. dict(type=RandomFlip, prob=0.5),
  63. dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
  64. dict(type=PackDetInputs)
  65. ]
  66. train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))
  67. custom_hooks = [
  68. dict(
  69. type=EMAHook,
  70. ema_type=ExpMomentumEMA,
  71. momentum=0.0002,
  72. update_buffers=True,
  73. priority=49),
  74. dict(
  75. type=PipelineSwitchHook,
  76. switch_epoch=280,
  77. switch_pipeline=train_pipeline_stage2)
  78. ]