test_l2_loss.py 527 B

123456789101112131415161718192021
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmdet.models import L2Loss
  5. class TestL2Loss(TestCase):
  6. def test_l2_loss(self):
  7. pred = torch.Tensor([[1, 1, 0, 0, 0, 0, 1]])
  8. target = torch.Tensor([[1, 1, 0, 0, 0, 0, 0]])
  9. loss = L2Loss(
  10. neg_pos_ub=2,
  11. pos_margin=0,
  12. neg_margin=0.1,
  13. hard_mining=True,
  14. loss_weight=1.0)
  15. assert torch.allclose(loss(pred, target), torch.tensor(0.1350))