match_cost.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import abstractmethod
  3. from typing import Optional, Union
  4. import torch
  5. import torch.nn.functional as F
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.registry import TASK_UTILS
  9. from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcywh
  10. class BaseMatchCost:
  11. """Base match cost class.
  12. Args:
  13. weight (Union[float, int]): Cost weight. Defaults to 1.
  14. """
  15. def __init__(self, weight: Union[float, int] = 1.) -> None:
  16. self.weight = weight
  17. @abstractmethod
  18. def __call__(self,
  19. pred_instances: InstanceData,
  20. gt_instances: InstanceData,
  21. img_meta: Optional[dict] = None,
  22. **kwargs) -> Tensor:
  23. """Compute match cost.
  24. Args:
  25. pred_instances (:obj:`InstanceData`): Instances of model
  26. predictions. It includes ``priors``, and the priors can
  27. be anchors or points, or the bboxes predicted by the
  28. previous stage, has shape (n, 4). The bboxes predicted by
  29. the current model or stage will be named ``bboxes``,
  30. ``labels``, and ``scores``, the same as the ``InstanceData``
  31. in other places.
  32. gt_instances (:obj:`InstanceData`): Ground truth of instance
  33. annotations. It usually includes ``bboxes``, with shape (k, 4),
  34. and ``labels``, with shape (k, ).
  35. img_meta (dict, optional): Image information.
  36. Returns:
  37. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  38. """
  39. pass
  40. @TASK_UTILS.register_module()
  41. class BBoxL1Cost(BaseMatchCost):
  42. """BBoxL1Cost.
  43. Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy'
  44. and its coordinates are unnormalized.
  45. Args:
  46. box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN.
  47. Defaults to 'xyxy'.
  48. weight (Union[float, int]): Cost weight. Defaults to 1.
  49. Examples:
  50. >>> from mmdet.models.task_modules.assigners.
  51. ... match_costs.match_cost import BBoxL1Cost
  52. >>> import torch
  53. >>> self = BBoxL1Cost()
  54. >>> bbox_pred = torch.rand(1, 4)
  55. >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
  56. >>> factor = torch.tensor([10, 8, 10, 8])
  57. >>> self(bbox_pred, gt_bboxes, factor)
  58. tensor([[1.6172, 1.6422]])
  59. """
  60. def __init__(self,
  61. box_format: str = 'xyxy',
  62. weight: Union[float, int] = 1.) -> None:
  63. super().__init__(weight=weight)
  64. assert box_format in ['xyxy', 'xywh']
  65. self.box_format = box_format
  66. def __call__(self,
  67. pred_instances: InstanceData,
  68. gt_instances: InstanceData,
  69. img_meta: Optional[dict] = None,
  70. **kwargs) -> Tensor:
  71. """Compute match cost.
  72. Args:
  73. pred_instances (:obj:`InstanceData`): ``bboxes`` inside is
  74. predicted boxes with unnormalized coordinate
  75. (x, y, x, y).
  76. gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt
  77. bboxes with unnormalized coordinate (x, y, x, y).
  78. img_meta (Optional[dict]): Image information. Defaults to None.
  79. Returns:
  80. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  81. """
  82. pred_bboxes = pred_instances.bboxes
  83. gt_bboxes = gt_instances.bboxes
  84. # convert box format
  85. if self.box_format == 'xywh':
  86. gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
  87. pred_bboxes = bbox_xyxy_to_cxcywh(pred_bboxes)
  88. # normalized
  89. img_h, img_w = img_meta['img_shape']
  90. factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
  91. img_h]).unsqueeze(0)
  92. gt_bboxes = gt_bboxes / factor
  93. pred_bboxes = pred_bboxes / factor
  94. bbox_cost = torch.cdist(pred_bboxes, gt_bboxes, p=1)
  95. return bbox_cost * self.weight
  96. @TASK_UTILS.register_module()
  97. class IoUCost(BaseMatchCost):
  98. """IoUCost.
  99. Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy'
  100. and its coordinates are unnormalized.
  101. Args:
  102. iou_mode (str): iou mode such as 'iou', 'giou'. Defaults to 'giou'.
  103. weight (Union[float, int]): Cost weight. Defaults to 1.
  104. Examples:
  105. >>> from mmdet.models.task_modules.assigners.
  106. ... match_costs.match_cost import IoUCost
  107. >>> import torch
  108. >>> self = IoUCost()
  109. >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
  110. >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
  111. >>> self(bboxes, gt_bboxes)
  112. tensor([[-0.1250, 0.1667],
  113. [ 0.1667, -0.5000]])
  114. """
  115. def __init__(self, iou_mode: str = 'giou', weight: Union[float, int] = 1.):
  116. super().__init__(weight=weight)
  117. self.iou_mode = iou_mode
  118. def __call__(self,
  119. pred_instances: InstanceData,
  120. gt_instances: InstanceData,
  121. img_meta: Optional[dict] = None,
  122. **kwargs):
  123. """Compute match cost.
  124. Args:
  125. pred_instances (:obj:`InstanceData`): ``bboxes`` inside is
  126. predicted boxes with unnormalized coordinate
  127. (x, y, x, y).
  128. gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt
  129. bboxes with unnormalized coordinate (x, y, x, y).
  130. img_meta (Optional[dict]): Image information. Defaults to None.
  131. Returns:
  132. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  133. """
  134. pred_bboxes = pred_instances.bboxes
  135. gt_bboxes = gt_instances.bboxes
  136. overlaps = bbox_overlaps(
  137. pred_bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
  138. # The 1 is a constant that doesn't change the matching, so omitted.
  139. iou_cost = -overlaps
  140. return iou_cost * self.weight
  141. @TASK_UTILS.register_module()
  142. class ClassificationCost(BaseMatchCost):
  143. """ClsSoftmaxCost.
  144. Args:
  145. weight (Union[float, int]): Cost weight. Defaults to 1.
  146. Examples:
  147. >>> from mmdet.models.task_modules.assigners.
  148. ... match_costs.match_cost import ClassificationCost
  149. >>> import torch
  150. >>> self = ClassificationCost()
  151. >>> cls_pred = torch.rand(4, 3)
  152. >>> gt_labels = torch.tensor([0, 1, 2])
  153. >>> factor = torch.tensor([10, 8, 10, 8])
  154. >>> self(cls_pred, gt_labels)
  155. tensor([[-0.3430, -0.3525, -0.3045],
  156. [-0.3077, -0.2931, -0.3992],
  157. [-0.3664, -0.3455, -0.2881],
  158. [-0.3343, -0.2701, -0.3956]])
  159. """
  160. def __init__(self, weight: Union[float, int] = 1) -> None:
  161. super().__init__(weight=weight)
  162. def __call__(self,
  163. pred_instances: InstanceData,
  164. gt_instances: InstanceData,
  165. img_meta: Optional[dict] = None,
  166. **kwargs) -> Tensor:
  167. """Compute match cost.
  168. Args:
  169. pred_instances (:obj:`InstanceData`): ``scores`` inside is
  170. predicted classification logits, of shape
  171. (num_queries, num_class).
  172. gt_instances (:obj:`InstanceData`): ``labels`` inside should have
  173. shape (num_gt, ).
  174. img_meta (Optional[dict]): _description_. Defaults to None.
  175. Returns:
  176. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  177. """
  178. pred_scores = pred_instances.scores
  179. gt_labels = gt_instances.labels
  180. pred_scores = pred_scores.softmax(-1)
  181. cls_cost = -pred_scores[:, gt_labels]
  182. return cls_cost * self.weight
  183. @TASK_UTILS.register_module()
  184. class FocalLossCost(BaseMatchCost):
  185. """FocalLossCost.
  186. Args:
  187. alpha (Union[float, int]): focal_loss alpha. Defaults to 0.25.
  188. gamma (Union[float, int]): focal_loss gamma. Defaults to 2.
  189. eps (float): Defaults to 1e-12.
  190. binary_input (bool): Whether the input is binary. Currently,
  191. binary_input = True is for masks input, binary_input = False
  192. is for label input. Defaults to False.
  193. weight (Union[float, int]): Cost weight. Defaults to 1.
  194. """
  195. def __init__(self,
  196. alpha: Union[float, int] = 0.25,
  197. gamma: Union[float, int] = 2,
  198. eps: float = 1e-12,
  199. binary_input: bool = False,
  200. weight: Union[float, int] = 1.) -> None:
  201. super().__init__(weight=weight)
  202. self.alpha = alpha
  203. self.gamma = gamma
  204. self.eps = eps
  205. self.binary_input = binary_input
  206. def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor:
  207. """
  208. Args:
  209. cls_pred (Tensor): Predicted classification logits, shape
  210. (num_queries, num_class).
  211. gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
  212. Returns:
  213. torch.Tensor: cls_cost value with weight
  214. """
  215. cls_pred = cls_pred.sigmoid()
  216. neg_cost = -(1 - cls_pred + self.eps).log() * (
  217. 1 - self.alpha) * cls_pred.pow(self.gamma)
  218. pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
  219. 1 - cls_pred).pow(self.gamma)
  220. cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
  221. return cls_cost * self.weight
  222. def _mask_focal_loss_cost(self, cls_pred, gt_labels) -> Tensor:
  223. """
  224. Args:
  225. cls_pred (Tensor): Predicted classification logits.
  226. in shape (num_queries, d1, ..., dn), dtype=torch.float32.
  227. gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
  228. dtype=torch.long. Labels should be binary.
  229. Returns:
  230. Tensor: Focal cost matrix with weight in shape\
  231. (num_queries, num_gt).
  232. """
  233. cls_pred = cls_pred.flatten(1)
  234. gt_labels = gt_labels.flatten(1).float()
  235. n = cls_pred.shape[1]
  236. cls_pred = cls_pred.sigmoid()
  237. neg_cost = -(1 - cls_pred + self.eps).log() * (
  238. 1 - self.alpha) * cls_pred.pow(self.gamma)
  239. pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
  240. 1 - cls_pred).pow(self.gamma)
  241. cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \
  242. torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
  243. return cls_cost / n * self.weight
  244. def __call__(self,
  245. pred_instances: InstanceData,
  246. gt_instances: InstanceData,
  247. img_meta: Optional[dict] = None,
  248. **kwargs) -> Tensor:
  249. """Compute match cost.
  250. Args:
  251. pred_instances (:obj:`InstanceData`): Predicted instances which
  252. must contain ``scores`` or ``masks``.
  253. gt_instances (:obj:`InstanceData`): Ground truth which must contain
  254. ``labels`` or ``mask``.
  255. img_meta (Optional[dict]): Image information. Defaults to None.
  256. Returns:
  257. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  258. """
  259. if self.binary_input:
  260. pred_masks = pred_instances.masks
  261. gt_masks = gt_instances.masks
  262. return self._mask_focal_loss_cost(pred_masks, gt_masks)
  263. else:
  264. pred_scores = pred_instances.scores
  265. gt_labels = gt_instances.labels
  266. return self._focal_loss_cost(pred_scores, gt_labels)
  267. @TASK_UTILS.register_module()
  268. class DiceCost(BaseMatchCost):
  269. """Cost of mask assignments based on dice losses.
  270. Args:
  271. pred_act (bool): Whether to apply sigmoid to mask_pred.
  272. Defaults to False.
  273. eps (float): Defaults to 1e-3.
  274. naive_dice (bool): If True, use the naive dice loss
  275. in which the power of the number in the denominator is
  276. the first power. If False, use the second power that
  277. is adopted by K-Net and SOLO. Defaults to True.
  278. weight (Union[float, int]): Cost weight. Defaults to 1.
  279. """
  280. def __init__(self,
  281. pred_act: bool = False,
  282. eps: float = 1e-3,
  283. naive_dice: bool = True,
  284. weight: Union[float, int] = 1.) -> None:
  285. super().__init__(weight=weight)
  286. self.pred_act = pred_act
  287. self.eps = eps
  288. self.naive_dice = naive_dice
  289. def _binary_mask_dice_loss(self, mask_preds: Tensor,
  290. gt_masks: Tensor) -> Tensor:
  291. """
  292. Args:
  293. mask_preds (Tensor): Mask prediction in shape (num_queries, *).
  294. gt_masks (Tensor): Ground truth in shape (num_gt, *)
  295. store 0 or 1, 0 for negative class and 1 for
  296. positive class.
  297. Returns:
  298. Tensor: Dice cost matrix in shape (num_queries, num_gt).
  299. """
  300. mask_preds = mask_preds.flatten(1)
  301. gt_masks = gt_masks.flatten(1).float()
  302. numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
  303. if self.naive_dice:
  304. denominator = mask_preds.sum(-1)[:, None] + \
  305. gt_masks.sum(-1)[None, :]
  306. else:
  307. denominator = mask_preds.pow(2).sum(1)[:, None] + \
  308. gt_masks.pow(2).sum(1)[None, :]
  309. loss = 1 - (numerator + self.eps) / (denominator + self.eps)
  310. return loss
  311. def __call__(self,
  312. pred_instances: InstanceData,
  313. gt_instances: InstanceData,
  314. img_meta: Optional[dict] = None,
  315. **kwargs) -> Tensor:
  316. """Compute match cost.
  317. Args:
  318. pred_instances (:obj:`InstanceData`): Predicted instances which
  319. must contain ``masks``.
  320. gt_instances (:obj:`InstanceData`): Ground truth which must contain
  321. ``mask``.
  322. img_meta (Optional[dict]): Image information. Defaults to None.
  323. Returns:
  324. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  325. """
  326. pred_masks = pred_instances.masks
  327. gt_masks = gt_instances.masks
  328. if self.pred_act:
  329. pred_masks = pred_masks.sigmoid()
  330. dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
  331. return dice_cost * self.weight
  332. @TASK_UTILS.register_module()
  333. class CrossEntropyLossCost(BaseMatchCost):
  334. """CrossEntropyLossCost.
  335. Args:
  336. use_sigmoid (bool): Whether the prediction uses sigmoid
  337. of softmax. Defaults to True.
  338. weight (Union[float, int]): Cost weight. Defaults to 1.
  339. """
  340. def __init__(self,
  341. use_sigmoid: bool = True,
  342. weight: Union[float, int] = 1.) -> None:
  343. super().__init__(weight=weight)
  344. self.use_sigmoid = use_sigmoid
  345. def _binary_cross_entropy(self, cls_pred: Tensor,
  346. gt_labels: Tensor) -> Tensor:
  347. """
  348. Args:
  349. cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
  350. (num_queries, *).
  351. gt_labels (Tensor): The learning label of prediction with
  352. shape (num_gt, *).
  353. Returns:
  354. Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
  355. """
  356. cls_pred = cls_pred.flatten(1).float()
  357. gt_labels = gt_labels.flatten(1).float()
  358. n = cls_pred.shape[1]
  359. pos = F.binary_cross_entropy_with_logits(
  360. cls_pred, torch.ones_like(cls_pred), reduction='none')
  361. neg = F.binary_cross_entropy_with_logits(
  362. cls_pred, torch.zeros_like(cls_pred), reduction='none')
  363. cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
  364. torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
  365. cls_cost = cls_cost / n
  366. return cls_cost
  367. def __call__(self,
  368. pred_instances: InstanceData,
  369. gt_instances: InstanceData,
  370. img_meta: Optional[dict] = None,
  371. **kwargs) -> Tensor:
  372. """Compute match cost.
  373. Args:
  374. pred_instances (:obj:`InstanceData`): Predicted instances which
  375. must contain ``scores`` or ``masks``.
  376. gt_instances (:obj:`InstanceData`): Ground truth which must contain
  377. ``labels`` or ``masks``.
  378. img_meta (Optional[dict]): Image information. Defaults to None.
  379. Returns:
  380. Tensor: Match Cost matrix of shape (num_preds, num_gts).
  381. """
  382. pred_masks = pred_instances.masks
  383. gt_masks = gt_instances.masks
  384. if self.use_sigmoid:
  385. cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
  386. else:
  387. raise NotImplementedError
  388. return cls_cost * self.weight