smooth_l1_loss.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from .utils import weighted_loss
  8. @weighted_loss
  9. def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
  10. """Smooth L1 loss.
  11. Args:
  12. pred (Tensor): The prediction.
  13. target (Tensor): The learning target of the prediction.
  14. beta (float, optional): The threshold in the piecewise function.
  15. Defaults to 1.0.
  16. Returns:
  17. Tensor: Calculated loss
  18. """
  19. assert beta > 0
  20. if target.numel() == 0:
  21. return pred.sum() * 0
  22. assert pred.size() == target.size()
  23. diff = torch.abs(pred - target)
  24. loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
  25. diff - 0.5 * beta)
  26. return loss
  27. @weighted_loss
  28. def l1_loss(pred: Tensor, target: Tensor) -> Tensor:
  29. """L1 loss.
  30. Args:
  31. pred (Tensor): The prediction.
  32. target (Tensor): The learning target of the prediction.
  33. Returns:
  34. Tensor: Calculated loss
  35. """
  36. if target.numel() == 0:
  37. return pred.sum() * 0
  38. assert pred.size() == target.size()
  39. loss = torch.abs(pred - target)
  40. return loss
  41. @MODELS.register_module()
  42. class SmoothL1Loss(nn.Module):
  43. """Smooth L1 loss.
  44. Args:
  45. beta (float, optional): The threshold in the piecewise function.
  46. Defaults to 1.0.
  47. reduction (str, optional): The method to reduce the loss.
  48. Options are "none", "mean" and "sum". Defaults to "mean".
  49. loss_weight (float, optional): The weight of loss.
  50. """
  51. def __init__(self,
  52. beta: float = 1.0,
  53. reduction: str = 'mean',
  54. loss_weight: float = 1.0) -> None:
  55. super().__init__()
  56. self.beta = beta
  57. self.reduction = reduction
  58. self.loss_weight = loss_weight
  59. def forward(self,
  60. pred: Tensor,
  61. target: Tensor,
  62. weight: Optional[Tensor] = None,
  63. avg_factor: Optional[int] = None,
  64. reduction_override: Optional[str] = None,
  65. **kwargs) -> Tensor:
  66. """Forward function.
  67. Args:
  68. pred (Tensor): The prediction.
  69. target (Tensor): The learning target of the prediction.
  70. weight (Tensor, optional): The weight of loss for each
  71. prediction. Defaults to None.
  72. avg_factor (int, optional): Average factor that is used to average
  73. the loss. Defaults to None.
  74. reduction_override (str, optional): The reduction method used to
  75. override the original reduction method of the loss.
  76. Defaults to None.
  77. Returns:
  78. Tensor: Calculated loss
  79. """
  80. if weight is not None and not torch.any(weight > 0):
  81. if pred.dim() == weight.dim() + 1:
  82. weight = weight.unsqueeze(1)
  83. return (pred * weight).sum()
  84. assert reduction_override in (None, 'none', 'mean', 'sum')
  85. reduction = (
  86. reduction_override if reduction_override else self.reduction)
  87. loss_bbox = self.loss_weight * smooth_l1_loss(
  88. pred,
  89. target,
  90. weight,
  91. beta=self.beta,
  92. reduction=reduction,
  93. avg_factor=avg_factor,
  94. **kwargs)
  95. return loss_bbox
  96. @MODELS.register_module()
  97. class L1Loss(nn.Module):
  98. """L1 loss.
  99. Args:
  100. reduction (str, optional): The method to reduce the loss.
  101. Options are "none", "mean" and "sum".
  102. loss_weight (float, optional): The weight of loss.
  103. """
  104. def __init__(self,
  105. reduction: str = 'mean',
  106. loss_weight: float = 1.0) -> None:
  107. super().__init__()
  108. self.reduction = reduction
  109. self.loss_weight = loss_weight
  110. def forward(self,
  111. pred: Tensor,
  112. target: Tensor,
  113. weight: Optional[Tensor] = None,
  114. avg_factor: Optional[int] = None,
  115. reduction_override: Optional[str] = None) -> Tensor:
  116. """Forward function.
  117. Args:
  118. pred (Tensor): The prediction.
  119. target (Tensor): The learning target of the prediction.
  120. weight (Tensor, optional): The weight of loss for each
  121. prediction. Defaults to None.
  122. avg_factor (int, optional): Average factor that is used to average
  123. the loss. Defaults to None.
  124. reduction_override (str, optional): The reduction method used to
  125. override the original reduction method of the loss.
  126. Defaults to None.
  127. Returns:
  128. Tensor: Calculated loss
  129. """
  130. if weight is not None and not torch.any(weight > 0):
  131. if pred.dim() == weight.dim() + 1:
  132. weight = weight.unsqueeze(1)
  133. return (pred * weight).sum()
  134. assert reduction_override in (None, 'none', 'mean', 'sum')
  135. reduction = (
  136. reduction_override if reduction_override else self.reduction)
  137. loss_bbox = self.loss_weight * l1_loss(
  138. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  139. return loss_bbox