mask_rcnn_r50_fpn.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.ops import RoIAlign, nms
  3. from torch.nn import BatchNorm2d
  4. from mmdet.models.backbones.resnet import ResNet
  5. from mmdet.models.data_preprocessors.data_preprocessor import \
  6. DetDataPreprocessor
  7. from mmdet.models.dense_heads.rpn_head import RPNHead
  8. from mmdet.models.detectors.mask_rcnn import MaskRCNN
  9. from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss
  10. from mmdet.models.losses.smooth_l1_loss import L1Loss
  11. from mmdet.models.necks.fpn import FPN
  12. from mmdet.models.roi_heads.bbox_heads.convfc_bbox_head import \
  13. Shared2FCBBoxHead
  14. from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead
  15. from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \
  16. SingleRoIExtractor
  17. from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
  18. from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner
  19. from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \
  20. DeltaXYWHBBoxCoder
  21. from mmdet.models.task_modules.prior_generators.anchor_generator import \
  22. AnchorGenerator
  23. from mmdet.models.task_modules.samplers.random_sampler import RandomSampler
  24. # model settings
  25. model = dict(
  26. type=MaskRCNN,
  27. data_preprocessor=dict(
  28. type=DetDataPreprocessor,
  29. mean=[123.675, 116.28, 103.53],
  30. std=[58.395, 57.12, 57.375],
  31. bgr_to_rgb=True,
  32. pad_mask=True,
  33. pad_size_divisor=32),
  34. backbone=dict(
  35. type=ResNet,
  36. depth=50,
  37. num_stages=4,
  38. out_indices=(0, 1, 2, 3),
  39. frozen_stages=1,
  40. norm_cfg=dict(type=BatchNorm2d, requires_grad=True),
  41. norm_eval=True,
  42. style='pytorch',
  43. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  44. neck=dict(
  45. type=FPN,
  46. in_channels=[256, 512, 1024, 2048],
  47. out_channels=256,
  48. num_outs=5),
  49. rpn_head=dict(
  50. type=RPNHead,
  51. in_channels=256,
  52. feat_channels=256,
  53. anchor_generator=dict(
  54. type=AnchorGenerator,
  55. scales=[8],
  56. ratios=[0.5, 1.0, 2.0],
  57. strides=[4, 8, 16, 32, 64]),
  58. bbox_coder=dict(
  59. type=DeltaXYWHBBoxCoder,
  60. target_means=[.0, .0, .0, .0],
  61. target_stds=[1.0, 1.0, 1.0, 1.0]),
  62. loss_cls=dict(
  63. type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0),
  64. loss_bbox=dict(type=L1Loss, loss_weight=1.0)),
  65. roi_head=dict(
  66. type=StandardRoIHead,
  67. bbox_roi_extractor=dict(
  68. type=SingleRoIExtractor,
  69. roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0),
  70. out_channels=256,
  71. featmap_strides=[4, 8, 16, 32]),
  72. bbox_head=dict(
  73. type=Shared2FCBBoxHead,
  74. in_channels=256,
  75. fc_out_channels=1024,
  76. roi_feat_size=7,
  77. num_classes=80,
  78. bbox_coder=dict(
  79. type=DeltaXYWHBBoxCoder,
  80. target_means=[0., 0., 0., 0.],
  81. target_stds=[0.1, 0.1, 0.2, 0.2]),
  82. reg_class_agnostic=False,
  83. loss_cls=dict(
  84. type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0),
  85. loss_bbox=dict(type=L1Loss, loss_weight=1.0)),
  86. mask_roi_extractor=dict(
  87. type=SingleRoIExtractor,
  88. roi_layer=dict(type=RoIAlign, output_size=14, sampling_ratio=0),
  89. out_channels=256,
  90. featmap_strides=[4, 8, 16, 32]),
  91. mask_head=dict(
  92. type=FCNMaskHead,
  93. num_convs=4,
  94. in_channels=256,
  95. conv_out_channels=256,
  96. num_classes=80,
  97. loss_mask=dict(
  98. type=CrossEntropyLoss, use_mask=True, loss_weight=1.0))),
  99. # model training and testing settings
  100. train_cfg=dict(
  101. rpn=dict(
  102. assigner=dict(
  103. type=MaxIoUAssigner,
  104. pos_iou_thr=0.7,
  105. neg_iou_thr=0.3,
  106. min_pos_iou=0.3,
  107. match_low_quality=True,
  108. ignore_iof_thr=-1),
  109. sampler=dict(
  110. type=RandomSampler,
  111. num=256,
  112. pos_fraction=0.5,
  113. neg_pos_ub=-1,
  114. add_gt_as_proposals=False),
  115. allowed_border=-1,
  116. pos_weight=-1,
  117. debug=False),
  118. rpn_proposal=dict(
  119. nms_pre=2000,
  120. max_per_img=1000,
  121. nms=dict(type=nms, iou_threshold=0.7),
  122. min_bbox_size=0),
  123. rcnn=dict(
  124. assigner=dict(
  125. type=MaxIoUAssigner,
  126. pos_iou_thr=0.5,
  127. neg_iou_thr=0.5,
  128. min_pos_iou=0.5,
  129. match_low_quality=True,
  130. ignore_iof_thr=-1),
  131. sampler=dict(
  132. type=RandomSampler,
  133. num=512,
  134. pos_fraction=0.25,
  135. neg_pos_ub=-1,
  136. add_gt_as_proposals=True),
  137. mask_size=28,
  138. pos_weight=-1,
  139. debug=False)),
  140. test_cfg=dict(
  141. rpn=dict(
  142. nms_pre=1000,
  143. max_per_img=1000,
  144. nms=dict(type=nms, iou_threshold=0.7),
  145. min_bbox_size=0),
  146. rcnn=dict(
  147. score_thr=0.05,
  148. nms=dict(type=nms, iou_threshold=0.5),
  149. max_per_img=100,
  150. mask_thr_binary=0.5)))