123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- def mask_matrix_nms(masks,
- labels,
- scores,
- filter_thr=-1,
- nms_pre=-1,
- max_num=-1,
- kernel='gaussian',
- sigma=2.0,
- mask_area=None):
- """Matrix NMS for multi-class masks.
- Args:
- masks (Tensor): Has shape (num_instances, h, w)
- labels (Tensor): Labels of corresponding masks,
- has shape (num_instances,).
- scores (Tensor): Mask scores of corresponding masks,
- has shape (num_instances).
- filter_thr (float): Score threshold to filter the masks
- after matrix nms. Default: -1, which means do not
- use filter_thr.
- nms_pre (int): The max number of instances to do the matrix nms.
- Default: -1, which means do not use nms_pre.
- max_num (int, optional): If there are more than max_num masks after
- matrix, only top max_num will be kept. Default: -1, which means
- do not use max_num.
- kernel (str): 'linear' or 'gaussian'.
- sigma (float): std in gaussian method.
- mask_area (Tensor): The sum of seg_masks.
- Returns:
- tuple(Tensor): Processed mask results.
- - scores (Tensor): Updated scores, has shape (n,).
- - labels (Tensor): Remained labels, has shape (n,).
- - masks (Tensor): Remained masks, has shape (n, w, h).
- - keep_inds (Tensor): The indices number of
- the remaining mask in the input mask, has shape (n,).
- """
- assert len(labels) == len(masks) == len(scores)
- if len(labels) == 0:
- return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
- 0, *masks.shape[-2:]), labels.new_zeros(0)
- if mask_area is None:
- mask_area = masks.sum((1, 2)).float()
- else:
- assert len(masks) == len(mask_area)
- # sort and keep top nms_pre
- scores, sort_inds = torch.sort(scores, descending=True)
- keep_inds = sort_inds
- if nms_pre > 0 and len(sort_inds) > nms_pre:
- sort_inds = sort_inds[:nms_pre]
- keep_inds = keep_inds[:nms_pre]
- scores = scores[:nms_pre]
- masks = masks[sort_inds]
- mask_area = mask_area[sort_inds]
- labels = labels[sort_inds]
- num_masks = len(labels)
- flatten_masks = masks.reshape(num_masks, -1).float()
- # inter.
- inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
- expanded_mask_area = mask_area.expand(num_masks, num_masks)
- # Upper triangle iou matrix.
- iou_matrix = (inter_matrix /
- (expanded_mask_area + expanded_mask_area.transpose(1, 0) -
- inter_matrix)).triu(diagonal=1)
- # label_specific matrix.
- expanded_labels = labels.expand(num_masks, num_masks)
- # Upper triangle label matrix.
- label_matrix = (expanded_labels == expanded_labels.transpose(
- 1, 0)).triu(diagonal=1)
- # IoU compensation
- compensate_iou, _ = (iou_matrix * label_matrix).max(0)
- compensate_iou = compensate_iou.expand(num_masks,
- num_masks).transpose(1, 0)
- # IoU decay
- decay_iou = iou_matrix * label_matrix
- # Calculate the decay_coefficient
- if kernel == 'gaussian':
- decay_matrix = torch.exp(-1 * sigma * (decay_iou**2))
- compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2))
- decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
- elif kernel == 'linear':
- decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
- decay_coefficient, _ = decay_matrix.min(0)
- else:
- raise NotImplementedError(
- f'{kernel} kernel is not supported in matrix nms!')
- # update the score.
- scores = scores * decay_coefficient
- if filter_thr > 0:
- keep = scores >= filter_thr
- keep_inds = keep_inds[keep]
- if not keep.any():
- return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
- 0, *masks.shape[-2:]), labels.new_zeros(0)
- masks = masks[keep]
- scores = scores[keep]
- labels = labels[keep]
- # sort and keep top max_num
- scores, sort_inds = torch.sort(scores, descending=True)
- keep_inds = keep_inds[sort_inds]
- if max_num > 0 and len(sort_inds) > max_num:
- sort_inds = sort_inds[:max_num]
- keep_inds = keep_inds[:max_num]
- scores = scores[:max_num]
- masks = masks[sort_inds]
- labels = labels[sort_inds]
- return scores, labels, masks, keep_inds
|