test_similarity.py 251 B

12345678910
  1. import torch
  2. from mmdet.models.task_modules import embed_similarity
  3. def test_embed_similarity():
  4. """Test embed similarity."""
  5. embeds = torch.rand(2, 3)
  6. similarity = embed_similarity(embeds, embeds)
  7. assert similarity.shape == (2, 2)