qdtrack_faster-rcnn_r50_fpn_4e_base.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. _base_ = [
  2. '../_base_/models/faster-rcnn_r50_fpn.py', '../_base_/default_runtime.py'
  3. ]
  4. detector = _base_.model
  5. detector.pop('data_preprocessor')
  6. detector['backbone'].update(
  7. dict(
  8. norm_cfg=dict(type='BN', requires_grad=False),
  9. style='caffe',
  10. init_cfg=dict(
  11. type='Pretrained',
  12. checkpoint='open-mmlab://detectron2/resnet50_caffe')))
  13. detector.rpn_head.loss_bbox.update(
  14. dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0))
  15. detector.rpn_head.bbox_coder.update(dict(clip_border=False))
  16. detector.roi_head.bbox_head.update(dict(num_classes=1))
  17. detector.roi_head.bbox_head.bbox_coder.update(dict(clip_border=False))
  18. detector['init_cfg'] = dict(
  19. type='Pretrained',
  20. checkpoint= # noqa: E251
  21. 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
  22. 'faster_rcnn_r50_fpn_1x_coco-person/'
  23. 'faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth'
  24. # noqa: E501
  25. )
  26. del _base_.model
  27. model = dict(
  28. type='QDTrack',
  29. data_preprocessor=dict(
  30. type='TrackDataPreprocessor',
  31. mean=[103.530, 116.280, 123.675],
  32. std=[1.0, 1.0, 1.0],
  33. bgr_to_rgb=False,
  34. pad_size_divisor=32),
  35. detector=detector,
  36. track_head=dict(
  37. type='QuasiDenseTrackHead',
  38. roi_extractor=dict(
  39. type='SingleRoIExtractor',
  40. roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
  41. out_channels=256,
  42. featmap_strides=[4, 8, 16, 32]),
  43. embed_head=dict(
  44. type='QuasiDenseEmbedHead',
  45. num_convs=4,
  46. num_fcs=1,
  47. embed_channels=256,
  48. norm_cfg=dict(type='GN', num_groups=32),
  49. loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),
  50. loss_track_aux=dict(
  51. type='MarginL2Loss',
  52. neg_pos_ub=3,
  53. pos_margin=0,
  54. neg_margin=0.1,
  55. hard_mining=True,
  56. loss_weight=1.0)),
  57. loss_bbox=dict(type='L1Loss', loss_weight=1.0),
  58. train_cfg=dict(
  59. assigner=dict(
  60. type='MaxIoUAssigner',
  61. pos_iou_thr=0.7,
  62. neg_iou_thr=0.5,
  63. min_pos_iou=0.5,
  64. match_low_quality=False,
  65. ignore_iof_thr=-1),
  66. sampler=dict(
  67. type='CombinedSampler',
  68. num=256,
  69. pos_fraction=0.5,
  70. neg_pos_ub=3,
  71. add_gt_as_proposals=True,
  72. pos_sampler=dict(type='InstanceBalancedPosSampler'),
  73. neg_sampler=dict(type='RandomSampler')))),
  74. tracker=dict(
  75. type='QuasiDenseTracker',
  76. init_score_thr=0.9,
  77. obj_score_thr=0.5,
  78. match_score_thr=0.5,
  79. memo_tracklet_frames=30,
  80. memo_backdrop_frames=1,
  81. memo_momentum=0.8,
  82. nms_conf_thr=0.5,
  83. nms_backdrop_iou_thr=0.3,
  84. nms_class_iou_thr=0.7,
  85. with_cats=True,
  86. match_metric='bisoftmax'))
  87. # optimizer
  88. optim_wrapper = dict(
  89. type='OptimWrapper',
  90. optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),
  91. clip_grad=dict(max_norm=35, norm_type=2))
  92. # learning policy
  93. param_scheduler = [
  94. dict(type='MultiStepLR', begin=0, end=4, by_epoch=True, milestones=[3])
  95. ]
  96. # runtime settings
  97. train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=4, val_interval=4)
  98. val_cfg = dict(type='ValLoop')
  99. test_cfg = dict(type='TestLoop')
  100. default_hooks = dict(
  101. logger=dict(type='LoggerHook', interval=50),
  102. visualization=dict(type='TrackVisualizationHook', draw=False))
  103. vis_backends = [dict(type='LocalVisBackend')]
  104. visualizer = dict(
  105. type='TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer')
  106. # custom hooks
  107. custom_hooks = [
  108. # Synchronize model buffers such as running_mean and running_var in BN
  109. # at the end of each epoch
  110. dict(type='SyncBuffersHook')
  111. ]