glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. _base_ = [
  2. '../_base_/datasets/coco_detection.py',
  3. '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
  4. ]
  5. lang_model_name = 'bert-base-uncased'
  6. model = dict(
  7. type='GLIP',
  8. data_preprocessor=dict(
  9. type='DetDataPreprocessor',
  10. mean=[103.53, 116.28, 123.675],
  11. std=[57.375, 57.12, 58.395],
  12. bgr_to_rgb=False,
  13. pad_size_divisor=32),
  14. backbone=dict(
  15. type='SwinTransformer',
  16. embed_dims=96,
  17. depths=[2, 2, 6, 2],
  18. num_heads=[3, 6, 12, 24],
  19. window_size=7,
  20. mlp_ratio=4,
  21. qkv_bias=True,
  22. qk_scale=None,
  23. drop_rate=0.,
  24. attn_drop_rate=0.,
  25. drop_path_rate=0.2,
  26. patch_norm=True,
  27. out_indices=(1, 2, 3),
  28. with_cp=False,
  29. convert_weights=False),
  30. neck=dict(
  31. type='FPN',
  32. in_channels=[192, 384, 768],
  33. out_channels=256,
  34. start_level=0,
  35. relu_before_extra_convs=True,
  36. add_extra_convs='on_output',
  37. num_outs=5),
  38. bbox_head=dict(
  39. type='ATSSVLFusionHead',
  40. lang_model_name=lang_model_name,
  41. num_classes=80,
  42. in_channels=256,
  43. feat_channels=256,
  44. anchor_generator=dict(
  45. type='AnchorGenerator',
  46. ratios=[1.0],
  47. octave_base_scale=8,
  48. scales_per_octave=1,
  49. strides=[8, 16, 32, 64, 128],
  50. center_offset=0.5),
  51. bbox_coder=dict(
  52. type='DeltaXYWHBBoxCoderForGLIP',
  53. target_means=[.0, .0, .0, .0],
  54. target_stds=[0.1, 0.1, 0.2, 0.2]),
  55. ),
  56. language_model=dict(type='BertModel', name=lang_model_name),
  57. train_cfg=dict(
  58. assigner=dict(type='ATSSAssigner', topk=9),
  59. allowed_border=-1,
  60. pos_weight=-1,
  61. debug=False),
  62. test_cfg=dict(
  63. nms_pre=1000,
  64. min_bbox_size=0,
  65. score_thr=0.05,
  66. nms=dict(type='nms', iou_threshold=0.6),
  67. max_per_img=100))
  68. test_pipeline = [
  69. dict(
  70. type='LoadImageFromFile',
  71. backend_args=_base_.backend_args,
  72. imdecode_backend='pillow'),
  73. dict(
  74. type='FixScaleResize',
  75. scale=(800, 1333),
  76. keep_ratio=True,
  77. backend='pillow'),
  78. dict(type='LoadAnnotations', with_bbox=True),
  79. dict(
  80. type='PackDetInputs',
  81. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  82. 'scale_factor', 'text', 'custom_entities'))
  83. ]
  84. val_dataloader = dict(
  85. dataset=dict(pipeline=test_pipeline, return_classes=True))
  86. test_dataloader = val_dataloader