diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. _base_ = [
  2. 'mmdet::_base_/datasets/coco_detection.py',
  3. 'mmdet::_base_/schedules/schedule_1x.py',
  4. 'mmdet::_base_/default_runtime.py'
  5. ]
  6. custom_imports = dict(
  7. imports=['projects.DiffusionDet.diffusiondet'], allow_failed_imports=False)
  8. # model settings
  9. model = dict(
  10. type='DiffusionDet',
  11. data_preprocessor=dict(
  12. type='DetDataPreprocessor',
  13. mean=[123.675, 116.28, 103.53],
  14. std=[58.395, 57.12, 57.375],
  15. bgr_to_rgb=True,
  16. pad_size_divisor=32),
  17. backbone=dict(
  18. type='ResNet',
  19. depth=50,
  20. num_stages=4,
  21. out_indices=(0, 1, 2, 3),
  22. frozen_stages=1,
  23. norm_cfg=dict(type='BN', requires_grad=True),
  24. norm_eval=True,
  25. style='pytorch',
  26. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  27. neck=dict(
  28. type='FPN',
  29. in_channels=[256, 512, 1024, 2048],
  30. out_channels=256,
  31. num_outs=4),
  32. bbox_head=dict(
  33. type='DynamicDiffusionDetHead',
  34. num_classes=80,
  35. feat_channels=256,
  36. num_proposals=500,
  37. num_heads=6,
  38. deep_supervision=True,
  39. prior_prob=0.01,
  40. snr_scale=2.0,
  41. sampling_timesteps=1,
  42. ddim_sampling_eta=1.0,
  43. single_head=dict(
  44. type='SingleDiffusionDetHead',
  45. num_cls_convs=1,
  46. num_reg_convs=3,
  47. dim_feedforward=2048,
  48. num_heads=8,
  49. dropout=0.0,
  50. act_cfg=dict(type='ReLU', inplace=True),
  51. dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)),
  52. roi_extractor=dict(
  53. type='SingleRoIExtractor',
  54. roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=2),
  55. out_channels=256,
  56. featmap_strides=[4, 8, 16, 32]),
  57. # criterion
  58. criterion=dict(
  59. type='DiffusionDetCriterion',
  60. num_classes=80,
  61. assigner=dict(
  62. type='DiffusionDetMatcher',
  63. match_costs=[
  64. dict(
  65. type='FocalLossCost',
  66. alpha=0.25,
  67. gamma=2.0,
  68. weight=2.0,
  69. eps=1e-8),
  70. dict(type='BBoxL1Cost', weight=5.0, box_format='xyxy'),
  71. dict(type='IoUCost', iou_mode='giou', weight=2.0)
  72. ],
  73. center_radius=2.5,
  74. candidate_topk=5),
  75. loss_cls=dict(
  76. type='FocalLoss',
  77. use_sigmoid=True,
  78. alpha=0.25,
  79. gamma=2.0,
  80. reduction='sum',
  81. loss_weight=2.0),
  82. loss_bbox=dict(type='L1Loss', reduction='sum', loss_weight=5.0),
  83. loss_giou=dict(type='GIoULoss', reduction='sum',
  84. loss_weight=2.0))),
  85. test_cfg=dict(
  86. use_nms=True,
  87. score_thr=0.5,
  88. min_bbox_size=0,
  89. nms=dict(type='nms', iou_threshold=0.5),
  90. ))
  91. backend = 'pillow'
  92. train_pipeline = [
  93. dict(
  94. type='LoadImageFromFile',
  95. backend_args=_base_.backend_args,
  96. imdecode_backend=backend),
  97. dict(type='LoadAnnotations', with_bbox=True),
  98. dict(type='RandomFlip', prob=0.5),
  99. dict(
  100. type='RandomChoice',
  101. transforms=[[
  102. dict(
  103. type='RandomChoiceResize',
  104. scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  105. (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  106. (736, 1333), (768, 1333), (800, 1333)],
  107. keep_ratio=True,
  108. backend=backend),
  109. ],
  110. [
  111. dict(
  112. type='RandomChoiceResize',
  113. scales=[(400, 1333), (500, 1333), (600, 1333)],
  114. keep_ratio=True,
  115. backend=backend),
  116. dict(
  117. type='RandomCrop',
  118. crop_type='absolute_range',
  119. crop_size=(384, 600),
  120. allow_negative_crop=True),
  121. dict(
  122. type='RandomChoiceResize',
  123. scales=[(480, 1333), (512, 1333), (544, 1333),
  124. (576, 1333), (608, 1333), (640, 1333),
  125. (672, 1333), (704, 1333), (736, 1333),
  126. (768, 1333), (800, 1333)],
  127. keep_ratio=True,
  128. backend=backend)
  129. ]]),
  130. dict(type='PackDetInputs')
  131. ]
  132. test_pipeline = [
  133. dict(
  134. type='LoadImageFromFile',
  135. backend_args=_base_.backend_args,
  136. imdecode_backend=backend),
  137. dict(type='Resize', scale=(1333, 800), keep_ratio=True, backend=backend),
  138. # If you don't have a gt annotation, delete the pipeline
  139. dict(type='LoadAnnotations', with_bbox=True),
  140. dict(
  141. type='PackDetInputs',
  142. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  143. 'scale_factor'))
  144. ]
  145. train_dataloader = dict(
  146. sampler=dict(type='InfiniteSampler'),
  147. dataset=dict(
  148. filter_cfg=dict(filter_empty_gt=False, min_size=1e-5),
  149. pipeline=train_pipeline))
  150. val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
  151. test_dataloader = val_dataloader
  152. # optimizer
  153. optim_wrapper = dict(
  154. type='OptimWrapper',
  155. optimizer=dict(
  156. _delete_=True, type='AdamW', lr=0.000025, weight_decay=0.0001),
  157. clip_grad=dict(max_norm=1.0, norm_type=2))
  158. train_cfg = dict(
  159. _delete_=True,
  160. type='IterBasedTrainLoop',
  161. max_iters=450000,
  162. val_interval=75000)
  163. # learning rate
  164. param_scheduler = [
  165. dict(
  166. type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=1000),
  167. dict(
  168. type='MultiStepLR',
  169. begin=0,
  170. end=450000,
  171. by_epoch=False,
  172. milestones=[350000, 420000],
  173. gamma=0.1)
  174. ]
  175. default_hooks = dict(
  176. checkpoint=dict(by_epoch=False, interval=75000, max_keep_ckpts=3))
  177. log_processor = dict(by_epoch=False)