12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- from mmengine.model import BaseModule
- from mmdet.registry import MODELS
- @MODELS.register_module()
- class TripletLoss(BaseModule):
- """Triplet loss with hard positive/negative mining.
- Reference:
- Hermans et al. In Defense of the Triplet Loss for
- Person Re-Identification. arXiv:1703.07737.
- Imported from `<https://github.com/KaiyangZhou/deep-person-reid/blob/
- master/torchreid/losses/hard_mine_triplet_loss.py>`_.
- Args:
- margin (float, optional): Margin for triplet loss. Defaults to 0.3.
- loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
- hard_mining (bool, optional): Whether to perform hard mining.
- Defaults to True.
- """
- def __init__(self,
- margin: float = 0.3,
- loss_weight: float = 1.0,
- hard_mining=True):
- super(TripletLoss, self).__init__()
- self.margin = margin
- self.ranking_loss = nn.MarginRankingLoss(margin=margin)
- self.loss_weight = loss_weight
- self.hard_mining = hard_mining
- def hard_mining_triplet_loss_forward(
- self, inputs: torch.Tensor,
- targets: torch.LongTensor) -> torch.Tensor:
- """
- Args:
- inputs (torch.Tensor): feature matrix with shape
- (batch_size, feat_dim).
- targets (torch.LongTensor): ground truth labels with shape
- (num_classes).
- Returns:
- torch.Tensor: triplet loss with hard mining.
- """
- batch_size = inputs.size(0)
- # Compute Euclidean distance
- dist = torch.pow(inputs, 2).sum(
- dim=1, keepdim=True).expand(batch_size, batch_size)
- dist = dist + dist.t()
- dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2)
- dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
- # For each anchor, find the furthest positive sample
- # and nearest negative sample in the embedding space
- mask = targets.expand(batch_size, batch_size).eq(
- targets.expand(batch_size, batch_size).t())
- dist_ap, dist_an = [], []
- for i in range(batch_size):
- dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
- dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
- dist_ap = torch.cat(dist_ap)
- dist_an = torch.cat(dist_an)
- # Compute ranking hinge loss
- y = torch.ones_like(dist_an)
- return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y)
- def forward(self, inputs: torch.Tensor,
- targets: torch.LongTensor) -> torch.Tensor:
- """
- Args:
- inputs (torch.Tensor): feature matrix with shape
- (batch_size, feat_dim).
- targets (torch.LongTensor): ground truth labels with shape
- (num_classes).
- Returns:
- torch.Tensor: triplet loss.
- """
- if self.hard_mining:
- return self.hard_mining_triplet_loss_forward(inputs, targets)
- else:
- raise NotImplementedError()
|