test_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.utils import digit_version
  6. from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss,
  7. DistributionFocalLoss, EQLV2Loss, FocalLoss,
  8. GaussianFocalLoss,
  9. KnowledgeDistillationKLDivLoss, L1Loss,
  10. MarginL2Loss, MSELoss, QualityFocalLoss,
  11. SeesawLoss, SmoothL1Loss, VarifocalLoss)
  12. from mmdet.models.losses.ghm_loss import GHMC, GHMR
  13. from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss,
  14. EIoULoss, GIoULoss, IoULoss,
  15. SIoULoss)
  16. @pytest.mark.parametrize('loss_class', [
  17. IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, EIoULoss, SIoULoss
  18. ])
  19. def test_iou_type_loss_zeros_weight(loss_class):
  20. pred = torch.rand((10, 4))
  21. target = torch.rand((10, 4))
  22. weight = torch.zeros(10)
  23. loss = loss_class()(pred, target, weight)
  24. assert loss == 0.
  25. @pytest.mark.parametrize('loss_class', [
  26. BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss,
  27. EIoULoss, SIoULoss, FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss,
  28. GaussianFocalLoss, GIoULoss, QualityFocalLoss, IoULoss, L1Loss,
  29. VarifocalLoss, GHMR, GHMC, SmoothL1Loss, KnowledgeDistillationKLDivLoss,
  30. DiceLoss
  31. ])
  32. def test_loss_with_reduction_override(loss_class):
  33. pred = torch.rand((10, 4))
  34. target = torch.rand((10, 4)),
  35. weight = None
  36. with pytest.raises(AssertionError):
  37. # only reduction_override from [None, 'none', 'mean', 'sum']
  38. # is not allowed
  39. reduction_override = True
  40. loss_class()(
  41. pred, target, weight, reduction_override=reduction_override)
  42. @pytest.mark.parametrize('loss_class', [QualityFocalLoss])
  43. @pytest.mark.parametrize('activated', [False, True])
  44. def test_QualityFocalLoss_Loss(loss_class, activated):
  45. input_shape = (4, 5)
  46. pred = torch.rand(input_shape)
  47. label = torch.Tensor([0, 1, 2, 0]).long()
  48. quality_label = torch.rand(input_shape[0])
  49. original_loss = loss_class(activated=activated)(pred,
  50. (label, quality_label))
  51. assert isinstance(original_loss, torch.Tensor)
  52. target = torch.nn.functional.one_hot(label, 5)
  53. target = target * quality_label.reshape(input_shape[0], 1)
  54. new_loss = loss_class(activated=activated)(pred, target)
  55. assert isinstance(new_loss, torch.Tensor)
  56. assert new_loss == original_loss
  57. @pytest.mark.parametrize('loss_class', [
  58. IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, EIoULoss, SIoULoss,
  59. MSELoss, L1Loss, SmoothL1Loss, BalancedL1Loss, MarginL2Loss
  60. ])
  61. @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
  62. def test_regression_losses(loss_class, input_shape):
  63. pred = torch.rand(input_shape)
  64. target = torch.rand(input_shape)
  65. weight = torch.rand(input_shape)
  66. # Test loss forward
  67. loss = loss_class()(pred, target)
  68. assert isinstance(loss, torch.Tensor)
  69. # Test loss forward with weight
  70. loss = loss_class()(pred, target, weight)
  71. assert isinstance(loss, torch.Tensor)
  72. # Test loss forward with reduction_override
  73. loss = loss_class()(pred, target, reduction_override='mean')
  74. assert isinstance(loss, torch.Tensor)
  75. # Test loss forward with avg_factor
  76. loss = loss_class()(pred, target, avg_factor=10)
  77. assert isinstance(loss, torch.Tensor)
  78. with pytest.raises(ValueError):
  79. # loss can evaluate with avg_factor only if
  80. # reduction is None, 'none' or 'mean'.
  81. reduction_override = 'sum'
  82. loss_class()(
  83. pred, target, avg_factor=10, reduction_override=reduction_override)
  84. # Test loss forward with avg_factor and reduction
  85. for reduction_override in [None, 'none', 'mean']:
  86. loss_class()(
  87. pred, target, avg_factor=10, reduction_override=reduction_override)
  88. assert isinstance(loss, torch.Tensor)
  89. @pytest.mark.parametrize('loss_class', [CrossEntropyLoss])
  90. @pytest.mark.parametrize('input_shape', [(10, 5), (0, 5)])
  91. def test_classification_losses(loss_class, input_shape):
  92. if input_shape[0] == 0 and digit_version(
  93. torch.__version__) < digit_version('1.5.0'):
  94. pytest.skip(
  95. f'CELoss in PyTorch {torch.__version__} does not support empty'
  96. f'tensor.')
  97. pred = torch.rand(input_shape)
  98. target = torch.randint(0, 5, (input_shape[0], ))
  99. # Test loss forward
  100. loss = loss_class()(pred, target)
  101. assert isinstance(loss, torch.Tensor)
  102. # Test loss forward with reduction_override
  103. loss = loss_class()(pred, target, reduction_override='mean')
  104. assert isinstance(loss, torch.Tensor)
  105. # Test loss forward with avg_factor
  106. loss = loss_class()(pred, target, avg_factor=10)
  107. assert isinstance(loss, torch.Tensor)
  108. with pytest.raises(ValueError):
  109. # loss can evaluate with avg_factor only if
  110. # reduction is None, 'none' or 'mean'.
  111. reduction_override = 'sum'
  112. loss_class()(
  113. pred, target, avg_factor=10, reduction_override=reduction_override)
  114. # Test loss forward with avg_factor and reduction
  115. for reduction_override in [None, 'none', 'mean']:
  116. loss_class()(
  117. pred, target, avg_factor=10, reduction_override=reduction_override)
  118. assert isinstance(loss, torch.Tensor)
  119. @pytest.mark.parametrize('loss_class', [FocalLoss])
  120. @pytest.mark.parametrize('input_shape', [(10, 5), (3, 5, 40, 40)])
  121. def test_FocalLoss_loss(loss_class, input_shape):
  122. pred = torch.rand(input_shape)
  123. target = torch.randint(0, 5, (input_shape[0], ))
  124. if len(input_shape) == 4:
  125. B, N, W, H = input_shape
  126. target = F.one_hot(torch.randint(0, 5, (B * W * H, )),
  127. 5).reshape(B, W, H, N).permute(0, 3, 1, 2)
  128. # Test loss forward
  129. loss = loss_class()(pred, target)
  130. assert isinstance(loss, torch.Tensor)
  131. # Test loss forward with reduction_override
  132. loss = loss_class()(pred, target, reduction_override='mean')
  133. assert isinstance(loss, torch.Tensor)
  134. # Test loss forward with avg_factor
  135. loss = loss_class()(pred, target, avg_factor=10)
  136. assert isinstance(loss, torch.Tensor)
  137. with pytest.raises(ValueError):
  138. # loss can evaluate with avg_factor only if
  139. # reduction is None, 'none' or 'mean'.
  140. reduction_override = 'sum'
  141. loss_class()(
  142. pred, target, avg_factor=10, reduction_override=reduction_override)
  143. # Test loss forward with avg_factor and reduction
  144. for reduction_override in [None, 'none', 'mean']:
  145. loss_class()(
  146. pred, target, avg_factor=10, reduction_override=reduction_override)
  147. assert isinstance(loss, torch.Tensor)
  148. @pytest.mark.parametrize('loss_class', [GHMR])
  149. @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
  150. def test_GHMR_loss(loss_class, input_shape):
  151. pred = torch.rand(input_shape)
  152. target = torch.rand(input_shape)
  153. weight = torch.rand(input_shape)
  154. # Test loss forward
  155. loss = loss_class()(pred, target, weight)
  156. assert isinstance(loss, torch.Tensor)
  157. @pytest.mark.parametrize('use_sigmoid', [True, False])
  158. @pytest.mark.parametrize('reduction', ['sum', 'mean', None])
  159. @pytest.mark.parametrize('avg_non_ignore', [True, False])
  160. def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore):
  161. # Test cross_entropy loss
  162. loss_class = CrossEntropyLoss(
  163. use_sigmoid=use_sigmoid,
  164. use_mask=False,
  165. ignore_index=255,
  166. avg_non_ignore=avg_non_ignore)
  167. pred = torch.rand((10, 5))
  168. target = torch.randint(0, 5, (10, ))
  169. ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long)
  170. target[ignored_indices] = 255
  171. # Test loss forward with default ignore
  172. loss_with_ignore = loss_class(pred, target, reduction_override=reduction)
  173. assert isinstance(loss_with_ignore, torch.Tensor)
  174. # Test loss forward with forward ignore
  175. target[ignored_indices] = 255
  176. loss_with_forward_ignore = loss_class(
  177. pred, target, ignore_index=255, reduction_override=reduction)
  178. assert isinstance(loss_with_forward_ignore, torch.Tensor)
  179. # Verify correctness
  180. if avg_non_ignore:
  181. # manually remove the ignored elements
  182. not_ignored_indices = (target != 255)
  183. pred = pred[not_ignored_indices]
  184. target = target[not_ignored_indices]
  185. loss = loss_class(pred, target, reduction_override=reduction)
  186. assert torch.allclose(loss, loss_with_ignore)
  187. assert torch.allclose(loss, loss_with_forward_ignore)
  188. # test ignore all target
  189. pred = torch.rand((10, 5))
  190. target = torch.ones((10, ), dtype=torch.long) * 255
  191. loss = loss_class(pred, target, reduction_override=reduction)
  192. assert loss == 0
  193. @pytest.mark.parametrize('naive_dice', [True, False])
  194. def test_dice_loss(naive_dice):
  195. loss_class = DiceLoss
  196. pred = torch.rand((10, 4, 4))
  197. target = torch.rand((10, 4, 4))
  198. weight = torch.rand((10))
  199. # Test loss forward
  200. loss = loss_class(naive_dice=naive_dice)(pred, target)
  201. assert isinstance(loss, torch.Tensor)
  202. # Test loss forward with weight
  203. loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
  204. assert isinstance(loss, torch.Tensor)
  205. # Test loss forward with reduction_override
  206. loss = loss_class(naive_dice=naive_dice)(
  207. pred, target, reduction_override='mean')
  208. assert isinstance(loss, torch.Tensor)
  209. # Test loss forward with avg_factor
  210. loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10)
  211. assert isinstance(loss, torch.Tensor)
  212. with pytest.raises(ValueError):
  213. # loss can evaluate with avg_factor only if
  214. # reduction is None, 'none' or 'mean'.
  215. reduction_override = 'sum'
  216. loss_class(naive_dice=naive_dice)(
  217. pred, target, avg_factor=10, reduction_override=reduction_override)
  218. # Test loss forward with avg_factor and reduction
  219. for reduction_override in [None, 'none', 'mean']:
  220. loss_class(naive_dice=naive_dice)(
  221. pred, target, avg_factor=10, reduction_override=reduction_override)
  222. assert isinstance(loss, torch.Tensor)
  223. # Test loss forward with has_acted=False and use_sigmoid=False
  224. with pytest.raises(NotImplementedError):
  225. loss_class(
  226. use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred,
  227. target)
  228. # Test loss forward with weight.ndim != loss.ndim
  229. with pytest.raises(AssertionError):
  230. weight = torch.rand((2, 8))
  231. loss_class(naive_dice=naive_dice)(pred, target, weight)
  232. # Test loss forward with len(weight) != len(pred)
  233. with pytest.raises(AssertionError):
  234. weight = torch.rand((8))
  235. loss_class(naive_dice=naive_dice)(pred, target, weight)
  236. @pytest.mark.parametrize('loss_class', [EQLV2Loss])
  237. @pytest.mark.parametrize('reduction', ['mean'])
  238. def test_eqlv2_loss(loss_class, reduction):
  239. cls_score = torch.randn((1204, 1204))
  240. label = torch.randint(0, 2, (1204, ))
  241. weight = None
  242. loss = loss_class()(cls_score, label, weight)
  243. assert isinstance(loss, torch.Tensor)