dump_det_results.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import Sequence
  4. from mmengine.evaluator import DumpResults
  5. from mmengine.evaluator.metric import _to_cpu
  6. from mmdet.registry import METRICS
  7. from mmdet.structures.mask import encode_mask_results
  8. @METRICS.register_module()
  9. class DumpDetResults(DumpResults):
  10. """Dump model predictions to a pickle file for offline evaluation.
  11. Different from `DumpResults` in MMEngine, it compresses instance
  12. segmentation masks into RLE format.
  13. Args:
  14. out_file_path (str): Path of the dumped file. Must end with '.pkl'
  15. or '.pickle'.
  16. collect_device (str): Device name used for collecting results from
  17. different ranks during distributed training. Must be 'cpu' or
  18. 'gpu'. Defaults to 'cpu'.
  19. """
  20. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  21. """transfer tensors in predictions to CPU."""
  22. data_samples = _to_cpu(data_samples)
  23. for data_sample in data_samples:
  24. # remove gt
  25. data_sample.pop('gt_instances', None)
  26. data_sample.pop('ignored_instances', None)
  27. data_sample.pop('gt_panoptic_seg', None)
  28. if 'pred_instances' in data_sample:
  29. pred = data_sample['pred_instances']
  30. # encode mask to RLE
  31. if 'masks' in pred:
  32. pred['masks'] = encode_mask_results(pred['masks'].numpy())
  33. if 'pred_panoptic_seg' in data_sample:
  34. warnings.warn(
  35. 'Panoptic segmentation map will not be compressed. '
  36. 'The dumped file will be extremely large! '
  37. 'Suggest using `CocoPanopticMetric` to save the coco '
  38. 'format json and segmentation png files directly.')
  39. self.results.extend(data_samples)