123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- #yolo tta
- default_scope = 'mmdet'
- _base_=['./yolox_tta.py']
- #yolox_s_8xb8-300e_coco
- model = dict(
- type='YOLOX',
- data_preprocessor=dict(
- type='DetDataPreprocessor',
- pad_size_divisor=32,
- # batch_augments=[
- # dict(
- # type='BatchSyncRandomResize',
- # random_size_range=(320, 640),
- # size_divisor=32,
- # interval=10)
- # ]
- ),
- backbone=dict(
- type='CSPDarknet',
- deepen_factor=0.33,
- widen_factor=0.375,
- out_indices=(2, 3, 4),
- use_depthwise=False,
- dcn=False,
- spp_kernal_sizes=(5, 9, 13),
- norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
- act_cfg=dict(type='Swish'),
- ),
- neck=dict(
- type='YOLOXPAFPN',
- in_channels=[96, 192, 384],
- out_channels=96,
- num_csp_blocks=1,
- tr=False,
- use_depthwise=False,
- upsample_cfg=dict(scale_factor=2, mode='nearest'),
- norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
- act_cfg=dict(type='Swish')),
- bbox_head=dict(
- type='YOLOXHead',
- num_classes=1,
- in_channels=96,
- feat_channels=96,
- stacked_convs=2,
- strides=(8, 16, 32),
- use_depthwise=False,
- norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
- act_cfg=dict(type='Swish'),
- # dcn_on_last_conv=True,
- loss_cls=dict(
- type='CrossEntropyLoss',
- use_sigmoid=True,
- reduction='sum',
- loss_weight=1.0),
- loss_bbox=dict(
- #use giouloss intead of Iou to get better performance on bbox loss
- type='SIoULoss',
- eps=1e-16,
- reduction='sum',
- loss_weight=5.0),
- loss_obj=dict(
- type='CrossEntropyLoss',
- use_sigmoid=True,
- reduction='sum',
- loss_weight=1.0),
- loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
- #get top 5 instead of top 10
- train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2,candidate_topk=10)),
- # In order to align the source code, the threshold of the val phase is
- # 0.01, and the threshold of the test phase is 0.001.
- test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.5)))
- img_scale = (640, 640) # width, height
- dataset_type = 'CocoDataset'
- data_root = '../../../media/tricolops/T7/coco_format/'
- base_batch_size=16
- metainfo = {
- 'classes': ('barcode',),
- 'palette': [
- (220, 20, 60),
- ]
- }
- backend_args=None
- train_pipeline = [
- dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
- dict(
- type='RandomAffine',
- scaling_ratio_range=(0.1, 2),
- # img_scale is (width, height)
- border=(-img_scale[0] // 2, -img_scale[1] // 2)),
- dict(
- type='MixUp',
- img_scale=img_scale,
- ratio_range=(0.8, 1.6),
- pad_val=114.0),
- dict(type='YOLOXHSVRandomAug'),
- dict(type='RandomFlip', prob=0.5),
- # According to the official implementation, multi-scale
- # training is not considered here but in the
- # 'mmdet/models/detectors/yolox.py'.
- # Resize and Pad are for the last 15 epochs when Mosaic,
- # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
- dict(type='Resize', scale=img_scale, keep_ratio=True),
- dict(
- type='Pad',
- pad_to_square=True,
- # If the image is three-channel, the pad value needs
- # to be set separately for each channel.
- pad_val=dict(img=(114.0, 114.0, 114.0))),
- dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
- dict(type='PackDetInputs')
- ]
- train_dataset = dict(
- # use MultiImageMixDataset wrapper to support mosaic and mixup
- type='MultiImageMixDataset',
- dataset=dict(
- type=dataset_type,
- data_root=data_root,
- ann_file='Train/Train.json',
- metainfo=metainfo,
- data_prefix=dict(img='Train/'),
- pipeline=[
- dict(type='LoadImageFromFile', backend_args=backend_args),
- dict(type='LoadAnnotations', with_bbox=True)
- ],
- filter_cfg=dict(filter_empty_gt=False, min_size=32),
- backend_args=backend_args),
- pipeline=train_pipeline)
- test_pipeline = [
- dict(type='LoadImageFromFile', backend_args=backend_args),
- dict(type='Resize', scale=img_scale, keep_ratio=True),
- dict(
- type='Pad',
- pad_to_square=True,
- pad_val=dict(img=(114.0, 114.0, 114.0))),
- dict(type='LoadAnnotations', with_bbox=True),
- dict(
- type='PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
- 'scale_factor'))
- ]
- train_dataloader = dict(
- batch_size=base_batch_size,
- num_workers=4,
- persistent_workers=True,
- sampler=dict(type='DefaultSampler', shuffle=True),
- dataset=train_dataset)
- val_dataloader = dict(
- batch_size=1,
- num_workers=4,
- persistent_workers=True,
- drop_last=False,
- sampler=dict(type='DefaultSampler', shuffle=False),
- dataset=dict(
- type=dataset_type,
- data_root=data_root,
- ann_file='Val/Val.json',
- data_prefix=dict(img='Val/'),
- metainfo=metainfo,
- test_mode=True,
- pipeline=test_pipeline,
- backend_args=backend_args))
- test_dataloader = val_dataloader
- val_evaluator = dict(
- type='CocoMetric',
- ann_file=data_root + 'Val/Val.json',
- metric='bbox',)
- test_evaluator = val_evaluator
- #schedule 1x
- # training schedule for 1x
- max_epochs=500
- num_last_epochs = 50
- train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=20)
- val_cfg = dict(type='ValLoop')
- test_cfg = dict(type='TestLoop')
- # learning rate
- base_lr=0.0025
- param_scheduler = [
- dict(
- # use quadratic formula to warm up 5 epochs
- # and lr is updated by iteration
- # TODO: fix default scope in get function
- type='mmdet.QuadraticWarmupLR',
- by_epoch=True,
- begin=0,
- end=5,
- convert_to_iter_based=True),
- dict(
- # use cosine lr from 5 to 285 epoch
- type='CosineAnnealingLR',
- eta_min=base_lr * 0.05,
- begin=5,
- T_max=max_epochs - num_last_epochs,
- end=max_epochs - num_last_epochs,
- by_epoch=True,
- convert_to_iter_based=True),
- dict(
- # use fixed lr during last 15 epochs
- type='ConstantLR',
- by_epoch=True,
- factor=1,
- begin=max_epochs - num_last_epochs,
- end=max_epochs,
- )
- ]
- # optimizer
- optim_wrapper = dict(
- type='OptimWrapper',
- optimizer=dict(type='SGD', lr=base_lr, momentum=0.9, weight_decay=0.0005, nesterov=True),
- paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
- # Default setting for scaling LR automatically
- # - `enable` means enable scaling LR automatically
- # or not by default.
- # - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
- auto_scale_lr = dict(enable=False, base_batch_size=base_batch_size)
- #default_runtime
- default_hooks = dict(
- timer=dict(type='IterTimerHook'),
- logger=dict(type='LoggerHook', interval=50),
- param_scheduler=dict(type='ParamSchedulerHook'),
- checkpoint=dict(type='CheckpointHook', interval=10),
- sampler_seed=dict(type='DistSamplerSeedHook'),
- visualization=dict(type='DetVisualizationHook'))
- env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
- )
- vis_backends = [dict(type='LocalVisBackend')]
- visualizer = dict(
- type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
- log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
- log_level = 'INFO'
- load_from = None
- resume = False
- custom_hooks = [
- dict(
- type='YOLOXModeSwitchHook',
- num_last_epochs=num_last_epochs,
- priority=48),
- dict(type='SyncNormHook', priority=48),
- dict(
- type='EMAHook',
- ema_type='ExpMomentumEMA',
- momentum=0.0001,
- update_buffers=True,
- priority=49)
- ]
- # load_from = 'epoch_1000_yolox.pth'
|