gaussian_focal_loss.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Union
  3. import torch.nn as nn
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from .utils import weight_reduce_loss, weighted_loss
  7. @weighted_loss
  8. def gaussian_focal_loss(pred: Tensor,
  9. gaussian_target: Tensor,
  10. alpha: float = 2.0,
  11. gamma: float = 4.0,
  12. pos_weight: float = 1.0,
  13. neg_weight: float = 1.0) -> Tensor:
  14. """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
  15. distribution.
  16. Args:
  17. pred (torch.Tensor): The prediction.
  18. gaussian_target (torch.Tensor): The learning target of the prediction
  19. in gaussian distribution.
  20. alpha (float, optional): A balanced form for Focal Loss.
  21. Defaults to 2.0.
  22. gamma (float, optional): The gamma for calculating the modulating
  23. factor. Defaults to 4.0.
  24. pos_weight(float): Positive sample loss weight. Defaults to 1.0.
  25. neg_weight(float): Negative sample loss weight. Defaults to 1.0.
  26. """
  27. eps = 1e-12
  28. pos_weights = gaussian_target.eq(1)
  29. neg_weights = (1 - gaussian_target).pow(gamma)
  30. pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
  31. neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
  32. return pos_weight * pos_loss + neg_weight * neg_loss
  33. def gaussian_focal_loss_with_pos_inds(
  34. pred: Tensor,
  35. gaussian_target: Tensor,
  36. pos_inds: Tensor,
  37. pos_labels: Tensor,
  38. alpha: float = 2.0,
  39. gamma: float = 4.0,
  40. pos_weight: float = 1.0,
  41. neg_weight: float = 1.0,
  42. reduction: str = 'mean',
  43. avg_factor: Optional[Union[int, float]] = None) -> Tensor:
  44. """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
  45. distribution.
  46. Note: The index with a value of 1 in ``gaussian_target`` in the
  47. ``gaussian_focal_loss`` function is a positive sample, but in
  48. ``gaussian_focal_loss_with_pos_inds`` the positive sample is passed
  49. in through the ``pos_inds`` parameter.
  50. Args:
  51. pred (torch.Tensor): The prediction. The shape is (N, num_classes).
  52. gaussian_target (torch.Tensor): The learning target of the prediction
  53. in gaussian distribution. The shape is (N, num_classes).
  54. pos_inds (torch.Tensor): The positive sample index.
  55. The shape is (M, ).
  56. pos_labels (torch.Tensor): The label corresponding to the positive
  57. sample index. The shape is (M, ).
  58. alpha (float, optional): A balanced form for Focal Loss.
  59. Defaults to 2.0.
  60. gamma (float, optional): The gamma for calculating the modulating
  61. factor. Defaults to 4.0.
  62. pos_weight(float): Positive sample loss weight. Defaults to 1.0.
  63. neg_weight(float): Negative sample loss weight. Defaults to 1.0.
  64. reduction (str): Options are "none", "mean" and "sum".
  65. Defaults to 'mean`.
  66. avg_factor (int, float, optional): Average factor that is used to
  67. average the loss. Defaults to None.
  68. """
  69. eps = 1e-12
  70. neg_weights = (1 - gaussian_target).pow(gamma)
  71. pos_pred_pix = pred[pos_inds]
  72. pos_pred = pos_pred_pix.gather(1, pos_labels.unsqueeze(1))
  73. pos_loss = -(pos_pred + eps).log() * (1 - pos_pred).pow(alpha)
  74. pos_loss = weight_reduce_loss(pos_loss, None, reduction, avg_factor)
  75. neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
  76. neg_loss = weight_reduce_loss(neg_loss, None, reduction, avg_factor)
  77. return pos_weight * pos_loss + neg_weight * neg_loss
  78. @MODELS.register_module()
  79. class GaussianFocalLoss(nn.Module):
  80. """GaussianFocalLoss is a variant of focal loss.
  81. More details can be found in the `paper
  82. <https://arxiv.org/abs/1808.01244>`_
  83. Code is modified from `kp_utils.py
  84. <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
  85. Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
  86. not 0/1 binary target.
  87. Args:
  88. alpha (float): Power of prediction.
  89. gamma (float): Power of target for negative samples.
  90. reduction (str): Options are "none", "mean" and "sum".
  91. loss_weight (float): Loss weight of current loss.
  92. pos_weight(float): Positive sample loss weight. Defaults to 1.0.
  93. neg_weight(float): Negative sample loss weight. Defaults to 1.0.
  94. """
  95. def __init__(self,
  96. alpha: float = 2.0,
  97. gamma: float = 4.0,
  98. reduction: str = 'mean',
  99. loss_weight: float = 1.0,
  100. pos_weight: float = 1.0,
  101. neg_weight: float = 1.0) -> None:
  102. super().__init__()
  103. self.alpha = alpha
  104. self.gamma = gamma
  105. self.reduction = reduction
  106. self.loss_weight = loss_weight
  107. self.pos_weight = pos_weight
  108. self.neg_weight = neg_weight
  109. def forward(self,
  110. pred: Tensor,
  111. target: Tensor,
  112. pos_inds: Optional[Tensor] = None,
  113. pos_labels: Optional[Tensor] = None,
  114. weight: Optional[Tensor] = None,
  115. avg_factor: Optional[Union[int, float]] = None,
  116. reduction_override: Optional[str] = None) -> Tensor:
  117. """Forward function.
  118. If you want to manually determine which positions are
  119. positive samples, you can set the pos_index and pos_label
  120. parameter. Currently, only the CenterNet update version uses
  121. the parameter.
  122. Args:
  123. pred (torch.Tensor): The prediction. The shape is (N, num_classes).
  124. target (torch.Tensor): The learning target of the prediction
  125. in gaussian distribution. The shape is (N, num_classes).
  126. pos_inds (torch.Tensor): The positive sample index.
  127. Defaults to None.
  128. pos_labels (torch.Tensor): The label corresponding to the positive
  129. sample index. Defaults to None.
  130. weight (torch.Tensor, optional): The weight of loss for each
  131. prediction. Defaults to None.
  132. avg_factor (int, float, optional): Average factor that is used to
  133. average the loss. Defaults to None.
  134. reduction_override (str, optional): The reduction method used to
  135. override the original reduction method of the loss.
  136. Defaults to None.
  137. """
  138. assert reduction_override in (None, 'none', 'mean', 'sum')
  139. reduction = (
  140. reduction_override if reduction_override else self.reduction)
  141. if pos_inds is not None:
  142. assert pos_labels is not None
  143. # Only used by centernet update version
  144. loss_reg = self.loss_weight * gaussian_focal_loss_with_pos_inds(
  145. pred,
  146. target,
  147. pos_inds,
  148. pos_labels,
  149. alpha=self.alpha,
  150. gamma=self.gamma,
  151. pos_weight=self.pos_weight,
  152. neg_weight=self.neg_weight,
  153. reduction=reduction,
  154. avg_factor=avg_factor)
  155. else:
  156. loss_reg = self.loss_weight * gaussian_focal_loss(
  157. pred,
  158. target,
  159. weight,
  160. alpha=self.alpha,
  161. gamma=self.gamma,
  162. pos_weight=self.pos_weight,
  163. neg_weight=self.neg_weight,
  164. reduction=reduction,
  165. avg_factor=avg_factor)
  166. return loss_reg