123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Sequence
- import torch
- from mmengine.evaluator import BaseMetric
- from mmdet.registry import METRICS
- @METRICS.register_module()
- class RefSegMetric(BaseMetric):
- """Referring Expression Segmentation Metric."""
- def __init__(self, metric: Sequence = ('cIoU', 'mIoU'), **kwargs):
- super().__init__(**kwargs)
- assert set(metric).issubset(['cIoU', 'mIoU']), \
- f'Only support cIoU and mIoU, but got {metric}'
- assert len(metric) > 0, 'metrics should not be empty'
- self.metrics = metric
- def compute_iou(self, pred_seg: torch.Tensor,
- gt_seg: torch.Tensor) -> tuple:
- overlap = pred_seg & gt_seg
- union = pred_seg | gt_seg
- return overlap, union
- def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
- """Process one batch of data and data_samples.
- The processed results should be stored in ``self.results``, which will
- be used to compute the metrics when all batches have been processed.
- Args:
- data_batch (dict): A batch of data from the dataloader.
- data_samples (Sequence[dict]): A batch of outputs from the model.
- """
- for data_sample in data_samples:
- pred_label = data_sample['pred_instances']['masks'].bool()
- label = data_sample['gt_masks'].to_tensor(
- pred_label.dtype, pred_label.device).bool()
- # calculate iou
- overlap, union = self.compute_iou(pred_label, label)
- bs = len(pred_label)
- iou = overlap.reshape(bs, -1).sum(-1) * 1.0 / union.reshape(
- bs, -1).sum(-1)
- iou = torch.nan_to_num_(iou, nan=0.0)
- self.results.append((overlap.sum(), union.sum(), iou.sum(), bs))
- def compute_metrics(self, results: list) -> dict:
- results = tuple(zip(*results))
- assert len(results) == 4
- cum_i = sum(results[0])
- cum_u = sum(results[1])
- iou = sum(results[2])
- seg_total = sum(results[3])
- metrics = {}
- if 'cIoU' in self.metrics:
- metrics['cIoU'] = cum_i * 100 / cum_u
- if 'mIoU' in self.metrics:
- metrics['mIoU'] = iou * 100 / seg_total
- return metrics
|