refseg_metric.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Sequence
  3. import torch
  4. from mmengine.evaluator import BaseMetric
  5. from mmdet.registry import METRICS
  6. @METRICS.register_module()
  7. class RefSegMetric(BaseMetric):
  8. """Referring Expression Segmentation Metric."""
  9. def __init__(self, metric: Sequence = ('cIoU', 'mIoU'), **kwargs):
  10. super().__init__(**kwargs)
  11. assert set(metric).issubset(['cIoU', 'mIoU']), \
  12. f'Only support cIoU and mIoU, but got {metric}'
  13. assert len(metric) > 0, 'metrics should not be empty'
  14. self.metrics = metric
  15. def compute_iou(self, pred_seg: torch.Tensor,
  16. gt_seg: torch.Tensor) -> tuple:
  17. overlap = pred_seg & gt_seg
  18. union = pred_seg | gt_seg
  19. return overlap, union
  20. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  21. """Process one batch of data and data_samples.
  22. The processed results should be stored in ``self.results``, which will
  23. be used to compute the metrics when all batches have been processed.
  24. Args:
  25. data_batch (dict): A batch of data from the dataloader.
  26. data_samples (Sequence[dict]): A batch of outputs from the model.
  27. """
  28. for data_sample in data_samples:
  29. pred_label = data_sample['pred_instances']['masks'].bool()
  30. label = data_sample['gt_masks'].to_tensor(
  31. pred_label.dtype, pred_label.device).bool()
  32. # calculate iou
  33. overlap, union = self.compute_iou(pred_label, label)
  34. bs = len(pred_label)
  35. iou = overlap.reshape(bs, -1).sum(-1) * 1.0 / union.reshape(
  36. bs, -1).sum(-1)
  37. iou = torch.nan_to_num_(iou, nan=0.0)
  38. self.results.append((overlap.sum(), union.sum(), iou.sum(), bs))
  39. def compute_metrics(self, results: list) -> dict:
  40. results = tuple(zip(*results))
  41. assert len(results) == 4
  42. cum_i = sum(results[0])
  43. cum_u = sum(results[1])
  44. iou = sum(results[2])
  45. seg_total = sum(results[3])
  46. metrics = {}
  47. if 'cIoU' in self.metrics:
  48. metrics['cIoU'] = cum_i * 100 / cum_u
  49. if 'mIoU' in self.metrics:
  50. metrics['mIoU'] = iou * 100 / seg_total
  51. return metrics