12345678910111213141516171819 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmdet.models import TripletLoss
- class TestTripletLoss(TestCase):
- def test_triplet_loss(self):
- feature = torch.Tensor([[1, 1], [1, 1], [0, 0], [0, 0]])
- label = torch.Tensor([1, 1, 0, 0])
- loss = TripletLoss(margin=0.3, loss_weight=1.0)
- assert torch.allclose(loss(feature, label), torch.tensor(0.))
- label = torch.Tensor([1, 0, 1, 0])
- assert torch.allclose(loss(feature, label), torch.tensor(1.7142))
|