retinanet_r50_fpn.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.ops import nms
  3. from torch.nn import BatchNorm2d
  4. from mmdet.models import (FPN, DetDataPreprocessor, FocalLoss, L1Loss, ResNet,
  5. RetinaHead, RetinaNet)
  6. from mmdet.models.task_modules import (AnchorGenerator, DeltaXYWHBBoxCoder,
  7. MaxIoUAssigner, PseudoSampler)
  8. # model settings
  9. model = dict(
  10. type=RetinaNet,
  11. data_preprocessor=dict(
  12. type=DetDataPreprocessor,
  13. mean=[123.675, 116.28, 103.53],
  14. std=[58.395, 57.12, 57.375],
  15. bgr_to_rgb=True,
  16. pad_size_divisor=32),
  17. backbone=dict(
  18. type=ResNet,
  19. depth=50,
  20. num_stages=4,
  21. out_indices=(0, 1, 2, 3),
  22. frozen_stages=1,
  23. norm_cfg=dict(type=BatchNorm2d, requires_grad=True),
  24. norm_eval=True,
  25. style='pytorch',
  26. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  27. neck=dict(
  28. type=FPN,
  29. in_channels=[256, 512, 1024, 2048],
  30. out_channels=256,
  31. start_level=1,
  32. add_extra_convs='on_input',
  33. num_outs=5),
  34. bbox_head=dict(
  35. type=RetinaHead,
  36. num_classes=80,
  37. in_channels=256,
  38. stacked_convs=4,
  39. feat_channels=256,
  40. anchor_generator=dict(
  41. type=AnchorGenerator,
  42. octave_base_scale=4,
  43. scales_per_octave=3,
  44. ratios=[0.5, 1.0, 2.0],
  45. strides=[8, 16, 32, 64, 128]),
  46. bbox_coder=dict(
  47. type=DeltaXYWHBBoxCoder,
  48. target_means=[.0, .0, .0, .0],
  49. target_stds=[1.0, 1.0, 1.0, 1.0]),
  50. loss_cls=dict(
  51. type=FocalLoss,
  52. use_sigmoid=True,
  53. gamma=2.0,
  54. alpha=0.25,
  55. loss_weight=1.0),
  56. loss_bbox=dict(type=L1Loss, loss_weight=1.0)),
  57. # model training and testing settings
  58. train_cfg=dict(
  59. assigner=dict(
  60. type=MaxIoUAssigner,
  61. pos_iou_thr=0.5,
  62. neg_iou_thr=0.4,
  63. min_pos_iou=0,
  64. ignore_iof_thr=-1),
  65. sampler=dict(
  66. type=PseudoSampler), # Focal loss should use PseudoSampler
  67. allowed_border=-1,
  68. pos_weight=-1,
  69. debug=False),
  70. test_cfg=dict(
  71. nms_pre=1000,
  72. min_bbox_size=0,
  73. score_thr=0.05,
  74. nms=dict(type=nms, iou_threshold=0.5),
  75. max_per_img=100))