huber_loss.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from mmdet.models.losses.utils import weighted_loss
  7. from mmdet.registry import MODELS
  8. @weighted_loss
  9. def huber_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
  10. """Huber loss.
  11. Args:
  12. pred (Tensor): The prediction.
  13. target (Tensor): The learning target of the prediction.
  14. beta (float, optional): The threshold in the piecewise function.
  15. Defaults to 1.0.
  16. Returns:
  17. Tensor: Calculated loss
  18. """
  19. assert beta > 0
  20. if target.numel() == 0:
  21. return pred.sum() * 0
  22. assert pred.size() == target.size()
  23. diff = torch.abs(pred - target)
  24. loss = torch.where(diff < beta, 0.5 * diff * diff,
  25. beta * diff - 0.5 * beta * beta)
  26. return loss
  27. @MODELS.register_module()
  28. class HuberLoss(nn.Module):
  29. """Huber loss.
  30. Args:
  31. beta (float, optional): The threshold in the piecewise function.
  32. Defaults to 1.0.
  33. reduction (str, optional): The method to reduce the loss.
  34. Options are "none", "mean" and "sum". Defaults to "mean".
  35. loss_weight (float, optional): The weight of loss.
  36. """
  37. def __init__(self,
  38. beta: float = 1.0,
  39. reduction: str = 'mean',
  40. loss_weight: float = 1.0) -> None:
  41. super().__init__()
  42. self.beta = beta
  43. self.reduction = reduction
  44. self.loss_weight = loss_weight
  45. def forward(self,
  46. pred: Tensor,
  47. target: Tensor,
  48. weight: Optional[Tensor] = None,
  49. avg_factor: Optional[int] = None,
  50. reduction_override: Optional[str] = None,
  51. **kwargs) -> Tensor:
  52. """Forward function.
  53. Args:
  54. pred (Tensor): The prediction.
  55. target (Tensor): The learning target of the prediction.
  56. weight (Tensor, optional): The weight of loss for each
  57. prediction. Defaults to None.
  58. avg_factor (int, optional): Average factor that is used to average
  59. the loss. Defaults to None.
  60. reduction_override (str, optional): The reduction method used to
  61. override the original reduction method of the loss.
  62. Defaults to None.
  63. Returns:
  64. Tensor: Calculated loss
  65. """
  66. assert reduction_override in (None, 'none', 'mean', 'sum')
  67. reduction = (
  68. reduction_override if reduction_override else self.reduction)
  69. loss_bbox = self.loss_weight * huber_loss(
  70. pred,
  71. target,
  72. weight,
  73. beta=self.beta,
  74. reduction=reduction,
  75. avg_factor=avg_factor,
  76. **kwargs)
  77. return loss_bbox