masktrack-rcnn_mask-rcnn_r50_fpn_8xb1-12e_youtubevis2019.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. _base_ = [
  2. '../_base_/models/mask-rcnn_r50_fpn.py',
  3. '../_base_/datasets/youtube_vis.py', '../_base_/default_runtime.py'
  4. ]
  5. detector = _base_.model
  6. detector.pop('data_preprocessor')
  7. detector.roi_head.bbox_head.update(dict(num_classes=40))
  8. detector.roi_head.mask_head.update(dict(num_classes=40))
  9. detector.train_cfg.rpn.sampler.update(dict(num=64))
  10. detector.train_cfg.rpn_proposal.update(dict(nms_pre=200, max_per_img=200))
  11. detector.train_cfg.rcnn.sampler.update(dict(num=128))
  12. detector.test_cfg.rpn.update(dict(nms_pre=200, max_per_img=200))
  13. detector.test_cfg.rcnn.update(dict(score_thr=0.01))
  14. detector['init_cfg'] = dict(
  15. type='Pretrained',
  16. checkpoint= # noqa: E251
  17. 'https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_1x_coco/mask_rcnn_r50_fpn_1x_coco_20200205-d4b0c5d6.pth' # noqa: E501
  18. )
  19. del _base_.model
  20. model = dict(
  21. type='MaskTrackRCNN',
  22. data_preprocessor=dict(
  23. type='TrackDataPreprocessor',
  24. mean=[123.675, 116.28, 103.53],
  25. std=[58.395, 57.12, 57.375],
  26. bgr_to_rgb=True,
  27. pad_mask=True,
  28. pad_size_divisor=32),
  29. detector=detector,
  30. track_head=dict(
  31. type='RoITrackHead',
  32. roi_extractor=dict(
  33. type='SingleRoIExtractor',
  34. roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
  35. out_channels=256,
  36. featmap_strides=[4, 8, 16, 32]),
  37. embed_head=dict(
  38. type='RoIEmbedHead',
  39. num_fcs=2,
  40. roi_feat_size=7,
  41. in_channels=256,
  42. fc_out_channels=1024),
  43. train_cfg=dict(
  44. assigner=dict(
  45. type='MaxIoUAssigner',
  46. pos_iou_thr=0.5,
  47. neg_iou_thr=0.5,
  48. min_pos_iou=0.5,
  49. match_low_quality=True,
  50. ignore_iof_thr=-1),
  51. sampler=dict(
  52. type='RandomSampler',
  53. num=128,
  54. pos_fraction=0.25,
  55. neg_pos_ub=-1,
  56. add_gt_as_proposals=True),
  57. pos_weight=-1,
  58. debug=False)),
  59. tracker=dict(
  60. type='MaskTrackRCNNTracker',
  61. match_weights=dict(det_score=1.0, iou=2.0, det_label=10.0),
  62. num_frames_retain=20))
  63. dataset_type = 'YouTubeVISDataset'
  64. data_root = 'data/youtube_vis_2019/'
  65. dataset_version = data_root[-5:-1] # 2019 or 2021
  66. # train_dataloader
  67. train_dataloader = dict(
  68. _delete_=True,
  69. batch_size=1,
  70. num_workers=2,
  71. persistent_workers=True,
  72. sampler=dict(type='TrackImgSampler'), # image-based sampling
  73. batch_sampler=dict(type='TrackAspectRatioBatchSampler'),
  74. dataset=dict(
  75. type=dataset_type,
  76. data_root=data_root,
  77. dataset_version=dataset_version,
  78. ann_file='annotations/youtube_vis_2019_train.json',
  79. data_prefix=dict(img_path='train/JPEGImages'),
  80. pipeline=_base_.train_pipeline))
  81. # optimizer
  82. optim_wrapper = dict(
  83. type='OptimWrapper',
  84. optimizer=dict(type='SGD', lr=0.00125, momentum=0.9, weight_decay=0.0001),
  85. clip_grad=dict(max_norm=35, norm_type=2))
  86. # learning policy
  87. param_scheduler = [
  88. dict(
  89. type='LinearLR',
  90. start_factor=1.0 / 3.0,
  91. by_epoch=False,
  92. begin=0,
  93. end=500),
  94. dict(
  95. type='MultiStepLR',
  96. begin=0,
  97. end=12,
  98. by_epoch=True,
  99. milestones=[8, 11],
  100. gamma=0.1)
  101. ]
  102. # visualizer
  103. default_hooks = dict(
  104. visualization=dict(type='TrackVisualizationHook', draw=False))
  105. vis_backends = [dict(type='LocalVisBackend')]
  106. visualizer = dict(
  107. type='TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer')
  108. # runtime settings
  109. train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_begin=13)
  110. val_cfg = dict(type='ValLoop')
  111. test_cfg = dict(type='TestLoop')
  112. # evaluator
  113. val_evaluator = dict(
  114. type='YouTubeVISMetric',
  115. metric='youtube_vis_ap',
  116. outfile_prefix='./youtube_vis_results',
  117. format_only=True)
  118. test_evaluator = val_evaluator
  119. del detector