bd_yolox_8xb2e50.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. #yolo tta
  2. default_scope = 'mmdet'
  3. _base_=['./yolox_tta.py']
  4. #yolox_s_8xb8-300e_coco
  5. model = dict(
  6. type='YOLOX',
  7. data_preprocessor=dict(
  8. type='DetDataPreprocessor',
  9. pad_size_divisor=32,
  10. # batch_augments=[
  11. # dict(
  12. # type='BatchSyncRandomResize',
  13. # random_size_range=(320, 640),
  14. # size_divisor=32,
  15. # interval=10)
  16. # ]
  17. ),
  18. backbone=dict(
  19. type='CSPDarknet',
  20. deepen_factor=0.33,
  21. widen_factor=0.375,
  22. out_indices=(2, 3, 4),
  23. use_depthwise=False,
  24. dcn=False,
  25. spp_kernal_sizes=(5, 9, 13),
  26. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  27. act_cfg=dict(type='Swish'),
  28. ),
  29. neck=dict(
  30. type='YOLOXPAFPN',
  31. in_channels=[96, 192, 384],
  32. out_channels=96,
  33. num_csp_blocks=1,
  34. tr=False,
  35. use_depthwise=False,
  36. upsample_cfg=dict(scale_factor=2, mode='nearest'),
  37. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  38. act_cfg=dict(type='Swish')),
  39. bbox_head=dict(
  40. type='YOLOXHead',
  41. num_classes=1,
  42. in_channels=96,
  43. feat_channels=96,
  44. stacked_convs=2,
  45. strides=(8, 16, 32),
  46. use_depthwise=False,
  47. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  48. act_cfg=dict(type='Swish'),
  49. # dcn_on_last_conv=True,
  50. loss_cls=dict(
  51. type='CrossEntropyLoss',
  52. use_sigmoid=True,
  53. reduction='sum',
  54. loss_weight=1.0),
  55. loss_bbox=dict(
  56. #use giouloss intead of Iou to get better performance on bbox loss
  57. type='SIoULoss',
  58. eps=1e-16,
  59. reduction='sum',
  60. loss_weight=5.0),
  61. loss_obj=dict(
  62. type='CrossEntropyLoss',
  63. use_sigmoid=True,
  64. reduction='sum',
  65. loss_weight=1.0),
  66. loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
  67. #get top 5 instead of top 10
  68. train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2,candidate_topk=10)),
  69. # In order to align the source code, the threshold of the val phase is
  70. # 0.01, and the threshold of the test phase is 0.001.
  71. test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.5)))
  72. img_scale = (640, 640) # width, height
  73. dataset_type = 'CocoDataset'
  74. data_root = '../../../media/tricolops/T7/coco_format/'
  75. base_batch_size=16
  76. metainfo = {
  77. 'classes': ('barcode',),
  78. 'palette': [
  79. (220, 20, 60),
  80. ]
  81. }
  82. backend_args=None
  83. train_pipeline = [
  84. dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
  85. dict(
  86. type='RandomAffine',
  87. scaling_ratio_range=(0.1, 2),
  88. # img_scale is (width, height)
  89. border=(-img_scale[0] // 2, -img_scale[1] // 2)),
  90. dict(
  91. type='MixUp',
  92. img_scale=img_scale,
  93. ratio_range=(0.8, 1.6),
  94. pad_val=114.0),
  95. dict(type='YOLOXHSVRandomAug'),
  96. dict(type='RandomFlip', prob=0.5),
  97. # According to the official implementation, multi-scale
  98. # training is not considered here but in the
  99. # 'mmdet/models/detectors/yolox.py'.
  100. # Resize and Pad are for the last 15 epochs when Mosaic,
  101. # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
  102. dict(type='Resize', scale=img_scale, keep_ratio=True),
  103. dict(
  104. type='Pad',
  105. pad_to_square=True,
  106. # If the image is three-channel, the pad value needs
  107. # to be set separately for each channel.
  108. pad_val=dict(img=(114.0, 114.0, 114.0))),
  109. dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
  110. dict(type='PackDetInputs')
  111. ]
  112. train_dataset = dict(
  113. # use MultiImageMixDataset wrapper to support mosaic and mixup
  114. type='MultiImageMixDataset',
  115. dataset=dict(
  116. type=dataset_type,
  117. data_root=data_root,
  118. ann_file='Train/Train.json',
  119. metainfo=metainfo,
  120. data_prefix=dict(img='Train/'),
  121. pipeline=[
  122. dict(type='LoadImageFromFile', backend_args=backend_args),
  123. dict(type='LoadAnnotations', with_bbox=True)
  124. ],
  125. filter_cfg=dict(filter_empty_gt=False, min_size=32),
  126. backend_args=backend_args),
  127. pipeline=train_pipeline)
  128. test_pipeline = [
  129. dict(type='LoadImageFromFile', backend_args=backend_args),
  130. dict(type='Resize', scale=img_scale, keep_ratio=True),
  131. dict(
  132. type='Pad',
  133. pad_to_square=True,
  134. pad_val=dict(img=(114.0, 114.0, 114.0))),
  135. dict(type='LoadAnnotations', with_bbox=True),
  136. dict(
  137. type='PackDetInputs',
  138. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  139. 'scale_factor'))
  140. ]
  141. train_dataloader = dict(
  142. batch_size=base_batch_size,
  143. num_workers=4,
  144. persistent_workers=True,
  145. sampler=dict(type='DefaultSampler', shuffle=True),
  146. dataset=train_dataset)
  147. val_dataloader = dict(
  148. batch_size=1,
  149. num_workers=4,
  150. persistent_workers=True,
  151. drop_last=False,
  152. sampler=dict(type='DefaultSampler', shuffle=False),
  153. dataset=dict(
  154. type=dataset_type,
  155. data_root=data_root,
  156. ann_file='Val/Val.json',
  157. data_prefix=dict(img='Val/'),
  158. metainfo=metainfo,
  159. test_mode=True,
  160. pipeline=test_pipeline,
  161. backend_args=backend_args))
  162. test_dataloader = val_dataloader
  163. val_evaluator = dict(
  164. type='CocoMetric',
  165. ann_file=data_root + 'Val/Val.json',
  166. metric='bbox',)
  167. test_evaluator = val_evaluator
  168. #schedule 1x
  169. # training schedule for 1x
  170. max_epochs=500
  171. num_last_epochs = 50
  172. train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=20)
  173. val_cfg = dict(type='ValLoop')
  174. test_cfg = dict(type='TestLoop')
  175. # learning rate
  176. base_lr=0.0025
  177. param_scheduler = [
  178. dict(
  179. # use quadratic formula to warm up 5 epochs
  180. # and lr is updated by iteration
  181. # TODO: fix default scope in get function
  182. type='mmdet.QuadraticWarmupLR',
  183. by_epoch=True,
  184. begin=0,
  185. end=5,
  186. convert_to_iter_based=True),
  187. dict(
  188. # use cosine lr from 5 to 285 epoch
  189. type='CosineAnnealingLR',
  190. eta_min=base_lr * 0.05,
  191. begin=5,
  192. T_max=max_epochs - num_last_epochs,
  193. end=max_epochs - num_last_epochs,
  194. by_epoch=True,
  195. convert_to_iter_based=True),
  196. dict(
  197. # use fixed lr during last 15 epochs
  198. type='ConstantLR',
  199. by_epoch=True,
  200. factor=1,
  201. begin=max_epochs - num_last_epochs,
  202. end=max_epochs,
  203. )
  204. ]
  205. # optimizer
  206. optim_wrapper = dict(
  207. type='OptimWrapper',
  208. optimizer=dict(type='SGD', lr=base_lr, momentum=0.9, weight_decay=0.0005, nesterov=True),
  209. paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
  210. # Default setting for scaling LR automatically
  211. # - `enable` means enable scaling LR automatically
  212. # or not by default.
  213. # - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
  214. auto_scale_lr = dict(enable=False, base_batch_size=base_batch_size)
  215. #default_runtime
  216. default_hooks = dict(
  217. timer=dict(type='IterTimerHook'),
  218. logger=dict(type='LoggerHook', interval=50),
  219. param_scheduler=dict(type='ParamSchedulerHook'),
  220. checkpoint=dict(type='CheckpointHook', interval=10),
  221. sampler_seed=dict(type='DistSamplerSeedHook'),
  222. visualization=dict(type='DetVisualizationHook'))
  223. env_cfg = dict(
  224. cudnn_benchmark=False,
  225. mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
  226. dist_cfg=dict(backend='nccl'),
  227. )
  228. vis_backends = [dict(type='LocalVisBackend')]
  229. visualizer = dict(
  230. type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
  231. log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
  232. log_level = 'INFO'
  233. load_from = None
  234. resume = False
  235. custom_hooks = [
  236. dict(
  237. type='YOLOXModeSwitchHook',
  238. num_last_epochs=num_last_epochs,
  239. priority=48),
  240. dict(type='SyncNormHook', priority=48),
  241. dict(
  242. type='EMAHook',
  243. ema_type='ExpMomentumEMA',
  244. momentum=0.0001,
  245. update_buffers=True,
  246. priority=49)
  247. ]
  248. # load_from = 'epoch_1000_yolox.pth'