varifocal_loss.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from .utils import weight_reduce_loss
  8. def varifocal_loss(pred: Tensor,
  9. target: Tensor,
  10. weight: Optional[Tensor] = None,
  11. alpha: float = 0.75,
  12. gamma: float = 2.0,
  13. iou_weighted: bool = True,
  14. with_logits: bool=False,
  15. reduction: str = 'mean',
  16. avg_factor: Optional[int] = None) -> Tensor:
  17. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  18. Args:
  19. pred (Tensor): The prediction with shape (N, C), C is the
  20. number of classes.
  21. target (Tensor): The learning target of the iou-aware
  22. classification score with shape (N, C), C is the number of classes.
  23. weight (Tensor, optional): The weight of loss for each
  24. prediction. Defaults to None.
  25. alpha (float, optional): A balance factor for the negative part of
  26. Varifocal Loss, which is different from the alpha of Focal Loss.
  27. Defaults to 0.75.
  28. gamma (float, optional): The gamma for calculating the modulating
  29. factor. Defaults to 2.0.
  30. iou_weighted (bool, optional): Whether to weight the loss of the
  31. positive example with the iou target. Defaults to True.
  32. reduction (str, optional): The method used to reduce the loss into
  33. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  34. "sum".
  35. avg_factor (int, optional): Average factor that is used to average
  36. the loss. Defaults to None.
  37. Returns:
  38. Tensor: Loss tensor.
  39. """
  40. # pred and target should be of the same size
  41. assert pred.size() == target.size()
  42. pred_sigmoid = pred.sigmoid()
  43. target = target.type_as(pred)
  44. if iou_weighted and with_logits:
  45. focal_weight = target * (target > 0.0).float() + \
  46. alpha * (pred_sigmoid).abs().pow(gamma) * \
  47. (target <= 0.0).float()
  48. elif iou_weighted:
  49. focal_weight = target * (target > 0.0).float() + \
  50. alpha * (pred_sigmoid-target).abs().pow(gamma) * \
  51. (target <= 0.0).float()
  52. else:
  53. focal_weight = (target > 0.0).float() + \
  54. alpha * (pred_sigmoid - target).abs().pow(gamma) * \
  55. (target <= 0.0).float()
  56. loss = F.binary_cross_entropy_with_logits(
  57. pred, target, reduction='none') * focal_weight
  58. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  59. return loss
  60. @MODELS.register_module()
  61. class VarifocalLoss(nn.Module):
  62. def __init__(self,
  63. use_sigmoid: bool = True,
  64. alpha: float = 0.75,
  65. gamma: float = 2.0,
  66. iou_weighted: bool = True,
  67. with_logits:bool=False,
  68. reduction: str = 'mean',
  69. loss_weight: float = 1.0) -> None:
  70. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  71. Args:
  72. use_sigmoid (bool, optional): Whether the prediction is
  73. used for sigmoid or softmax. Defaults to True.
  74. alpha (float, optional): A balance factor for the negative part of
  75. Varifocal Loss, which is different from the alpha of Focal
  76. Loss. Defaults to 0.75.
  77. gamma (float, optional): The gamma for calculating the modulating
  78. factor. Defaults to 2.0.
  79. iou_weighted (bool, optional): Whether to weight the loss of the
  80. positive examples with the iou target. Defaults to True.
  81. reduction (str, optional): The method used to reduce the loss into
  82. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  83. "sum".
  84. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  85. """
  86. super().__init__()
  87. assert use_sigmoid is True, \
  88. 'Only sigmoid varifocal loss supported now.'
  89. assert alpha >= 0.0
  90. self.use_sigmoid = use_sigmoid
  91. self.alpha = alpha
  92. self.gamma = gamma
  93. self.iou_weighted = iou_weighted
  94. self.reduction = reduction
  95. self.loss_weight = loss_weight
  96. self.with_logits=with_logits
  97. def forward(self,
  98. pred: Tensor,
  99. target: Tensor,
  100. weight: Optional[Tensor] = None,
  101. avg_factor: Optional[int] = None,
  102. reduction_override: Optional[str] = None) -> Tensor:
  103. """Forward function.
  104. Args:
  105. pred (Tensor): The prediction with shape (N, C), C is the
  106. number of classes.
  107. target (Tensor): The learning target of the iou-aware
  108. classification score with shape (N, C), C is
  109. the number of classes.
  110. weight (Tensor, optional): The weight of loss for each
  111. prediction. Defaults to None.
  112. avg_factor (int, optional): Average factor that is used to average
  113. the loss. Defaults to None.
  114. reduction_override (str, optional): The reduction method used to
  115. override the original reduction method of the loss.
  116. Options are "none", "mean" and "sum".
  117. Returns:
  118. Tensor: The calculated loss
  119. """
  120. assert reduction_override in (None, 'none', 'mean', 'sum')
  121. reduction = (
  122. reduction_override if reduction_override else self.reduction)
  123. if self.use_sigmoid:
  124. loss_cls = self.loss_weight * varifocal_loss(
  125. pred,
  126. target,
  127. weight,
  128. alpha=self.alpha,
  129. gamma=self.gamma,
  130. iou_weighted=self.iou_weighted,
  131. reduction=reduction,
  132. with_logits=self.with_logits,
  133. avg_factor=avg_factor)
  134. else:
  135. raise NotImplementedError
  136. return loss_cls