deformable-detr_r50_16xb2-50e_coco.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. _base_ = [
  2. '../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
  3. ]
  4. model = dict(
  5. type='DeformableDETR',
  6. num_queries=300,
  7. num_feature_levels=4,
  8. with_box_refine=False,
  9. as_two_stage=False,
  10. data_preprocessor=dict(
  11. type='DetDataPreprocessor',
  12. mean=[123.675, 116.28, 103.53],
  13. std=[58.395, 57.12, 57.375],
  14. bgr_to_rgb=True,
  15. pad_size_divisor=1),
  16. backbone=dict(
  17. type='ResNet',
  18. depth=50,
  19. num_stages=4,
  20. out_indices=(1, 2, 3),
  21. frozen_stages=1,
  22. norm_cfg=dict(type='BN', requires_grad=False),
  23. norm_eval=True,
  24. style='pytorch',
  25. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  26. neck=dict(
  27. type='ChannelMapper',
  28. in_channels=[512, 1024, 2048],
  29. kernel_size=1,
  30. out_channels=256,
  31. act_cfg=None,
  32. norm_cfg=dict(type='GN', num_groups=32),
  33. num_outs=4),
  34. encoder=dict( # DeformableDetrTransformerEncoder
  35. num_layers=6,
  36. layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
  37. self_attn_cfg=dict( # MultiScaleDeformableAttention
  38. embed_dims=256,
  39. batch_first=True),
  40. ffn_cfg=dict(
  41. embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
  42. decoder=dict( # DeformableDetrTransformerDecoder
  43. num_layers=6,
  44. return_intermediate=True,
  45. layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
  46. self_attn_cfg=dict( # MultiheadAttention
  47. embed_dims=256,
  48. num_heads=8,
  49. dropout=0.1,
  50. batch_first=True),
  51. cross_attn_cfg=dict( # MultiScaleDeformableAttention
  52. embed_dims=256,
  53. batch_first=True),
  54. ffn_cfg=dict(
  55. embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
  56. post_norm_cfg=None),
  57. positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5),
  58. bbox_head=dict(
  59. type='DeformableDETRHead',
  60. num_classes=80,
  61. sync_cls_avg_factor=True,
  62. loss_cls=dict(
  63. type='FocalLoss',
  64. use_sigmoid=True,
  65. gamma=2.0,
  66. alpha=0.25,
  67. loss_weight=2.0),
  68. loss_bbox=dict(type='L1Loss', loss_weight=5.0),
  69. loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
  70. # training and testing settings
  71. train_cfg=dict(
  72. assigner=dict(
  73. type='HungarianAssigner',
  74. match_costs=[
  75. dict(type='FocalLossCost', weight=2.0),
  76. dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
  77. dict(type='IoUCost', iou_mode='giou', weight=2.0)
  78. ])),
  79. test_cfg=dict(max_per_img=100))
  80. # train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
  81. # from the default setting in mmdet.
  82. train_pipeline = [
  83. dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
  84. dict(type='LoadAnnotations', with_bbox=True),
  85. dict(type='RandomFlip', prob=0.5),
  86. # dict(
  87. # type='RandomChoice',
  88. # transforms=[
  89. # [
  90. # dict(
  91. # type='RandomChoiceResize',
  92. # scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  93. # (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  94. # (736, 1333), (768, 1333), (800, 1333)],
  95. # keep_ratio=True)
  96. # ],
  97. # [
  98. # dict(
  99. # type='RandomChoiceResize',
  100. # # The radio of all image in train dataset < 7
  101. # # follow the original implement
  102. # scales=[(400, 4200), (500, 4200), (600, 4200)],
  103. # keep_ratio=True),
  104. # dict(
  105. # type='RandomCrop',
  106. # crop_type='absolute_range',
  107. # crop_size=(384, 600),
  108. # allow_negative_crop=True),
  109. # dict(
  110. # type='RandomChoiceResize',
  111. # scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  112. # (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  113. # (736, 1333), (768, 1333), (800, 1333)],
  114. # keep_ratio=True)
  115. # ]
  116. # ]),
  117. dict(
  118. type='RandomResize',
  119. scale=(640, 640),
  120. ratio_range=(0.5, 2.0),
  121. keep_ratio=True),
  122. dict(type='RandomCrop', crop_size=(640, 640)),
  123. dict(type='PackDetInputs')
  124. ]
  125. train_dataloader = dict(
  126. dataset=dict(
  127. filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))
  128. # optimizer
  129. optim_wrapper = dict(
  130. type='OptimWrapper',
  131. optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
  132. clip_grad=dict(max_norm=0.1, norm_type=2),
  133. paramwise_cfg=dict(
  134. custom_keys={
  135. 'backbone': dict(lr_mult=0.1),
  136. 'sampling_offsets': dict(lr_mult=0.1),
  137. 'reference_points': dict(lr_mult=0.1)
  138. }))
  139. # learning policy
  140. max_epochs = 50
  141. train_cfg = dict(
  142. type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
  143. val_cfg = dict(type='ValLoop')
  144. test_cfg = dict(type='TestLoop')
  145. param_scheduler = [
  146. dict(
  147. type='MultiStepLR',
  148. begin=0,
  149. end=max_epochs,
  150. by_epoch=True,
  151. milestones=[40],
  152. gamma=0.1)
  153. ]
  154. # NOTE: `auto_scale_lr` is for automatically scaling LR,
  155. # USER SHOULD NOT CHANGE ITS VALUES.
  156. # base_batch_size = (16 GPUs) x (2 samples per GPU)
  157. auto_scale_lr = dict(base_batch_size=32)