mask2former_r50_8xb2-8e_youtubevis2019.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. _base_ = ['../_base_/datasets/youtube_vis.py', '../_base_/default_runtime.py']
  2. num_classes = 40
  3. num_frames = 2
  4. model = dict(
  5. type='Mask2FormerVideo',
  6. data_preprocessor=dict(
  7. type='TrackDataPreprocessor',
  8. mean=[123.675, 116.28, 103.53],
  9. std=[58.395, 57.12, 57.375],
  10. bgr_to_rgb=True,
  11. pad_mask=True,
  12. pad_size_divisor=32),
  13. backbone=dict(
  14. type='ResNet',
  15. depth=50,
  16. num_stages=4,
  17. out_indices=(0, 1, 2, 3),
  18. frozen_stages=-1,
  19. norm_cfg=dict(type='BN', requires_grad=False),
  20. norm_eval=True,
  21. style='pytorch',
  22. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  23. track_head=dict(
  24. type='Mask2FormerTrackHead',
  25. in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
  26. strides=[4, 8, 16, 32],
  27. feat_channels=256,
  28. out_channels=256,
  29. num_classes=num_classes,
  30. num_queries=100,
  31. num_frames=num_frames,
  32. num_transformer_feat_level=3,
  33. pixel_decoder=dict(
  34. type='MSDeformAttnPixelDecoder',
  35. num_outs=3,
  36. norm_cfg=dict(type='GN', num_groups=32),
  37. act_cfg=dict(type='ReLU'),
  38. encoder=dict( # DeformableDetrTransformerEncoder
  39. num_layers=6,
  40. layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
  41. self_attn_cfg=dict( # MultiScaleDeformableAttention
  42. embed_dims=256,
  43. num_heads=8,
  44. num_levels=3,
  45. num_points=4,
  46. im2col_step=128,
  47. dropout=0.0,
  48. batch_first=True),
  49. ffn_cfg=dict(
  50. embed_dims=256,
  51. feedforward_channels=1024,
  52. num_fcs=2,
  53. ffn_drop=0.0,
  54. act_cfg=dict(type='ReLU', inplace=True)))),
  55. positional_encoding=dict(num_feats=128, normalize=True)),
  56. enforce_decoder_input_project=False,
  57. positional_encoding=dict(
  58. type='SinePositionalEncoding3D', num_feats=128, normalize=True),
  59. transformer_decoder=dict( # Mask2FormerTransformerDecoder
  60. return_intermediate=True,
  61. num_layers=9,
  62. layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
  63. self_attn_cfg=dict( # MultiheadAttention
  64. embed_dims=256,
  65. num_heads=8,
  66. dropout=0.0,
  67. batch_first=True),
  68. cross_attn_cfg=dict( # MultiheadAttention
  69. embed_dims=256,
  70. num_heads=8,
  71. dropout=0.0,
  72. batch_first=True),
  73. ffn_cfg=dict(
  74. embed_dims=256,
  75. feedforward_channels=2048,
  76. num_fcs=2,
  77. ffn_drop=0.0,
  78. act_cfg=dict(type='ReLU', inplace=True))),
  79. init_cfg=None),
  80. loss_cls=dict(
  81. type='CrossEntropyLoss',
  82. use_sigmoid=False,
  83. loss_weight=2.0,
  84. reduction='mean',
  85. class_weight=[1.0] * num_classes + [0.1]),
  86. loss_mask=dict(
  87. type='CrossEntropyLoss',
  88. use_sigmoid=True,
  89. reduction='mean',
  90. loss_weight=5.0),
  91. loss_dice=dict(
  92. type='DiceLoss',
  93. use_sigmoid=True,
  94. activate=True,
  95. reduction='mean',
  96. naive_dice=True,
  97. eps=1.0,
  98. loss_weight=5.0),
  99. train_cfg=dict(
  100. num_points=12544,
  101. oversample_ratio=3.0,
  102. importance_sample_ratio=0.75,
  103. assigner=dict(
  104. type='HungarianAssigner',
  105. match_costs=[
  106. dict(type='ClassificationCost', weight=2.0),
  107. dict(
  108. type='CrossEntropyLossCost',
  109. weight=5.0,
  110. use_sigmoid=True),
  111. dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
  112. ]),
  113. sampler=dict(type='MaskPseudoSampler'))),
  114. init_cfg=dict(
  115. type='Pretrained',
  116. checkpoint='https://download.openmmlab.com/mmdetection/v3.0/'
  117. 'mask2former/mask2former_r50_8xb2-lsj-50e_coco/'
  118. 'mask2former_r50_8xb2-lsj-50e_coco_20220506_191028-41b088b6.pth'))
  119. # optimizer
  120. embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
  121. optim_wrapper = dict(
  122. type='OptimWrapper',
  123. optimizer=dict(
  124. type='AdamW',
  125. lr=0.0001,
  126. weight_decay=0.05,
  127. eps=1e-8,
  128. betas=(0.9, 0.999)),
  129. paramwise_cfg=dict(
  130. custom_keys={
  131. 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
  132. 'query_embed': embed_multi,
  133. 'query_feat': embed_multi,
  134. 'level_embed': embed_multi,
  135. },
  136. norm_decay_mult=0.0),
  137. clip_grad=dict(max_norm=0.01, norm_type=2))
  138. # learning policy
  139. max_iters = 6000
  140. param_scheduler = dict(
  141. type='MultiStepLR',
  142. begin=0,
  143. end=max_iters,
  144. by_epoch=False,
  145. milestones=[
  146. 4000,
  147. ],
  148. gamma=0.1)
  149. # runtime settings
  150. train_cfg = dict(
  151. type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001)
  152. val_cfg = dict(type='ValLoop')
  153. test_cfg = dict(type='TestLoop')
  154. vis_backends = [dict(type='LocalVisBackend')]
  155. visualizer = dict(
  156. type='TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer')
  157. default_hooks = dict(
  158. checkpoint=dict(
  159. type='CheckpointHook', by_epoch=False, save_last=True, interval=2000),
  160. visualization=dict(type='TrackVisualizationHook', draw=False))
  161. log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False)
  162. # evaluator
  163. val_evaluator = dict(
  164. type='YouTubeVISMetric',
  165. metric='youtube_vis_ap',
  166. outfile_prefix='./youtube_vis_results',
  167. format_only=True)
  168. test_evaluator = val_evaluator