l2_loss.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import numpy as np
  4. import torch
  5. from mmengine.model import BaseModule
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from .utils import weighted_loss
  9. @weighted_loss
  10. def l2_loss(pred: Tensor, target: Tensor) -> Tensor:
  11. """L2 loss.
  12. Args:
  13. pred (torch.Tensor): The prediction.
  14. target (torch.Tensor): The learning target of the prediction.
  15. Returns:
  16. torch.Tensor: Calculated loss
  17. """
  18. assert pred.size() == target.size()
  19. loss = torch.abs(pred - target)**2
  20. return loss
  21. @MODELS.register_module()
  22. class L2Loss(BaseModule):
  23. """L2 loss.
  24. Args:
  25. reduction (str, optional): The method to reduce the loss.
  26. Options are "none", "mean" and "sum".
  27. loss_weight (float, optional): The weight of loss.
  28. """
  29. def __init__(self,
  30. neg_pos_ub: int = -1,
  31. pos_margin: float = -1,
  32. neg_margin: float = -1,
  33. hard_mining: bool = False,
  34. reduction: str = 'mean',
  35. loss_weight: float = 1.0):
  36. super(L2Loss, self).__init__()
  37. self.neg_pos_ub = neg_pos_ub
  38. self.pos_margin = pos_margin
  39. self.neg_margin = neg_margin
  40. self.hard_mining = hard_mining
  41. self.reduction = reduction
  42. self.loss_weight = loss_weight
  43. def forward(self,
  44. pred: Tensor,
  45. target: Tensor,
  46. weight: Optional[Tensor] = None,
  47. avg_factor: Optional[float] = None,
  48. reduction_override: Optional[str] = None) -> Tensor:
  49. """Forward function.
  50. Args:
  51. pred (torch.Tensor): The prediction.
  52. target (torch.Tensor): The learning target of the prediction.
  53. weight (torch.Tensor, optional): The weight of loss for each
  54. prediction. Defaults to None.
  55. avg_factor (float, optional): Average factor that is used to
  56. average the loss. Defaults to None.
  57. reduction_override (str, optional): The reduction method used to
  58. override the original reduction method of the loss.
  59. Defaults to None.
  60. """
  61. assert reduction_override in (None, 'none', 'mean', 'sum')
  62. reduction = (
  63. reduction_override if reduction_override else self.reduction)
  64. pred, weight, avg_factor = self.update_weight(pred, target, weight,
  65. avg_factor)
  66. loss_bbox = self.loss_weight * l2_loss(
  67. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  68. return loss_bbox
  69. def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor,
  70. avg_factor: float) -> Tuple[Tensor, Tensor, float]:
  71. """Update the weight according to targets."""
  72. if weight is None:
  73. weight = target.new_ones(target.size())
  74. invalid_inds = weight <= 0
  75. target[invalid_inds] = -1
  76. pos_inds = target == 1
  77. neg_inds = target == 0
  78. if self.pos_margin > 0:
  79. pred[pos_inds] -= self.pos_margin
  80. if self.neg_margin > 0:
  81. pred[neg_inds] -= self.neg_margin
  82. pred = torch.clamp(pred, min=0, max=1)
  83. num_pos = int((target == 1).sum())
  84. num_neg = int((target == 0).sum())
  85. if self.neg_pos_ub > 0 and num_neg / (num_pos +
  86. 1e-6) > self.neg_pos_ub:
  87. num_neg = num_pos * self.neg_pos_ub
  88. neg_idx = torch.nonzero(target == 0, as_tuple=False)
  89. if self.hard_mining:
  90. costs = l2_loss(
  91. pred, target, reduction='none')[neg_idx[:, 0],
  92. neg_idx[:, 1]].detach()
  93. neg_idx = neg_idx[costs.topk(num_neg)[1], :]
  94. else:
  95. neg_idx = self.random_choice(neg_idx, num_neg)
  96. new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool()
  97. new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True
  98. invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds)
  99. weight[invalid_neg_inds] = 0
  100. avg_factor = (weight > 0).sum()
  101. return pred, weight, avg_factor
  102. @staticmethod
  103. def random_choice(gallery: Union[list, np.ndarray, Tensor],
  104. num: int) -> np.ndarray:
  105. """Random select some elements from the gallery.
  106. It seems that Pytorch's implementation is slower than numpy so we use
  107. numpy to randperm the indices.
  108. """
  109. assert len(gallery) >= num
  110. if isinstance(gallery, list):
  111. gallery = np.array(gallery)
  112. cands = np.arange(len(gallery))
  113. np.random.shuffle(cands)
  114. rand_inds = cands[:num]
  115. if not isinstance(gallery, np.ndarray):
  116. rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
  117. return gallery[rand_inds]