eqlv2_loss.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import logging
  3. from functools import partial
  4. from typing import Optional
  5. import torch
  6. import torch.distributed as dist
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from mmengine.logging import print_log
  10. from torch import Tensor
  11. from mmdet.registry import MODELS
  12. @MODELS.register_module()
  13. class EQLV2Loss(nn.Module):
  14. def __init__(self,
  15. use_sigmoid: bool = True,
  16. reduction: str = 'mean',
  17. class_weight: Optional[Tensor] = None,
  18. loss_weight: float = 1.0,
  19. num_classes: int = 1203,
  20. use_distributed: bool = False,
  21. mu: float = 0.8,
  22. alpha: float = 4.0,
  23. gamma: int = 12,
  24. vis_grad: bool = False,
  25. test_with_obj: bool = True) -> None:
  26. """`Equalization Loss v2 <https://arxiv.org/abs/2012.08548>`_
  27. Args:
  28. use_sigmoid (bool): EQLv2 uses the sigmoid function to transform
  29. the predicted logits to an estimated probability distribution.
  30. reduction (str, optional): The method used to reduce the loss into
  31. a scalar. Defaults to 'mean'.
  32. class_weight (Tensor, optional): The weight of loss for each
  33. prediction. Defaults to None.
  34. loss_weight (float, optional): The weight of the total EQLv2 loss.
  35. Defaults to 1.0.
  36. num_classes (int): 1203 for lvis v1.0, 1230 for lvis v0.5.
  37. use_distributed (bool, float): EQLv2 will calculate the gradients
  38. on all GPUs if there is any. Change to True if you are using
  39. distributed training. Default to False.
  40. mu (float, optional): Defaults to 0.8
  41. alpha (float, optional): A balance factor for the negative part of
  42. EQLV2 Loss. Defaults to 4.0.
  43. gamma (int, optional): The gamma for calculating the modulating
  44. factor. Defaults to 12.
  45. vis_grad (bool, optional): Default to False.
  46. test_with_obj (bool, optional): Default to True.
  47. Returns:
  48. None.
  49. """
  50. super().__init__()
  51. self.use_sigmoid = True
  52. self.reduction = reduction
  53. self.loss_weight = loss_weight
  54. self.class_weight = class_weight
  55. self.num_classes = num_classes
  56. self.group = True
  57. # cfg for eqlv2
  58. self.vis_grad = vis_grad
  59. self.mu = mu
  60. self.alpha = alpha
  61. self.gamma = gamma
  62. self.use_distributed = use_distributed
  63. # initial variables
  64. self.register_buffer('pos_grad', torch.zeros(self.num_classes))
  65. self.register_buffer('neg_grad', torch.zeros(self.num_classes))
  66. # At the beginning of training, we set a high value (eg. 100)
  67. # for the initial gradient ratio so that the weight for pos
  68. # gradients and neg gradients are 1.
  69. self.register_buffer('pos_neg', torch.ones(self.num_classes) * 100)
  70. self.test_with_obj = test_with_obj
  71. def _func(x, gamma, mu):
  72. return 1 / (1 + torch.exp(-gamma * (x - mu)))
  73. self.map_func = partial(_func, gamma=self.gamma, mu=self.mu)
  74. print_log(
  75. f'build EQL v2, gamma: {gamma}, mu: {mu}, alpha: {alpha}',
  76. logger='current',
  77. level=logging.DEBUG)
  78. def forward(self,
  79. cls_score: Tensor,
  80. label: Tensor,
  81. weight: Optional[Tensor] = None,
  82. avg_factor: Optional[int] = None,
  83. reduction_override: Optional[Tensor] = None) -> Tensor:
  84. """`Equalization Loss v2 <https://arxiv.org/abs/2012.08548>`_
  85. Args:
  86. cls_score (Tensor): The prediction with shape (N, C), C is the
  87. number of classes.
  88. label (Tensor): The ground truth label of the predicted target with
  89. shape (N, C), C is the number of classes.
  90. weight (Tensor, optional): The weight of loss for each prediction.
  91. Defaults to None.
  92. avg_factor (int, optional): Average factor that is used to average
  93. the loss. Defaults to None.
  94. reduction_override (str, optional): The reduction method used to
  95. override the original reduction method of the loss.
  96. Options are "none", "mean" and "sum".
  97. Returns:
  98. Tensor: The calculated loss
  99. """
  100. self.n_i, self.n_c = cls_score.size()
  101. self.gt_classes = label
  102. self.pred_class_logits = cls_score
  103. def expand_label(pred, gt_classes):
  104. target = pred.new_zeros(self.n_i, self.n_c)
  105. target[torch.arange(self.n_i), gt_classes] = 1
  106. return target
  107. target = expand_label(cls_score, label)
  108. pos_w, neg_w = self.get_weight(cls_score)
  109. weight = pos_w * target + neg_w * (1 - target)
  110. cls_loss = F.binary_cross_entropy_with_logits(
  111. cls_score, target, reduction='none')
  112. cls_loss = torch.sum(cls_loss * weight) / self.n_i
  113. self.collect_grad(cls_score.detach(), target.detach(), weight.detach())
  114. return self.loss_weight * cls_loss
  115. def get_channel_num(self, num_classes):
  116. num_channel = num_classes + 1
  117. return num_channel
  118. def get_activation(self, pred):
  119. pred = torch.sigmoid(pred)
  120. n_i, n_c = pred.size()
  121. bg_score = pred[:, -1].view(n_i, 1)
  122. if self.test_with_obj:
  123. pred[:, :-1] *= (1 - bg_score)
  124. return pred
  125. def collect_grad(self, pred, target, weight):
  126. prob = torch.sigmoid(pred)
  127. grad = target * (prob - 1) + (1 - target) * prob
  128. grad = torch.abs(grad)
  129. # do not collect grad for objectiveness branch [:-1]
  130. pos_grad = torch.sum(grad * target * weight, dim=0)[:-1]
  131. neg_grad = torch.sum(grad * (1 - target) * weight, dim=0)[:-1]
  132. if self.use_distributed:
  133. dist.all_reduce(pos_grad)
  134. dist.all_reduce(neg_grad)
  135. self.pos_grad += pos_grad
  136. self.neg_grad += neg_grad
  137. self.pos_neg = self.pos_grad / (self.neg_grad + 1e-10)
  138. def get_weight(self, pred):
  139. neg_w = torch.cat([self.map_func(self.pos_neg), pred.new_ones(1)])
  140. pos_w = 1 + self.alpha * (1 - neg_w)
  141. neg_w = neg_w.view(1, -1).expand(self.n_i, self.n_c)
  142. pos_w = pos_w.view(1, -1).expand(self.n_i, self.n_c)
  143. return pos_w, neg_w