margin_loss.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import numpy as np
  4. import torch
  5. from mmengine.model import BaseModule
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from .mse_loss import mse_loss
  9. @MODELS.register_module()
  10. class MarginL2Loss(BaseModule):
  11. """L2 loss with margin.
  12. Args:
  13. neg_pos_ub (int, optional): The upper bound of negative to positive
  14. samples in hard mining. Defaults to -1.
  15. pos_margin (float, optional): The similarity margin for positive
  16. samples in hard mining. Defaults to -1.
  17. neg_margin (float, optional): The similarity margin for negative
  18. samples in hard mining. Defaults to -1.
  19. hard_mining (bool, optional): Whether to use hard mining. Defaults to
  20. False.
  21. reduction (str, optional): The method to reduce the loss.
  22. Options are "none", "mean" and "sum". Defaults to "mean".
  23. loss_weight (float, optional): The weight of loss. Defaults to 1.0.
  24. """
  25. def __init__(self,
  26. neg_pos_ub: int = -1,
  27. pos_margin: float = -1,
  28. neg_margin: float = -1,
  29. hard_mining: bool = False,
  30. reduction: str = 'mean',
  31. loss_weight: float = 1.0):
  32. super(MarginL2Loss, self).__init__()
  33. self.neg_pos_ub = neg_pos_ub
  34. self.pos_margin = pos_margin
  35. self.neg_margin = neg_margin
  36. self.hard_mining = hard_mining
  37. self.reduction = reduction
  38. self.loss_weight = loss_weight
  39. def forward(self,
  40. pred: Tensor,
  41. target: Tensor,
  42. weight: Optional[Tensor] = None,
  43. avg_factor: Optional[float] = None,
  44. reduction_override: Optional[str] = None) -> Tensor:
  45. """Forward function.
  46. Args:
  47. pred (torch.Tensor): The prediction.
  48. target (torch.Tensor): The learning target of the prediction.
  49. weight (torch.Tensor, optional): The weight of loss for each
  50. prediction. Defaults to None.
  51. avg_factor (float, optional): Average factor that is used to
  52. average the loss. Defaults to None.
  53. reduction_override (str, optional): The reduction method used to
  54. override the original reduction method of the loss.
  55. Defaults to None.
  56. """
  57. assert reduction_override in (None, 'none', 'mean', 'sum')
  58. reduction = (
  59. reduction_override if reduction_override else self.reduction)
  60. pred, weight, avg_factor = self.update_weight(pred, target, weight,
  61. avg_factor)
  62. loss_bbox = self.loss_weight * mse_loss(
  63. pred,
  64. target.float(),
  65. weight.float(),
  66. reduction=reduction,
  67. avg_factor=avg_factor)
  68. return loss_bbox
  69. def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor,
  70. avg_factor: float) -> Tuple[Tensor, Tensor, float]:
  71. """Update the weight according to targets.
  72. Args:
  73. pred (torch.Tensor): The prediction.
  74. target (torch.Tensor): The learning target of the prediction.
  75. weight (torch.Tensor): The weight of loss for each prediction.
  76. avg_factor (float): Average factor that is used to average the
  77. loss.
  78. Returns:
  79. tuple[torch.Tensor]: The updated prediction, weight and average
  80. factor.
  81. """
  82. if weight is None:
  83. weight = target.new_ones(target.size())
  84. invalid_inds = weight <= 0
  85. target[invalid_inds] = -1
  86. pos_inds = target == 1
  87. neg_inds = target == 0
  88. if self.pos_margin > 0:
  89. pred[pos_inds] -= self.pos_margin
  90. if self.neg_margin > 0:
  91. pred[neg_inds] -= self.neg_margin
  92. pred = torch.clamp(pred, min=0, max=1)
  93. num_pos = int((target == 1).sum())
  94. num_neg = int((target == 0).sum())
  95. if self.neg_pos_ub > 0 and num_neg / (num_pos +
  96. 1e-6) > self.neg_pos_ub:
  97. num_neg = num_pos * self.neg_pos_ub
  98. neg_idx = torch.nonzero(target == 0, as_tuple=False)
  99. if self.hard_mining:
  100. costs = mse_loss(
  101. pred, target.float(),
  102. reduction='none')[neg_idx[:, 0], neg_idx[:, 1]].detach()
  103. neg_idx = neg_idx[costs.topk(num_neg)[1], :]
  104. else:
  105. neg_idx = self.random_choice(neg_idx, num_neg)
  106. new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool()
  107. new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True
  108. invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds)
  109. weight[invalid_neg_inds] = 0
  110. avg_factor = (weight > 0).sum()
  111. return pred, weight, avg_factor
  112. @staticmethod
  113. def random_choice(gallery: Union[list, np.ndarray, Tensor],
  114. num: int) -> np.ndarray:
  115. """Random select some elements from the gallery.
  116. It seems that Pytorch's implementation is slower than numpy so we use
  117. numpy to randperm the indices.
  118. Args:
  119. gallery (list | np.ndarray | torch.Tensor): The gallery from
  120. which to sample.
  121. num (int): The number of elements to sample.
  122. """
  123. assert len(gallery) >= num
  124. if isinstance(gallery, list):
  125. gallery = np.array(gallery)
  126. cands = np.arange(len(gallery))
  127. np.random.shuffle(cands)
  128. rand_inds = cands[:num]
  129. if not isinstance(gallery, np.ndarray):
  130. rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
  131. return gallery[rand_inds]