iou_loss.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. from typing import Optional
  5. import torch
  6. import torch.nn as nn
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures.bbox import bbox_overlaps
  10. from .utils import weighted_loss
  11. @weighted_loss
  12. def iou_loss(pred: Tensor,
  13. target: Tensor,
  14. linear: bool = False,
  15. mode: str = 'log',
  16. eps: float = 1e-6) -> Tensor:
  17. """IoU loss.
  18. Computing the IoU loss between a set of predicted bboxes and target bboxes.
  19. The loss is calculated as negative log of IoU.
  20. Args:
  21. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  22. shape (n, 4).
  23. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  24. linear (bool, optional): If True, use linear scale of loss instead of
  25. log scale. Default: False.
  26. mode (str): Loss scaling mode, including "linear", "square", and "log".
  27. Default: 'log'
  28. eps (float): Epsilon to avoid log(0).
  29. Return:
  30. Tensor: Loss tensor.
  31. """
  32. assert mode in ['linear', 'square', 'log']
  33. if linear:
  34. mode = 'linear'
  35. warnings.warn('DeprecationWarning: Setting "linear=True" in '
  36. 'iou_loss is deprecated, please use "mode=`linear`" '
  37. 'instead.')
  38. ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
  39. if mode == 'linear':
  40. loss = 1 - ious
  41. elif mode == 'square':
  42. loss = 1 - ious**2
  43. elif mode == 'log':
  44. loss = -ious.log()
  45. else:
  46. raise NotImplementedError
  47. return loss
  48. @weighted_loss
  49. def bounded_iou_loss(pred: Tensor,
  50. target: Tensor,
  51. beta: float = 0.2,
  52. eps: float = 1e-3) -> Tensor:
  53. """BIoULoss.
  54. This is an implementation of paper
  55. `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
  56. <https://arxiv.org/abs/1711.00164>`_.
  57. Args:
  58. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  59. shape (n, 4).
  60. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  61. beta (float, optional): Beta parameter in smoothl1.
  62. eps (float, optional): Epsilon to avoid NaN values.
  63. Return:
  64. Tensor: Loss tensor.
  65. """
  66. pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
  67. pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
  68. pred_w = pred[:, 2] - pred[:, 0]
  69. pred_h = pred[:, 3] - pred[:, 1]
  70. with torch.no_grad():
  71. target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
  72. target_ctry = (target[:, 1] + target[:, 3]) * 0.5
  73. target_w = target[:, 2] - target[:, 0]
  74. target_h = target[:, 3] - target[:, 1]
  75. dx = target_ctrx - pred_ctrx
  76. dy = target_ctry - pred_ctry
  77. loss_dx = 1 - torch.max(
  78. (target_w - 2 * dx.abs()) /
  79. (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
  80. loss_dy = 1 - torch.max(
  81. (target_h - 2 * dy.abs()) /
  82. (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
  83. loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
  84. (target_w + eps))
  85. loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
  86. (target_h + eps))
  87. # view(..., -1) does not work for empty tensor
  88. loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
  89. dim=-1).flatten(1)
  90. loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
  91. loss_comb - 0.5 * beta)
  92. return loss
  93. @weighted_loss
  94. def giou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
  95. r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
  96. Box Regression <https://arxiv.org/abs/1902.09630>`_.
  97. Args:
  98. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  99. shape (n, 4).
  100. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  101. eps (float): Epsilon to avoid log(0).
  102. Return:
  103. Tensor: Loss tensor.
  104. """
  105. gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps)
  106. loss = 1 - gious
  107. return loss
  108. @weighted_loss
  109. def diou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
  110. r"""Implementation of `Distance-IoU Loss: Faster and Better
  111. Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_.
  112. Code is modified from https://github.com/Zzh-tju/DIoU.
  113. Args:
  114. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  115. shape (n, 4).
  116. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  117. eps (float): Epsilon to avoid log(0).
  118. Return:
  119. Tensor: Loss tensor.
  120. """
  121. # overlap
  122. lt = torch.max(pred[:, :2], target[:, :2])
  123. rb = torch.min(pred[:, 2:], target[:, 2:])
  124. wh = (rb - lt).clamp(min=0)
  125. overlap = wh[:, 0] * wh[:, 1]
  126. # union
  127. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  128. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  129. union = ap + ag - overlap + eps
  130. # IoU
  131. ious = overlap / union
  132. # enclose area
  133. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  134. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  135. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
  136. cw = enclose_wh[:, 0]
  137. ch = enclose_wh[:, 1]
  138. c2 = cw**2 + ch**2 + eps
  139. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  140. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  141. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  142. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  143. left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
  144. right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
  145. rho2 = left + right
  146. # DIoU
  147. dious = ious - rho2 / c2
  148. loss = 1 - dious
  149. return loss
  150. @weighted_loss
  151. def ciou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
  152. r"""`Implementation of paper `Enhancing Geometric Factors into
  153. Model Learning and Inference for Object Detection and Instance
  154. Segmentation <https://arxiv.org/abs/2005.03572>`_.
  155. Code is modified from https://github.com/Zzh-tju/CIoU.
  156. Args:
  157. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  158. shape (n, 4).
  159. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  160. eps (float): Epsilon to avoid log(0).
  161. Return:
  162. Tensor: Loss tensor.
  163. """
  164. # overlap
  165. lt = torch.max(pred[:, :2], target[:, :2])
  166. rb = torch.min(pred[:, 2:], target[:, 2:])
  167. wh = (rb - lt).clamp(min=0)
  168. overlap = wh[:, 0] * wh[:, 1]
  169. # union
  170. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  171. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  172. union = ap + ag - overlap + eps
  173. # IoU
  174. ious = overlap / union
  175. # enclose area
  176. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  177. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  178. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
  179. cw = enclose_wh[:, 0]
  180. ch = enclose_wh[:, 1]
  181. c2 = cw**2 + ch**2 + eps
  182. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  183. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  184. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  185. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  186. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  187. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  188. left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
  189. right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
  190. rho2 = left + right
  191. factor = 4 / math.pi**2
  192. v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  193. with torch.no_grad():
  194. alpha = (ious > 0.5).float() * v / (1 - ious + v)
  195. # CIoU
  196. cious = ious - (rho2 / c2 + alpha * v)
  197. loss = 1 - cious.clamp(min=-1.0, max=1.0)
  198. return loss
  199. @weighted_loss
  200. def eiou_loss(pred: Tensor,
  201. target: Tensor,
  202. smooth_point: float = 0.1,
  203. eps: float = 1e-7) -> Tensor:
  204. r"""Implementation of paper `Extended-IoU Loss: A Systematic
  205. IoU-Related Method: Beyond Simplified Regression for Better
  206. Localization <https://ieeexplore.ieee.org/abstract/document/9429909>`_
  207. Code is modified from https://github.com//ShiqiYu/libfacedetection.train.
  208. Args:
  209. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  210. shape (n, 4).
  211. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  212. smooth_point (float): hyperparameter, default is 0.1.
  213. eps (float): Epsilon to avoid log(0).
  214. Return:
  215. Tensor: Loss tensor.
  216. """
  217. px1, py1, px2, py2 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
  218. tx1, ty1, tx2, ty2 = target[:, 0], target[:, 1], target[:, 2], target[:, 3]
  219. # extent top left
  220. ex1 = torch.min(px1, tx1)
  221. ey1 = torch.min(py1, ty1)
  222. # intersection coordinates
  223. ix1 = torch.max(px1, tx1)
  224. iy1 = torch.max(py1, ty1)
  225. ix2 = torch.min(px2, tx2)
  226. iy2 = torch.min(py2, ty2)
  227. # extra
  228. xmin = torch.min(ix1, ix2)
  229. ymin = torch.min(iy1, iy2)
  230. xmax = torch.max(ix1, ix2)
  231. ymax = torch.max(iy1, iy2)
  232. # Intersection
  233. intersection = (ix2 - ex1) * (iy2 - ey1) + (xmin - ex1) * (ymin - ey1) - (
  234. ix1 - ex1) * (ymax - ey1) - (xmax - ex1) * (
  235. iy1 - ey1)
  236. # Union
  237. union = (px2 - px1) * (py2 - py1) + (tx2 - tx1) * (
  238. ty2 - ty1) - intersection + eps
  239. # IoU
  240. ious = 1 - (intersection / union)
  241. # Smooth-EIoU
  242. smooth_sign = (ious < smooth_point).detach().float()
  243. loss = 0.5 * smooth_sign * (ious**2) / smooth_point + (1 - smooth_sign) * (
  244. ious - 0.5 * smooth_point)
  245. return loss
  246. @weighted_loss
  247. def siou_loss(pred, target, eps=1e-7, neg_gamma=False):
  248. r"""`Implementation of paper `SIoU Loss: More Powerful Learning
  249. for Bounding Box Regression <https://arxiv.org/abs/2205.12740>`_.
  250. Code is modified from https://github.com/meituan/YOLOv6.
  251. Args:
  252. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  253. shape (n, 4).
  254. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  255. eps (float): Eps to avoid log(0).
  256. neg_gamma (bool): `True` follows original implementation in paper.
  257. Return:
  258. Tensor: Loss tensor.
  259. """
  260. # overlap
  261. lt = torch.max(pred[:, :2], target[:, :2])
  262. rb = torch.min(pred[:, 2:], target[:, 2:])
  263. wh = (rb - lt).clamp(min=0)
  264. overlap = wh[:, 0] * wh[:, 1]
  265. # union
  266. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  267. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  268. union = ap + ag - overlap + eps
  269. # IoU
  270. ious = overlap / union
  271. # enclose area
  272. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  273. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  274. # modified clamp threshold zero to eps to avoid NaN
  275. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=eps)
  276. cw = enclose_wh[:, 0]
  277. ch = enclose_wh[:, 1]
  278. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  279. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  280. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  281. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  282. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  283. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  284. # angle cost
  285. s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
  286. s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
  287. sigma = torch.pow(s_cw**2 + s_ch**2, 0.5)
  288. sin_alpha_1 = torch.abs(s_cw) / sigma
  289. sin_alpha_2 = torch.abs(s_ch) / sigma
  290. threshold = pow(2, 0.5) / 2
  291. sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
  292. angle_cost = torch.cos(torch.asin(sin_alpha) * 2 - math.pi / 2)
  293. # distance cost
  294. rho_x = (s_cw / cw)**2
  295. rho_y = (s_ch / ch)**2
  296. # `neg_gamma=True` follows original implementation in paper
  297. # but setting `neg_gamma=False` makes training more stable.
  298. gamma = angle_cost - 2 if neg_gamma else 2 - angle_cost
  299. distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
  300. # shape cost
  301. omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
  302. omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
  303. shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(
  304. 1 - torch.exp(-1 * omiga_h), 4)
  305. # SIoU
  306. sious = ious - 0.5 * (distance_cost + shape_cost)
  307. loss = 1 - sious.clamp(min=-1.0, max=1.0)
  308. return loss
  309. @MODELS.register_module()
  310. class IoULoss(nn.Module):
  311. """IoULoss.
  312. Computing the IoU loss between a set of predicted bboxes and target bboxes.
  313. Args:
  314. linear (bool): If True, use linear scale of loss else determined
  315. by mode. Default: False.
  316. eps (float): Epsilon to avoid log(0).
  317. reduction (str): Options are "none", "mean" and "sum".
  318. loss_weight (float): Weight of loss.
  319. mode (str): Loss scaling mode, including "linear", "square", and "log".
  320. Default: 'log'
  321. """
  322. def __init__(self,
  323. linear: bool = False,
  324. eps: float = 1e-6,
  325. reduction: str = 'mean',
  326. loss_weight: float = 1.0,
  327. mode: str = 'log') -> None:
  328. super().__init__()
  329. assert mode in ['linear', 'square', 'log']
  330. if linear:
  331. mode = 'linear'
  332. warnings.warn('DeprecationWarning: Setting "linear=True" in '
  333. 'IOULoss is deprecated, please use "mode=`linear`" '
  334. 'instead.')
  335. self.mode = mode
  336. self.linear = linear
  337. self.eps = eps
  338. self.reduction = reduction
  339. self.loss_weight = loss_weight
  340. def forward(self,
  341. pred: Tensor,
  342. target: Tensor,
  343. weight: Optional[Tensor] = None,
  344. avg_factor: Optional[int] = None,
  345. reduction_override: Optional[str] = None,
  346. **kwargs) -> Tensor:
  347. """Forward function.
  348. Args:
  349. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  350. shape (n, 4).
  351. target (Tensor): The learning target of the prediction,
  352. shape (n, 4).
  353. weight (Tensor, optional): The weight of loss for each
  354. prediction. Defaults to None.
  355. avg_factor (int, optional): Average factor that is used to average
  356. the loss. Defaults to None.
  357. reduction_override (str, optional): The reduction method used to
  358. override the original reduction method of the loss.
  359. Defaults to None. Options are "none", "mean" and "sum".
  360. Return:
  361. Tensor: Loss tensor.
  362. """
  363. assert reduction_override in (None, 'none', 'mean', 'sum')
  364. reduction = (
  365. reduction_override if reduction_override else self.reduction)
  366. if (weight is not None) and (not torch.any(weight > 0)) and (
  367. reduction != 'none'):
  368. if pred.dim() == weight.dim() + 1:
  369. weight = weight.unsqueeze(1)
  370. return (pred * weight).sum() # 0
  371. if weight is not None and weight.dim() > 1:
  372. # TODO: remove this in the future
  373. # reduce the weight of shape (n, 4) to (n,) to match the
  374. # iou_loss of shape (n,)
  375. assert weight.shape == pred.shape
  376. weight = weight.mean(-1)
  377. loss = self.loss_weight * iou_loss(
  378. pred,
  379. target,
  380. weight,
  381. mode=self.mode,
  382. eps=self.eps,
  383. reduction=reduction,
  384. avg_factor=avg_factor,
  385. **kwargs)
  386. return loss
  387. @MODELS.register_module()
  388. class BoundedIoULoss(nn.Module):
  389. """BIoULoss.
  390. This is an implementation of paper
  391. `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
  392. <https://arxiv.org/abs/1711.00164>`_.
  393. Args:
  394. beta (float, optional): Beta parameter in smoothl1.
  395. eps (float, optional): Epsilon to avoid NaN values.
  396. reduction (str): Options are "none", "mean" and "sum".
  397. loss_weight (float): Weight of loss.
  398. """
  399. def __init__(self,
  400. beta: float = 0.2,
  401. eps: float = 1e-3,
  402. reduction: str = 'mean',
  403. loss_weight: float = 1.0) -> None:
  404. super().__init__()
  405. self.beta = beta
  406. self.eps = eps
  407. self.reduction = reduction
  408. self.loss_weight = loss_weight
  409. def forward(self,
  410. pred: Tensor,
  411. target: Tensor,
  412. weight: Optional[Tensor] = None,
  413. avg_factor: Optional[int] = None,
  414. reduction_override: Optional[str] = None,
  415. **kwargs) -> Tensor:
  416. """Forward function.
  417. Args:
  418. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  419. shape (n, 4).
  420. target (Tensor): The learning target of the prediction,
  421. shape (n, 4).
  422. weight (Optional[Tensor], optional): The weight of loss for each
  423. prediction. Defaults to None.
  424. avg_factor (Optional[int], optional): Average factor that is used
  425. to average the loss. Defaults to None.
  426. reduction_override (Optional[str], optional): The reduction method
  427. used to override the original reduction method of the loss.
  428. Defaults to None. Options are "none", "mean" and "sum".
  429. Returns:
  430. Tensor: Loss tensor.
  431. """
  432. if weight is not None and not torch.any(weight > 0):
  433. if pred.dim() == weight.dim() + 1:
  434. weight = weight.unsqueeze(1)
  435. return (pred * weight).sum() # 0
  436. assert reduction_override in (None, 'none', 'mean', 'sum')
  437. reduction = (
  438. reduction_override if reduction_override else self.reduction)
  439. loss = self.loss_weight * bounded_iou_loss(
  440. pred,
  441. target,
  442. weight,
  443. beta=self.beta,
  444. eps=self.eps,
  445. reduction=reduction,
  446. avg_factor=avg_factor,
  447. **kwargs)
  448. return loss
  449. @MODELS.register_module()
  450. class GIoULoss(nn.Module):
  451. r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
  452. Box Regression <https://arxiv.org/abs/1902.09630>`_.
  453. Args:
  454. eps (float): Epsilon to avoid log(0).
  455. reduction (str): Options are "none", "mean" and "sum".
  456. loss_weight (float): Weight of loss.
  457. """
  458. def __init__(self,
  459. eps: float = 1e-6,
  460. reduction: str = 'mean',
  461. loss_weight: float = 1.0) -> None:
  462. super().__init__()
  463. self.eps = eps
  464. self.reduction = reduction
  465. self.loss_weight = loss_weight
  466. def forward(self,
  467. pred: Tensor,
  468. target: Tensor,
  469. weight: Optional[Tensor] = None,
  470. avg_factor: Optional[int] = None,
  471. reduction_override: Optional[str] = None,
  472. **kwargs) -> Tensor:
  473. """Forward function.
  474. Args:
  475. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  476. shape (n, 4).
  477. target (Tensor): The learning target of the prediction,
  478. shape (n, 4).
  479. weight (Optional[Tensor], optional): The weight of loss for each
  480. prediction. Defaults to None.
  481. avg_factor (Optional[int], optional): Average factor that is used
  482. to average the loss. Defaults to None.
  483. reduction_override (Optional[str], optional): The reduction method
  484. used to override the original reduction method of the loss.
  485. Defaults to None. Options are "none", "mean" and "sum".
  486. Returns:
  487. Tensor: Loss tensor.
  488. """
  489. if weight is not None and not torch.any(weight > 0):
  490. if pred.dim() == weight.dim() + 1:
  491. weight = weight.unsqueeze(1)
  492. return (pred * weight).sum() # 0
  493. assert reduction_override in (None, 'none', 'mean', 'sum')
  494. reduction = (
  495. reduction_override if reduction_override else self.reduction)
  496. if weight is not None and weight.dim() > 1:
  497. # TODO: remove this in the future
  498. # reduce the weight of shape (n, 4) to (n,) to match the
  499. # giou_loss of shape (n,)
  500. assert weight.shape == pred.shape
  501. weight = weight.mean(-1)
  502. loss = self.loss_weight * giou_loss(
  503. pred,
  504. target,
  505. weight,
  506. eps=self.eps,
  507. reduction=reduction,
  508. avg_factor=avg_factor,
  509. **kwargs)
  510. return loss
  511. @MODELS.register_module()
  512. class DIoULoss(nn.Module):
  513. r"""Implementation of `Distance-IoU Loss: Faster and Better
  514. Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_.
  515. Code is modified from https://github.com/Zzh-tju/DIoU.
  516. Args:
  517. eps (float): Epsilon to avoid log(0).
  518. reduction (str): Options are "none", "mean" and "sum".
  519. loss_weight (float): Weight of loss.
  520. """
  521. def __init__(self,
  522. eps: float = 1e-6,
  523. reduction: str = 'mean',
  524. loss_weight: float = 1.0) -> None:
  525. super().__init__()
  526. self.eps = eps
  527. self.reduction = reduction
  528. self.loss_weight = loss_weight
  529. def forward(self,
  530. pred: Tensor,
  531. target: Tensor,
  532. weight: Optional[Tensor] = None,
  533. avg_factor: Optional[int] = None,
  534. reduction_override: Optional[str] = None,
  535. **kwargs) -> Tensor:
  536. """Forward function.
  537. Args:
  538. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  539. shape (n, 4).
  540. target (Tensor): The learning target of the prediction,
  541. shape (n, 4).
  542. weight (Optional[Tensor], optional): The weight of loss for each
  543. prediction. Defaults to None.
  544. avg_factor (Optional[int], optional): Average factor that is used
  545. to average the loss. Defaults to None.
  546. reduction_override (Optional[str], optional): The reduction method
  547. used to override the original reduction method of the loss.
  548. Defaults to None. Options are "none", "mean" and "sum".
  549. Returns:
  550. Tensor: Loss tensor.
  551. """
  552. if weight is not None and not torch.any(weight > 0):
  553. if pred.dim() == weight.dim() + 1:
  554. weight = weight.unsqueeze(1)
  555. return (pred * weight).sum() # 0
  556. assert reduction_override in (None, 'none', 'mean', 'sum')
  557. reduction = (
  558. reduction_override if reduction_override else self.reduction)
  559. if weight is not None and weight.dim() > 1:
  560. # TODO: remove this in the future
  561. # reduce the weight of shape (n, 4) to (n,) to match the
  562. # giou_loss of shape (n,)
  563. assert weight.shape == pred.shape
  564. weight = weight.mean(-1)
  565. loss = self.loss_weight * diou_loss(
  566. pred,
  567. target,
  568. weight,
  569. eps=self.eps,
  570. reduction=reduction,
  571. avg_factor=avg_factor,
  572. **kwargs)
  573. return loss
  574. @MODELS.register_module()
  575. class CIoULoss(nn.Module):
  576. r"""`Implementation of paper `Enhancing Geometric Factors into
  577. Model Learning and Inference for Object Detection and Instance
  578. Segmentation <https://arxiv.org/abs/2005.03572>`_.
  579. Code is modified from https://github.com/Zzh-tju/CIoU.
  580. Args:
  581. eps (float): Epsilon to avoid log(0).
  582. reduction (str): Options are "none", "mean" and "sum".
  583. loss_weight (float): Weight of loss.
  584. """
  585. def __init__(self,
  586. eps: float = 1e-6,
  587. reduction: str = 'mean',
  588. loss_weight: float = 1.0) -> None:
  589. super().__init__()
  590. self.eps = eps
  591. self.reduction = reduction
  592. self.loss_weight = loss_weight
  593. def forward(self,
  594. pred: Tensor,
  595. target: Tensor,
  596. weight: Optional[Tensor] = None,
  597. avg_factor: Optional[int] = None,
  598. reduction_override: Optional[str] = None,
  599. **kwargs) -> Tensor:
  600. """Forward function.
  601. Args:
  602. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  603. shape (n, 4).
  604. target (Tensor): The learning target of the prediction,
  605. shape (n, 4).
  606. weight (Optional[Tensor], optional): The weight of loss for each
  607. prediction. Defaults to None.
  608. avg_factor (Optional[int], optional): Average factor that is used
  609. to average the loss. Defaults to None.
  610. reduction_override (Optional[str], optional): The reduction method
  611. used to override the original reduction method of the loss.
  612. Defaults to None. Options are "none", "mean" and "sum".
  613. Returns:
  614. Tensor: Loss tensor.
  615. """
  616. if weight is not None and not torch.any(weight > 0):
  617. if pred.dim() == weight.dim() + 1:
  618. weight = weight.unsqueeze(1)
  619. return (pred * weight).sum() # 0
  620. assert reduction_override in (None, 'none', 'mean', 'sum')
  621. reduction = (
  622. reduction_override if reduction_override else self.reduction)
  623. if weight is not None and weight.dim() > 1:
  624. # TODO: remove this in the future
  625. # reduce the weight of shape (n, 4) to (n,) to match the
  626. # giou_loss of shape (n,)
  627. assert weight.shape == pred.shape
  628. weight = weight.mean(-1)
  629. loss = self.loss_weight * ciou_loss(
  630. pred,
  631. target,
  632. weight,
  633. eps=self.eps,
  634. reduction=reduction,
  635. avg_factor=avg_factor,
  636. **kwargs)
  637. return loss
  638. @MODELS.register_module()
  639. class EIoULoss(nn.Module):
  640. r"""Implementation of paper `Extended-IoU Loss: A Systematic
  641. IoU-Related Method: Beyond Simplified Regression for Better
  642. Localization <https://ieeexplore.ieee.org/abstract/document/9429909>`_
  643. Code is modified from https://github.com//ShiqiYu/libfacedetection.train.
  644. Args:
  645. eps (float): Epsilon to avoid log(0).
  646. reduction (str): Options are "none", "mean" and "sum".
  647. loss_weight (float): Weight of loss.
  648. smooth_point (float): hyperparameter, default is 0.1.
  649. """
  650. def __init__(self,
  651. eps: float = 1e-6,
  652. reduction: str = 'mean',
  653. loss_weight: float = 1.0,
  654. smooth_point: float = 0.1) -> None:
  655. super().__init__()
  656. self.eps = eps
  657. self.reduction = reduction
  658. self.loss_weight = loss_weight
  659. self.smooth_point = smooth_point
  660. def forward(self,
  661. pred: Tensor,
  662. target: Tensor,
  663. weight: Optional[Tensor] = None,
  664. avg_factor: Optional[int] = None,
  665. reduction_override: Optional[str] = None,
  666. **kwargs) -> Tensor:
  667. """Forward function.
  668. Args:
  669. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  670. shape (n, 4).
  671. target (Tensor): The learning target of the prediction,
  672. shape (n, 4).
  673. weight (Optional[Tensor], optional): The weight of loss for each
  674. prediction. Defaults to None.
  675. avg_factor (Optional[int], optional): Average factor that is used
  676. to average the loss. Defaults to None.
  677. reduction_override (Optional[str], optional): The reduction method
  678. used to override the original reduction method of the loss.
  679. Defaults to None. Options are "none", "mean" and "sum".
  680. Returns:
  681. Tensor: Loss tensor.
  682. """
  683. if weight is not None and not torch.any(weight > 0):
  684. if pred.dim() == weight.dim() + 1:
  685. weight = weight.unsqueeze(1)
  686. return (pred * weight).sum() # 0
  687. assert reduction_override in (None, 'none', 'mean', 'sum')
  688. reduction = (
  689. reduction_override if reduction_override else self.reduction)
  690. if weight is not None and weight.dim() > 1:
  691. assert weight.shape == pred.shape
  692. weight = weight.mean(-1)
  693. loss = self.loss_weight * eiou_loss(
  694. pred,
  695. target,
  696. weight,
  697. smooth_point=self.smooth_point,
  698. eps=self.eps,
  699. reduction=reduction,
  700. avg_factor=avg_factor,
  701. **kwargs)
  702. return loss
  703. @MODELS.register_module()
  704. class SIoULoss(nn.Module):
  705. r"""`Implementation of paper `SIoU Loss: More Powerful Learning
  706. for Bounding Box Regression <https://arxiv.org/abs/2205.12740>`_.
  707. Code is modified from https://github.com/meituan/YOLOv6.
  708. Args:
  709. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  710. shape (n, 4).
  711. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  712. eps (float): Eps to avoid log(0).
  713. neg_gamma (bool): `True` follows original implementation in paper.
  714. Return:
  715. Tensor: Loss tensor.
  716. """
  717. def __init__(self,
  718. eps: float = 1e-6,
  719. reduction: str = 'mean',
  720. loss_weight: float = 1.0,
  721. neg_gamma: bool = False) -> None:
  722. super().__init__()
  723. self.eps = eps
  724. self.reduction = reduction
  725. self.loss_weight = loss_weight
  726. self.neg_gamma = neg_gamma
  727. def forward(self,
  728. pred: Tensor,
  729. target: Tensor,
  730. weight: Optional[Tensor] = None,
  731. avg_factor: Optional[int] = None,
  732. reduction_override: Optional[str] = None,
  733. **kwargs) -> Tensor:
  734. """Forward function.
  735. Args:
  736. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  737. shape (n, 4).
  738. target (Tensor): The learning target of the prediction,
  739. shape (n, 4).
  740. weight (Optional[Tensor], optional): The weight of loss for each
  741. prediction. Defaults to None.
  742. avg_factor (Optional[int], optional): Average factor that is used
  743. to average the loss. Defaults to None.
  744. reduction_override (Optional[str], optional): The reduction method
  745. used to override the original reduction method of the loss.
  746. Defaults to None. Options are "none", "mean" and "sum".
  747. Returns:
  748. Tensor: Loss tensor.
  749. """
  750. if weight is not None and not torch.any(weight > 0):
  751. if pred.dim() == weight.dim() + 1:
  752. weight = weight.unsqueeze(1)
  753. return (pred * weight).sum() # 0
  754. assert reduction_override in (None, 'none', 'mean', 'sum')
  755. reduction = (
  756. reduction_override if reduction_override else self.reduction)
  757. if weight is not None and weight.dim() > 1:
  758. # TODO: remove this in the future
  759. # reduce the weight of shape (n, 4) to (n,) to match the
  760. # giou_loss of shape (n,)
  761. assert weight.shape == pred.shape
  762. weight = weight.mean(-1)
  763. loss = self.loss_weight * siou_loss(
  764. pred,
  765. target,
  766. weight,
  767. eps=self.eps,
  768. reduction=reduction,
  769. avg_factor=avg_factor,
  770. neg_gamma=self.neg_gamma,
  771. **kwargs)
  772. return loss