triplet_loss.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmengine.model import BaseModule
  5. from mmdet.registry import MODELS
  6. @MODELS.register_module()
  7. class TripletLoss(BaseModule):
  8. """Triplet loss with hard positive/negative mining.
  9. Reference:
  10. Hermans et al. In Defense of the Triplet Loss for
  11. Person Re-Identification. arXiv:1703.07737.
  12. Imported from `<https://github.com/KaiyangZhou/deep-person-reid/blob/
  13. master/torchreid/losses/hard_mine_triplet_loss.py>`_.
  14. Args:
  15. margin (float, optional): Margin for triplet loss. Defaults to 0.3.
  16. loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
  17. hard_mining (bool, optional): Whether to perform hard mining.
  18. Defaults to True.
  19. """
  20. def __init__(self,
  21. margin: float = 0.3,
  22. loss_weight: float = 1.0,
  23. hard_mining=True):
  24. super(TripletLoss, self).__init__()
  25. self.margin = margin
  26. self.ranking_loss = nn.MarginRankingLoss(margin=margin)
  27. self.loss_weight = loss_weight
  28. self.hard_mining = hard_mining
  29. def hard_mining_triplet_loss_forward(
  30. self, inputs: torch.Tensor,
  31. targets: torch.LongTensor) -> torch.Tensor:
  32. """
  33. Args:
  34. inputs (torch.Tensor): feature matrix with shape
  35. (batch_size, feat_dim).
  36. targets (torch.LongTensor): ground truth labels with shape
  37. (num_classes).
  38. Returns:
  39. torch.Tensor: triplet loss with hard mining.
  40. """
  41. batch_size = inputs.size(0)
  42. # Compute Euclidean distance
  43. dist = torch.pow(inputs, 2).sum(
  44. dim=1, keepdim=True).expand(batch_size, batch_size)
  45. dist = dist + dist.t()
  46. dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2)
  47. dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
  48. # For each anchor, find the furthest positive sample
  49. # and nearest negative sample in the embedding space
  50. mask = targets.expand(batch_size, batch_size).eq(
  51. targets.expand(batch_size, batch_size).t())
  52. dist_ap, dist_an = [], []
  53. for i in range(batch_size):
  54. dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
  55. dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
  56. dist_ap = torch.cat(dist_ap)
  57. dist_an = torch.cat(dist_an)
  58. # Compute ranking hinge loss
  59. y = torch.ones_like(dist_an)
  60. return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y)
  61. def forward(self, inputs: torch.Tensor,
  62. targets: torch.LongTensor) -> torch.Tensor:
  63. """
  64. Args:
  65. inputs (torch.Tensor): feature matrix with shape
  66. (batch_size, feat_dim).
  67. targets (torch.LongTensor): ground truth labels with shape
  68. (num_classes).
  69. Returns:
  70. torch.Tensor: triplet loss.
  71. """
  72. if self.hard_mining:
  73. return self.hard_mining_triplet_loss_forward(inputs, targets)
  74. else:
  75. raise NotImplementedError()