# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmdet.models import TripletLoss


class TestTripletLoss(TestCase):

    def test_triplet_loss(self):
        feature = torch.Tensor([[1, 1], [1, 1], [0, 0], [0, 0]])
        label = torch.Tensor([1, 1, 0, 0])

        loss = TripletLoss(margin=0.3, loss_weight=1.0)
        assert torch.allclose(loss(feature, label), torch.tensor(0.))

        label = torch.Tensor([1, 0, 1, 0])
        assert torch.allclose(loss(feature, label), torch.tensor(1.7142))