utils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.logging import print_log
  6. from .text_encoder import CLIPTextEncoder
  7. # download from
  8. # https://github.com/facebookresearch/Detic/tree/main/datasets/metadata
  9. DATASET_EMBEDDINGS = {
  10. 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
  11. 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
  12. 'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
  13. 'coco': 'datasets/metadata/coco_clip_a+cname.npy',
  14. }
  15. def get_text_embeddings(dataset=None,
  16. custom_vocabulary=None,
  17. prompt_prefix='a '):
  18. assert (dataset is None) ^ (custom_vocabulary is None), \
  19. 'Either `dataset` or `custom_vocabulary` should be specified.'
  20. if dataset:
  21. if dataset in DATASET_EMBEDDINGS:
  22. return DATASET_EMBEDDINGS[dataset]
  23. else:
  24. custom_vocabulary = get_class_names(dataset)
  25. text_encoder = CLIPTextEncoder()
  26. text_encoder.eval()
  27. texts = [prompt_prefix + x for x in custom_vocabulary]
  28. print_log(
  29. f'Computing text embeddings for {len(custom_vocabulary)} classes.')
  30. embeddings = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
  31. return embeddings
  32. def get_class_names(dataset):
  33. if dataset == 'coco':
  34. from mmdet.datasets import CocoDataset
  35. class_names = CocoDataset.METAINFO['classes']
  36. elif dataset == 'cityscapes':
  37. from mmdet.datasets import CityscapesDataset
  38. class_names = CityscapesDataset.METAINFO['classes']
  39. elif dataset == 'voc':
  40. from mmdet.datasets import VOCDataset
  41. class_names = VOCDataset.METAINFO['classes']
  42. elif dataset == 'openimages':
  43. from mmdet.datasets import OpenImagesDataset
  44. class_names = OpenImagesDataset.METAINFO['classes']
  45. elif dataset == 'lvis':
  46. from mmdet.datasets import LVISV1Dataset
  47. class_names = LVISV1Dataset.METAINFO['classes']
  48. else:
  49. raise TypeError(f'Invalid type for dataset name: {type(dataset)}')
  50. return class_names
  51. def reset_cls_layer_weight(model, weight):
  52. if type(weight) == str:
  53. print_log(f'Resetting cls_layer_weight from file: {weight}')
  54. zs_weight = torch.tensor(
  55. np.load(weight),
  56. dtype=torch.float32).permute(1, 0).contiguous() # D x C
  57. else:
  58. zs_weight = weight
  59. zs_weight = torch.cat(
  60. [zs_weight, zs_weight.new_zeros(
  61. (zs_weight.shape[0], 1))], dim=1) # D x (C + 1)
  62. zs_weight = F.normalize(zs_weight, p=2, dim=0)
  63. zs_weight = zs_weight.to('cuda')
  64. num_classes = zs_weight.shape[-1]
  65. for bbox_head in model.roi_head.bbox_head:
  66. bbox_head.num_classes = num_classes
  67. del bbox_head.fc_cls.zs_weight
  68. bbox_head.fc_cls.zs_weight = zs_weight