coco_caption_metric.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. import os
  4. import tempfile
  5. from typing import List, Optional
  6. from mmengine.evaluator import BaseMetric
  7. from mmengine.utils import track_iter_progress
  8. from pycocotools.coco import COCO
  9. from mmdet.registry import METRICS
  10. try:
  11. from pycocoevalcap.eval import COCOEvalCap
  12. except ImportError:
  13. COCOEvalCap = None
  14. @METRICS.register_module()
  15. class COCOCaptionMetric(BaseMetric):
  16. """Coco Caption evaluation wrapper.
  17. Save the generated captions and transform into coco format.
  18. Calling COCO API for caption metrics.
  19. Args:
  20. ann_file (str): the path for the COCO format caption ground truth
  21. json file, load for evaluations.
  22. collect_device (str): Device name used for collecting results from
  23. different ranks during distributed training. Must be 'cpu' or
  24. 'gpu'. Defaults to 'cpu'.
  25. prefix (str, optional): The prefix that will be added in the metric
  26. names to disambiguate homonymous metrics of different evaluators.
  27. If prefix is not provided in the argument, self.default_prefix
  28. will be used instead. Should be modified according to the
  29. `retrieval_type` for unambiguous results. Defaults to TR.
  30. """
  31. def __init__(self,
  32. ann_file: str,
  33. collect_device: str = 'cpu',
  34. prefix: Optional[str] = None):
  35. if COCOEvalCap is None:
  36. raise RuntimeError(
  37. 'COCOEvalCap is not installed, please install it by: '
  38. 'pip install pycocoevalcap')
  39. super().__init__(collect_device=collect_device, prefix=prefix)
  40. self.ann_file = ann_file
  41. def process(self, data_batch, data_samples):
  42. """Process one batch of data samples.
  43. The processed results should be stored in ``self.results``, which will
  44. be used to computed the metrics when all batches have been processed.
  45. Args:
  46. data_batch: A batch of data from the dataloader.
  47. data_samples (Sequence[dict]): A batch of outputs from the model.
  48. """
  49. for data_sample in data_samples:
  50. result = dict()
  51. result['caption'] = data_sample['pred_caption']
  52. result['image_id'] = int(data_sample['img_id'])
  53. # Save the result to `self.results`.
  54. self.results.append(result)
  55. def compute_metrics(self, results: List):
  56. """Compute the metrics from processed results.
  57. Args:
  58. results (dict): The processed results of each batch.
  59. Returns:
  60. Dict: The computed metrics. The keys are the names of the metrics,
  61. and the values are corresponding results.
  62. """
  63. # NOTICE: don't access `self.results` from the method.
  64. with tempfile.TemporaryDirectory() as temp_dir:
  65. eval_result_file = save_result(
  66. result=results,
  67. result_dir=temp_dir,
  68. filename='caption_pred',
  69. remove_duplicate='image_id',
  70. )
  71. coco_val = coco_caption_eval(eval_result_file, self.ann_file)
  72. return coco_val
  73. def save_result(result, result_dir, filename, remove_duplicate=''):
  74. """Saving predictions as json file for evaluation."""
  75. # combine results from all processes
  76. if remove_duplicate:
  77. result_new = []
  78. id_list = []
  79. for res in track_iter_progress(result):
  80. if res[remove_duplicate] not in id_list:
  81. id_list.append(res[remove_duplicate])
  82. result_new.append(res)
  83. result = result_new
  84. final_result_file_url = os.path.join(result_dir, '%s.json' % filename)
  85. print(f'result file saved to {final_result_file_url}')
  86. json.dump(result, open(final_result_file_url, 'w'))
  87. return final_result_file_url
  88. def coco_caption_eval(results_file, ann_file):
  89. """Evaluation between gt json and prediction json files."""
  90. # create coco object and coco_result object
  91. coco = COCO(ann_file)
  92. coco_result = coco.loadRes(results_file)
  93. # create coco_eval object by taking coco and coco_result
  94. coco_eval = COCOEvalCap(coco, coco_result)
  95. # make sure the image ids are the same
  96. coco_eval.params['image_id'] = coco_result.getImgIds()
  97. # This will take some times at the first run
  98. coco_eval.evaluate()
  99. # print output evaluation scores
  100. for metric, score in coco_eval.eval.items():
  101. print(f'{metric}: {score:.3f}')
  102. return coco_eval.eval