similarity.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. def embed_similarity(key_embeds: Tensor,
  6. ref_embeds: Tensor,
  7. method: str = 'dot_product',
  8. temperature: int = -1) -> Tensor:
  9. """Calculate feature similarity from embeddings.
  10. Args:
  11. key_embeds (Tensor): Shape (N1, C).
  12. ref_embeds (Tensor): Shape (N2, C).
  13. method (str, optional): Method to calculate the similarity,
  14. options are 'dot_product' and 'cosine'. Defaults to
  15. 'dot_product'.
  16. temperature (int, optional): Softmax temperature. Defaults to -1.
  17. Returns:
  18. Tensor: Similarity matrix of shape (N1, N2).
  19. """
  20. assert method in ['dot_product', 'cosine']
  21. if method == 'cosine':
  22. key_embeds = F.normalize(key_embeds, p=2, dim=1)
  23. ref_embeds = F.normalize(ref_embeds, p=2, dim=1)
  24. similarity = torch.mm(key_embeds, ref_embeds.T)
  25. if temperature > 0:
  26. similarity /= float(temperature)
  27. return similarity