test_triplet_loss.py 552 B

12345678910111213141516171819
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmdet.models import TripletLoss
  5. class TestTripletLoss(TestCase):
  6. def test_triplet_loss(self):
  7. feature = torch.Tensor([[1, 1], [1, 1], [0, 0], [0, 0]])
  8. label = torch.Tensor([1, 1, 0, 0])
  9. loss = TripletLoss(margin=0.3, loss_weight=1.0)
  10. assert torch.allclose(loss(feature, label), torch.tensor(0.))
  11. label = torch.Tensor([1, 0, 1, 0])
  12. assert torch.allclose(loss(feature, label), torch.tensor(1.7142))