coco_detection.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.transforms import LoadImageFromFile
  3. from mmengine.dataset.sampler import DefaultSampler
  4. from mmdet.datasets import AspectRatioBatchSampler, CocoDataset
  5. from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs,
  6. RandomFlip, Resize)
  7. from mmdet.evaluation import CocoMetric
  8. # dataset settings
  9. dataset_type = CocoDataset
  10. data_root = 'data/coco/'
  11. # Example to use different file client
  12. # Method 1: simply set the data root and let the file I/O module
  13. # automatically infer from prefix (not support LMDB and Memcache yet)
  14. # data_root = 's3://openmmlab/datasets/detection/coco/'
  15. # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
  16. # backend_args = dict(
  17. # backend='petrel',
  18. # path_mapping=dict({
  19. # './data/': 's3://openmmlab/datasets/detection/',
  20. # 'data/': 's3://openmmlab/datasets/detection/'
  21. # }))
  22. backend_args = None
  23. train_pipeline = [
  24. dict(type=LoadImageFromFile, backend_args=backend_args),
  25. dict(type=LoadAnnotations, with_bbox=True),
  26. dict(type=Resize, scale=(1333, 800), keep_ratio=True),
  27. dict(type=RandomFlip, prob=0.5),
  28. dict(type=PackDetInputs)
  29. ]
  30. test_pipeline = [
  31. dict(type=LoadImageFromFile, backend_args=backend_args),
  32. dict(type=Resize, scale=(1333, 800), keep_ratio=True),
  33. # If you don't have a gt annotation, delete the pipeline
  34. dict(type=LoadAnnotations, with_bbox=True),
  35. dict(
  36. type=PackDetInputs,
  37. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  38. 'scale_factor'))
  39. ]
  40. train_dataloader = dict(
  41. batch_size=2,
  42. num_workers=2,
  43. persistent_workers=True,
  44. sampler=dict(type=DefaultSampler, shuffle=True),
  45. batch_sampler=dict(type=AspectRatioBatchSampler),
  46. dataset=dict(
  47. type=dataset_type,
  48. data_root=data_root,
  49. ann_file='annotations/instances_train2017.json',
  50. data_prefix=dict(img='train2017/'),
  51. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  52. pipeline=train_pipeline,
  53. backend_args=backend_args))
  54. val_dataloader = dict(
  55. batch_size=1,
  56. num_workers=2,
  57. persistent_workers=True,
  58. drop_last=False,
  59. sampler=dict(type=DefaultSampler, shuffle=False),
  60. dataset=dict(
  61. type=dataset_type,
  62. data_root=data_root,
  63. ann_file='annotations/instances_val2017.json',
  64. data_prefix=dict(img='val2017/'),
  65. test_mode=True,
  66. pipeline=test_pipeline,
  67. backend_args=backend_args))
  68. test_dataloader = val_dataloader
  69. val_evaluator = dict(
  70. type=CocoMetric,
  71. ann_file=data_root + 'annotations/instances_val2017.json',
  72. metric='bbox',
  73. format_only=False,
  74. backend_args=backend_args)
  75. test_evaluator = val_evaluator
  76. # inference on test dataset and
  77. # format the output results for submission.
  78. # test_dataloader = dict(
  79. # batch_size=1,
  80. # num_workers=2,
  81. # persistent_workers=True,
  82. # drop_last=False,
  83. # sampler=dict(type=DefaultSampler, shuffle=False),
  84. # dataset=dict(
  85. # type=dataset_type,
  86. # data_root=data_root,
  87. # ann_file=data_root + 'annotations/image_info_test-dev2017.json',
  88. # data_prefix=dict(img='test2017/'),
  89. # test_mode=True,
  90. # pipeline=test_pipeline))
  91. # test_evaluator = dict(
  92. # type=CocoMetric,
  93. # metric='bbox',
  94. # format_only=True,
  95. # ann_file=data_root + 'annotations/image_info_test-dev2017.json',
  96. # outfile_prefix='./work_dirs/coco_detection/test')