focal_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
  6. from mmdet.registry import MODELS
  7. from .utils import weight_reduce_loss
  8. # This method is only for debugging
  9. def py_sigmoid_focal_loss(pred,
  10. target,
  11. weight=None,
  12. gamma=2.0,
  13. alpha=0.25,
  14. reduction='mean',
  15. avg_factor=None):
  16. """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
  17. Args:
  18. pred (torch.Tensor): The prediction with shape (N, C), C is the
  19. number of classes
  20. target (torch.Tensor): The learning label of the prediction.
  21. weight (torch.Tensor, optional): Sample-wise loss weight.
  22. gamma (float, optional): The gamma for calculating the modulating
  23. factor. Defaults to 2.0.
  24. alpha (float, optional): A balanced form for Focal Loss.
  25. Defaults to 0.25.
  26. reduction (str, optional): The method used to reduce the loss into
  27. a scalar. Defaults to 'mean'.
  28. avg_factor (int, optional): Average factor that is used to average
  29. the loss. Defaults to None.
  30. """
  31. pred_sigmoid = pred.sigmoid()
  32. target = target.type_as(pred)
  33. # Actually, pt here denotes (1 - pt) in the Focal Loss paper
  34. pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
  35. # Thus it's pt.pow(gamma) rather than (1 - pt).pow(gamma)
  36. focal_weight = (alpha * target + (1 - alpha) *
  37. (1 - target)) * pt.pow(gamma)
  38. loss = F.binary_cross_entropy_with_logits(
  39. pred, target, reduction='none') * focal_weight
  40. if weight is not None:
  41. if weight.shape != loss.shape:
  42. if weight.size(0) == loss.size(0):
  43. # For most cases, weight is of shape (num_priors, ),
  44. # which means it does not have the second axis num_class
  45. weight = weight.view(-1, 1)
  46. else:
  47. # Sometimes, weight per anchor per class is also needed. e.g.
  48. # in FSAF. But it may be flattened of shape
  49. # (num_priors x num_class, ), while loss is still of shape
  50. # (num_priors, num_class).
  51. assert weight.numel() == loss.numel()
  52. weight = weight.view(loss.size(0), -1)
  53. assert weight.ndim == loss.ndim
  54. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  55. return loss
  56. def py_focal_loss_with_prob(pred,
  57. target,
  58. weight=None,
  59. gamma=2.0,
  60. alpha=0.25,
  61. reduction='mean',
  62. avg_factor=None):
  63. """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
  64. Different from `py_sigmoid_focal_loss`, this function accepts probability
  65. as input.
  66. Args:
  67. pred (torch.Tensor): The prediction probability with shape (N, C),
  68. C is the number of classes.
  69. target (torch.Tensor): The learning label of the prediction.
  70. The target shape support (N,C) or (N,), (N,C) means one-hot form.
  71. weight (torch.Tensor, optional): Sample-wise loss weight.
  72. gamma (float, optional): The gamma for calculating the modulating
  73. factor. Defaults to 2.0.
  74. alpha (float, optional): A balanced form for Focal Loss.
  75. Defaults to 0.25.
  76. reduction (str, optional): The method used to reduce the loss into
  77. a scalar. Defaults to 'mean'.
  78. avg_factor (int, optional): Average factor that is used to average
  79. the loss. Defaults to None.
  80. """
  81. if pred.dim() != target.dim():
  82. num_classes = pred.size(1)
  83. target = F.one_hot(target, num_classes=num_classes + 1)
  84. target = target[:, :num_classes]
  85. target = target.type_as(pred)
  86. pt = (1 - pred) * target + pred * (1 - target)
  87. focal_weight = (alpha * target + (1 - alpha) *
  88. (1 - target)) * pt.pow(gamma)
  89. loss = F.binary_cross_entropy(
  90. pred, target, reduction='none') * focal_weight
  91. if weight is not None:
  92. if weight.shape != loss.shape:
  93. if weight.size(0) == loss.size(0):
  94. # For most cases, weight is of shape (num_priors, ),
  95. # which means it does not have the second axis num_class
  96. weight = weight.view(-1, 1)
  97. else:
  98. # Sometimes, weight per anchor per class is also needed. e.g.
  99. # in FSAF. But it may be flattened of shape
  100. # (num_priors x num_class, ), while loss is still of shape
  101. # (num_priors, num_class).
  102. assert weight.numel() == loss.numel()
  103. weight = weight.view(loss.size(0), -1)
  104. assert weight.ndim == loss.ndim
  105. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  106. return loss
  107. def sigmoid_focal_loss(pred,
  108. target,
  109. weight=None,
  110. gamma=2.0,
  111. alpha=0.25,
  112. reduction='mean',
  113. avg_factor=None):
  114. r"""A wrapper of cuda version `Focal Loss
  115. <https://arxiv.org/abs/1708.02002>`_.
  116. Args:
  117. pred (torch.Tensor): The prediction with shape (N, C), C is the number
  118. of classes.
  119. target (torch.Tensor): The learning label of the prediction.
  120. weight (torch.Tensor, optional): Sample-wise loss weight.
  121. gamma (float, optional): The gamma for calculating the modulating
  122. factor. Defaults to 2.0.
  123. alpha (float, optional): A balanced form for Focal Loss.
  124. Defaults to 0.25.
  125. reduction (str, optional): The method used to reduce the loss into
  126. a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
  127. avg_factor (int, optional): Average factor that is used to average
  128. the loss. Defaults to None.
  129. """
  130. # Function.apply does not accept keyword arguments, so the decorator
  131. # "weighted_loss" is not applicable
  132. loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
  133. alpha, None, 'none')
  134. if weight is not None:
  135. if weight.shape != loss.shape:
  136. if weight.size(0) == loss.size(0):
  137. # For most cases, weight is of shape (num_priors, ),
  138. # which means it does not have the second axis num_class
  139. weight = weight.view(-1, 1)
  140. else:
  141. # Sometimes, weight per anchor per class is also needed. e.g.
  142. # in FSAF. But it may be flattened of shape
  143. # (num_priors x num_class, ), while loss is still of shape
  144. # (num_priors, num_class).
  145. assert weight.numel() == loss.numel()
  146. weight = weight.view(loss.size(0), -1)
  147. assert weight.ndim == loss.ndim
  148. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  149. return loss
  150. @MODELS.register_module()
  151. class FocalLoss(nn.Module):
  152. def __init__(self,
  153. use_sigmoid=True,
  154. gamma=2.0,
  155. alpha=0.25,
  156. reduction='mean',
  157. loss_weight=1.0,
  158. activated=False):
  159. """`Focal Loss <https://arxiv.org/abs/1708.02002>`_
  160. Args:
  161. use_sigmoid (bool, optional): Whether to the prediction is
  162. used for sigmoid or softmax. Defaults to True.
  163. gamma (float, optional): The gamma for calculating the modulating
  164. factor. Defaults to 2.0.
  165. alpha (float, optional): A balanced form for Focal Loss.
  166. Defaults to 0.25.
  167. reduction (str, optional): The method used to reduce the loss into
  168. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  169. "sum".
  170. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  171. activated (bool, optional): Whether the input is activated.
  172. If True, it means the input has been activated and can be
  173. treated as probabilities. Else, it should be treated as logits.
  174. Defaults to False.
  175. """
  176. super(FocalLoss, self).__init__()
  177. assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
  178. self.use_sigmoid = use_sigmoid
  179. self.gamma = gamma
  180. self.alpha = alpha
  181. self.reduction = reduction
  182. self.loss_weight = loss_weight
  183. self.activated = activated
  184. def forward(self,
  185. pred,
  186. target,
  187. weight=None,
  188. avg_factor=None,
  189. reduction_override=None):
  190. """Forward function.
  191. Args:
  192. pred (torch.Tensor): The prediction.
  193. target (torch.Tensor): The learning label of the prediction.
  194. The target shape support (N,C) or (N,), (N,C) means
  195. one-hot form.
  196. weight (torch.Tensor, optional): The weight of loss for each
  197. prediction. Defaults to None.
  198. avg_factor (int, optional): Average factor that is used to average
  199. the loss. Defaults to None.
  200. reduction_override (str, optional): The reduction method used to
  201. override the original reduction method of the loss.
  202. Options are "none", "mean" and "sum".
  203. Returns:
  204. torch.Tensor: The calculated loss
  205. """
  206. assert reduction_override in (None, 'none', 'mean', 'sum')
  207. reduction = (
  208. reduction_override if reduction_override else self.reduction)
  209. if self.use_sigmoid:
  210. if self.activated:
  211. calculate_loss_func = py_focal_loss_with_prob
  212. else:
  213. if pred.dim() == target.dim():
  214. # this means that target is already in One-Hot form.
  215. calculate_loss_func = py_sigmoid_focal_loss
  216. elif torch.cuda.is_available() and pred.is_cuda:
  217. calculate_loss_func = sigmoid_focal_loss
  218. else:
  219. num_classes = pred.size(1)
  220. target = F.one_hot(target, num_classes=num_classes + 1)
  221. target = target[:, :num_classes]
  222. calculate_loss_func = py_sigmoid_focal_loss
  223. loss_cls = self.loss_weight * calculate_loss_func(
  224. pred,
  225. target,
  226. weight,
  227. gamma=self.gamma,
  228. alpha=self.alpha,
  229. reduction=reduction,
  230. avg_factor=avg_factor)
  231. else:
  232. raise NotImplementedError
  233. return loss_cls