heuristic_fusion_head.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. from mmengine.structures import InstanceData, PixelData
  5. from torch import Tensor
  6. from mmdet.evaluation.functional import INSTANCE_OFFSET
  7. from mmdet.registry import MODELS
  8. from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig, PixelList
  9. from .base_panoptic_fusion_head import BasePanopticFusionHead
  10. @MODELS.register_module()
  11. class HeuristicFusionHead(BasePanopticFusionHead):
  12. """Fusion Head with Heuristic method."""
  13. def __init__(self,
  14. num_things_classes: int = 80,
  15. num_stuff_classes: int = 53,
  16. test_cfg: OptConfigType = None,
  17. init_cfg: OptMultiConfig = None,
  18. **kwargs) -> None:
  19. super().__init__(
  20. num_things_classes=num_things_classes,
  21. num_stuff_classes=num_stuff_classes,
  22. test_cfg=test_cfg,
  23. loss_panoptic=None,
  24. init_cfg=init_cfg,
  25. **kwargs)
  26. def loss(self, **kwargs) -> dict:
  27. """HeuristicFusionHead has no training loss."""
  28. return dict()
  29. def _lay_masks(self,
  30. mask_results: InstanceData,
  31. overlap_thr: float = 0.5) -> Tensor:
  32. """Lay instance masks to a result map.
  33. Args:
  34. mask_results (:obj:`InstanceData`): Instance segmentation results,
  35. each contains ``bboxes``, ``labels``, ``scores`` and ``masks``.
  36. overlap_thr (float): Threshold to determine whether two masks
  37. overlap. default: 0.5.
  38. Returns:
  39. Tensor: The result map, (H, W).
  40. """
  41. bboxes = mask_results.bboxes
  42. scores = mask_results.scores
  43. labels = mask_results.labels
  44. masks = mask_results.masks
  45. num_insts = bboxes.shape[0]
  46. id_map = torch.zeros(
  47. masks.shape[-2:], device=bboxes.device, dtype=torch.long)
  48. if num_insts == 0:
  49. return id_map, labels
  50. # Sort by score to use heuristic fusion
  51. order = torch.argsort(-scores)
  52. bboxes = bboxes[order]
  53. labels = labels[order]
  54. segm_masks = masks[order]
  55. instance_id = 1
  56. left_labels = []
  57. for idx in range(bboxes.shape[0]):
  58. _cls = labels[idx]
  59. _mask = segm_masks[idx]
  60. instance_id_map = torch.ones_like(
  61. _mask, dtype=torch.long) * instance_id
  62. area = _mask.sum()
  63. if area == 0:
  64. continue
  65. pasted = id_map > 0
  66. intersect = (_mask * pasted).sum()
  67. if (intersect / (area + 1e-5)) > overlap_thr:
  68. continue
  69. _part = _mask * (~pasted)
  70. id_map = torch.where(_part, instance_id_map, id_map)
  71. left_labels.append(_cls)
  72. instance_id += 1
  73. if len(left_labels) > 0:
  74. instance_labels = torch.stack(left_labels)
  75. else:
  76. instance_labels = bboxes.new_zeros((0, ), dtype=torch.long)
  77. assert instance_id == (len(instance_labels) + 1)
  78. return id_map, instance_labels
  79. def _predict_single(self, mask_results: InstanceData, seg_preds: Tensor,
  80. **kwargs) -> PixelData:
  81. """Fuse the results of instance and semantic segmentations.
  82. Args:
  83. mask_results (:obj:`InstanceData`): Instance segmentation results,
  84. each contains ``bboxes``, ``labels``, ``scores`` and ``masks``.
  85. seg_preds (Tensor): The semantic segmentation results,
  86. (num_stuff + 1, H, W).
  87. Returns:
  88. Tensor: The panoptic segmentation result, (H, W).
  89. """
  90. id_map, labels = self._lay_masks(mask_results,
  91. self.test_cfg.mask_overlap)
  92. seg_results = seg_preds.argmax(dim=0)
  93. seg_results = seg_results + self.num_things_classes
  94. pan_results = seg_results
  95. instance_id = 1
  96. for idx in range(len(mask_results)):
  97. _mask = id_map == (idx + 1)
  98. if _mask.sum() == 0:
  99. continue
  100. _cls = labels[idx]
  101. # simply trust detection
  102. segment_id = _cls + instance_id * INSTANCE_OFFSET
  103. pan_results[_mask] = segment_id
  104. instance_id += 1
  105. ids, counts = torch.unique(
  106. pan_results % INSTANCE_OFFSET, return_counts=True)
  107. stuff_ids = ids[ids >= self.num_things_classes]
  108. stuff_counts = counts[ids >= self.num_things_classes]
  109. ignore_stuff_ids = stuff_ids[
  110. stuff_counts < self.test_cfg.stuff_area_limit]
  111. assert pan_results.ndim == 2
  112. pan_results[(pan_results.unsqueeze(2) == ignore_stuff_ids.reshape(
  113. 1, 1, -1)).any(dim=2)] = self.num_classes
  114. pan_results = PixelData(sem_seg=pan_results[None].int())
  115. return pan_results
  116. def predict(self, mask_results_list: InstanceList,
  117. seg_preds_list: List[Tensor], **kwargs) -> PixelList:
  118. """Predict results by fusing the results of instance and semantic
  119. segmentations.
  120. Args:
  121. mask_results_list (list[:obj:`InstanceData`]): Instance
  122. segmentation results, each contains ``bboxes``, ``labels``,
  123. ``scores`` and ``masks``.
  124. seg_preds_list (Tensor): List of semantic segmentation results.
  125. Returns:
  126. List[PixelData]: Panoptic segmentation result.
  127. """
  128. results_list = [
  129. self._predict_single(mask_results_list[i], seg_preds_list[i])
  130. for i in range(len(mask_results_list))
  131. ]
  132. return results_list