#yolo tta default_scope = 'mmdet' _base_=['./yolox_tta.py'] #yolox_s_8xb8-300e_coco model = dict( type='YOLOX', data_preprocessor=dict( type='DetDataPreprocessor', pad_size_divisor=32, # batch_augments=[ # dict( # type='BatchSyncRandomResize', # random_size_range=(320, 640), # size_divisor=32, # interval=10) # ] ), backbone=dict( type='CSPDarknet', deepen_factor=0.33, widen_factor=0.375, out_indices=(2, 3, 4), use_depthwise=False, dcn=False, spp_kernal_sizes=(5, 9, 13), norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='Swish'), ), neck=dict( type='YOLOXPAFPN', in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1, tr=False, use_depthwise=False, upsample_cfg=dict(scale_factor=2, mode='nearest'), norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='Swish')), bbox_head=dict( type='YOLOXHead', num_classes=1, in_channels=96, feat_channels=96, stacked_convs=2, strides=(8, 16, 32), use_depthwise=False, norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='Swish'), # dcn_on_last_conv=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='sum', loss_weight=1.0), loss_bbox=dict( #use giouloss intead of Iou to get better performance on bbox loss type='SIoULoss', eps=1e-16, reduction='sum', loss_weight=5.0), loss_obj=dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='sum', loss_weight=1.0), loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), #get top 5 instead of top 10 train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2,candidate_topk=10)), # In order to align the source code, the threshold of the val phase is # 0.01, and the threshold of the test phase is 0.001. test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.5))) img_scale = (640, 640) # width, height dataset_type = 'CocoDataset' data_root = '../../../media/tricolops/T7/coco_format/' base_batch_size=16 metainfo = { 'classes': ('barcode',), 'palette': [ (220, 20, 60), ] } backend_args=None train_pipeline = [ dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), dict( type='RandomAffine', scaling_ratio_range=(0.1, 2), # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2)), dict( type='MixUp', img_scale=img_scale, ratio_range=(0.8, 1.6), pad_val=114.0), dict(type='YOLOXHSVRandomAug'), dict(type='RandomFlip', prob=0.5), # According to the official implementation, multi-scale # training is not considered here but in the # 'mmdet/models/detectors/yolox.py'. # Resize and Pad are for the last 15 epochs when Mosaic, # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. dict(type='Resize', scale=img_scale, keep_ratio=True), dict( type='Pad', pad_to_square=True, # If the image is three-channel, the pad value needs # to be set separately for each channel. pad_val=dict(img=(114.0, 114.0, 114.0))), dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), dict(type='PackDetInputs') ] train_dataset = dict( # use MultiImageMixDataset wrapper to support mosaic and mixup type='MultiImageMixDataset', dataset=dict( type=dataset_type, data_root=data_root, ann_file='Train/Train.json', metainfo=metainfo, data_prefix=dict(img='Train/'), pipeline=[ dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='LoadAnnotations', with_bbox=True) ], filter_cfg=dict(filter_empty_gt=False, min_size=32), backend_args=backend_args), pipeline=train_pipeline) test_pipeline = [ dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='Resize', scale=img_scale, keep_ratio=True), dict( type='Pad', pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0))), dict(type='LoadAnnotations', with_bbox=True), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( batch_size=base_batch_size, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=train_dataset) val_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, ann_file='Val/Val.json', data_prefix=dict(img='Val/'), metainfo=metainfo, test_mode=True, pipeline=test_pipeline, backend_args=backend_args)) test_dataloader = val_dataloader val_evaluator = dict( type='CocoMetric', ann_file=data_root + 'Val/Val.json', metric='bbox',) test_evaluator = val_evaluator #schedule 1x # training schedule for 1x max_epochs=500 num_last_epochs = 50 train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=20) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') # learning rate base_lr=0.0025 param_scheduler = [ dict( # use quadratic formula to warm up 5 epochs # and lr is updated by iteration # TODO: fix default scope in get function type='mmdet.QuadraticWarmupLR', by_epoch=True, begin=0, end=5, convert_to_iter_based=True), dict( # use cosine lr from 5 to 285 epoch type='CosineAnnealingLR', eta_min=base_lr * 0.05, begin=5, T_max=max_epochs - num_last_epochs, end=max_epochs - num_last_epochs, by_epoch=True, convert_to_iter_based=True), dict( # use fixed lr during last 15 epochs type='ConstantLR', by_epoch=True, factor=1, begin=max_epochs - num_last_epochs, end=max_epochs, ) ] # optimizer optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='SGD', lr=base_lr, momentum=0.9, weight_decay=0.0005, nesterov=True), paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) # Default setting for scaling LR automatically # - `enable` means enable scaling LR automatically # or not by default. # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). auto_scale_lr = dict(enable=False, base_batch_size=base_batch_size) #default_runtime default_hooks = dict( timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=50), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=10), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='DetVisualizationHook')) env_cfg = dict( cudnn_benchmark=False, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) log_level = 'INFO' load_from = None resume = False custom_hooks = [ dict( type='YOLOXModeSwitchHook', num_last_epochs=num_last_epochs, priority=48), dict(type='SyncNormHook', priority=48), dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0001, update_buffers=True, priority=49) ] # load_from = 'epoch_1000_yolox.pth'