retinanet_tta.py 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.transforms.loading import LoadImageFromFile
  3. from mmcv.transforms.processing import TestTimeAug
  4. from mmdet.datasets.transforms.formatting import PackDetInputs
  5. from mmdet.datasets.transforms.loading import LoadAnnotations
  6. from mmdet.datasets.transforms.transforms import RandomFlip, Resize
  7. from mmdet.models.test_time_augs.det_tta import DetTTAModel
  8. tta_model = dict(
  9. type=DetTTAModel,
  10. tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))
  11. img_scales = [(1333, 800), (666, 400), (2000, 1200)]
  12. tta_pipeline = [
  13. dict(type=LoadImageFromFile, backend_args=None),
  14. dict(
  15. type=TestTimeAug,
  16. transforms=[
  17. [dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales],
  18. [dict(type=RandomFlip, prob=1.),
  19. dict(type=RandomFlip, prob=0.)],
  20. [dict(type=LoadAnnotations, with_bbox=True)],
  21. [
  22. dict(
  23. type=PackDetInputs,
  24. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  25. 'scale_factor', 'flip', 'flip_direction'))
  26. ]
  27. ])
  28. ]