matrix_nms.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. def mask_matrix_nms(masks,
  4. labels,
  5. scores,
  6. filter_thr=-1,
  7. nms_pre=-1,
  8. max_num=-1,
  9. kernel='gaussian',
  10. sigma=2.0,
  11. mask_area=None):
  12. """Matrix NMS for multi-class masks.
  13. Args:
  14. masks (Tensor): Has shape (num_instances, h, w)
  15. labels (Tensor): Labels of corresponding masks,
  16. has shape (num_instances,).
  17. scores (Tensor): Mask scores of corresponding masks,
  18. has shape (num_instances).
  19. filter_thr (float): Score threshold to filter the masks
  20. after matrix nms. Default: -1, which means do not
  21. use filter_thr.
  22. nms_pre (int): The max number of instances to do the matrix nms.
  23. Default: -1, which means do not use nms_pre.
  24. max_num (int, optional): If there are more than max_num masks after
  25. matrix, only top max_num will be kept. Default: -1, which means
  26. do not use max_num.
  27. kernel (str): 'linear' or 'gaussian'.
  28. sigma (float): std in gaussian method.
  29. mask_area (Tensor): The sum of seg_masks.
  30. Returns:
  31. tuple(Tensor): Processed mask results.
  32. - scores (Tensor): Updated scores, has shape (n,).
  33. - labels (Tensor): Remained labels, has shape (n,).
  34. - masks (Tensor): Remained masks, has shape (n, w, h).
  35. - keep_inds (Tensor): The indices number of
  36. the remaining mask in the input mask, has shape (n,).
  37. """
  38. assert len(labels) == len(masks) == len(scores)
  39. if len(labels) == 0:
  40. return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
  41. 0, *masks.shape[-2:]), labels.new_zeros(0)
  42. if mask_area is None:
  43. mask_area = masks.sum((1, 2)).float()
  44. else:
  45. assert len(masks) == len(mask_area)
  46. # sort and keep top nms_pre
  47. scores, sort_inds = torch.sort(scores, descending=True)
  48. keep_inds = sort_inds
  49. if nms_pre > 0 and len(sort_inds) > nms_pre:
  50. sort_inds = sort_inds[:nms_pre]
  51. keep_inds = keep_inds[:nms_pre]
  52. scores = scores[:nms_pre]
  53. masks = masks[sort_inds]
  54. mask_area = mask_area[sort_inds]
  55. labels = labels[sort_inds]
  56. num_masks = len(labels)
  57. flatten_masks = masks.reshape(num_masks, -1).float()
  58. # inter.
  59. inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
  60. expanded_mask_area = mask_area.expand(num_masks, num_masks)
  61. # Upper triangle iou matrix.
  62. iou_matrix = (inter_matrix /
  63. (expanded_mask_area + expanded_mask_area.transpose(1, 0) -
  64. inter_matrix)).triu(diagonal=1)
  65. # label_specific matrix.
  66. expanded_labels = labels.expand(num_masks, num_masks)
  67. # Upper triangle label matrix.
  68. label_matrix = (expanded_labels == expanded_labels.transpose(
  69. 1, 0)).triu(diagonal=1)
  70. # IoU compensation
  71. compensate_iou, _ = (iou_matrix * label_matrix).max(0)
  72. compensate_iou = compensate_iou.expand(num_masks,
  73. num_masks).transpose(1, 0)
  74. # IoU decay
  75. decay_iou = iou_matrix * label_matrix
  76. # Calculate the decay_coefficient
  77. if kernel == 'gaussian':
  78. decay_matrix = torch.exp(-1 * sigma * (decay_iou**2))
  79. compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2))
  80. decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
  81. elif kernel == 'linear':
  82. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  83. decay_coefficient, _ = decay_matrix.min(0)
  84. else:
  85. raise NotImplementedError(
  86. f'{kernel} kernel is not supported in matrix nms!')
  87. # update the score.
  88. scores = scores * decay_coefficient
  89. if filter_thr > 0:
  90. keep = scores >= filter_thr
  91. keep_inds = keep_inds[keep]
  92. if not keep.any():
  93. return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
  94. 0, *masks.shape[-2:]), labels.new_zeros(0)
  95. masks = masks[keep]
  96. scores = scores[keep]
  97. labels = labels[keep]
  98. # sort and keep top max_num
  99. scores, sort_inds = torch.sort(scores, descending=True)
  100. keep_inds = keep_inds[sort_inds]
  101. if max_num > 0 and len(sort_inds) > max_num:
  102. sort_inds = sort_inds[:max_num]
  103. keep_inds = keep_inds[:max_num]
  104. scores = scores[:max_num]
  105. masks = masks[sort_inds]
  106. labels = labels[sort_inds]
  107. return scores, labels, masks, keep_inds