rtmdet_tta.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.transforms.loading import LoadImageFromFile
  3. from mmcv.transforms.processing import TestTimeAug
  4. from mmdet.datasets.transforms.formatting import PackDetInputs
  5. from mmdet.datasets.transforms.loading import LoadAnnotations
  6. from mmdet.datasets.transforms.transforms import Pad, RandomFlip, Resize
  7. from mmdet.models.test_time_augs.det_tta import DetTTAModel
  8. tta_model = dict(
  9. type=DetTTAModel,
  10. tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))
  11. img_scales = [(640, 640), (320, 320), (960, 960)]
  12. tta_pipeline = [
  13. dict(type=LoadImageFromFile, backend_args=None),
  14. dict(
  15. type=TestTimeAug,
  16. transforms=[
  17. [dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales],
  18. [
  19. # ``RandomFlip`` must be placed before ``Pad``, otherwise
  20. # bounding box coordinates after flipping cannot be
  21. # recovered correctly.
  22. dict(type=RandomFlip, prob=1.),
  23. dict(type=RandomFlip, prob=0.)
  24. ],
  25. [
  26. dict(
  27. type=Pad,
  28. size=(960, 960),
  29. pad_val=dict(img=(114, 114, 114))),
  30. ],
  31. [dict(type=LoadAnnotations, with_bbox=True)],
  32. [
  33. dict(
  34. type=PackDetInputs,
  35. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  36. 'scale_factor', 'flip', 'flip_direction'))
  37. ]
  38. ])
  39. ]