reid_dataset.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os.path as osp
  4. from collections import defaultdict
  5. from typing import Any, Dict, List
  6. import numpy as np
  7. from mmengine.dataset import BaseDataset
  8. from mmengine.utils import check_file_exist
  9. from mmdet.registry import DATASETS
  10. @DATASETS.register_module()
  11. class ReIDDataset(BaseDataset):
  12. """Dataset for ReID.
  13. Args:
  14. triplet_sampler (dict, optional): The sampler for hard mining
  15. triplet loss. Defaults to None.
  16. keys: num_ids (int): The number of person ids.
  17. ins_per_id (int): The number of image for each person.
  18. """
  19. def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
  20. self.triplet_sampler = triplet_sampler
  21. super().__init__(*args, **kwargs)
  22. def load_data_list(self) -> List[dict]:
  23. """Load annotations from an annotation file named as ''self.ann_file''.
  24. Returns:
  25. list[dict]: A list of annotation.
  26. """
  27. assert isinstance(self.ann_file, str)
  28. check_file_exist(self.ann_file)
  29. data_list = []
  30. with open(self.ann_file) as f:
  31. samples = [x.strip().split(' ') for x in f.readlines()]
  32. for filename, gt_label in samples:
  33. info = dict(img_prefix=self.data_prefix)
  34. if self.data_prefix['img_path'] is not None:
  35. info['img_path'] = osp.join(self.data_prefix['img_path'],
  36. filename)
  37. else:
  38. info['img_path'] = filename
  39. info['gt_label'] = np.array(gt_label, dtype=np.int64)
  40. data_list.append(info)
  41. self._parse_ann_info(data_list)
  42. return data_list
  43. def _parse_ann_info(self, data_list: List[dict]):
  44. """Parse person id annotations."""
  45. index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN]
  46. self.index_dic = dict() # pid->array([idx1,...,idxN])
  47. for idx, info in enumerate(data_list):
  48. pid = info['gt_label']
  49. index_tmp_dic[int(pid)].append(idx)
  50. for pid, idxs in index_tmp_dic.items():
  51. self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
  52. self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)
  53. def prepare_data(self, idx: int) -> Any:
  54. """Get data processed by ''self.pipeline''.
  55. Args:
  56. idx (int): The index of ''data_info''
  57. Returns:
  58. Any: Depends on ''self.pipeline''
  59. """
  60. data_info = self.get_data_info(idx)
  61. if self.triplet_sampler is not None:
  62. img_info = self.triplet_sampling(data_info['gt_label'],
  63. **self.triplet_sampler)
  64. data_info = copy.deepcopy(img_info) # triplet -> list
  65. else:
  66. data_info = copy.deepcopy(data_info) # no triplet -> dict
  67. return self.pipeline(data_info)
  68. def triplet_sampling(self,
  69. pos_pid,
  70. num_ids: int = 8,
  71. ins_per_id: int = 4) -> Dict:
  72. """Triplet sampler for hard mining triplet loss. First, for one
  73. pos_pid, random sample ins_per_id images with same person id.
  74. Then, random sample num_ids - 1 images for each negative id.
  75. Finally, random sample ins_per_id images for each negative id.
  76. Args:
  77. pos_pid (ndarray): The person id of the anchor.
  78. num_ids (int): The number of person ids.
  79. ins_per_id (int): The number of images for each person.
  80. Returns:
  81. Dict: Annotation information of num_ids X ins_per_id images.
  82. """
  83. assert len(self.pids) >= num_ids, \
  84. 'The number of person ids in the training set must ' \
  85. 'be greater than the number of person ids in the sample.'
  86. pos_idxs = self.index_dic[int(
  87. pos_pid)] # all positive idxs for pos_pid
  88. idxs_list = []
  89. # select positive samplers
  90. idxs_list.extend(pos_idxs[np.random.choice(
  91. pos_idxs.shape[0], ins_per_id, replace=True)])
  92. # select negative ids
  93. neg_pids = np.random.choice(
  94. [i for i, _ in enumerate(self.pids) if i != pos_pid],
  95. num_ids - 1,
  96. replace=False)
  97. # select negative samplers for each negative id
  98. for neg_pid in neg_pids:
  99. neg_idxs = self.index_dic[neg_pid]
  100. idxs_list.extend(neg_idxs[np.random.choice(
  101. neg_idxs.shape[0], ins_per_id, replace=True)])
  102. # return the final triplet batch
  103. triplet_img_infos = []
  104. for idx in idxs_list:
  105. triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
  106. # Collect data_list scatters (list of dict -> dict of list)
  107. out = dict()
  108. for key in triplet_img_infos[0].keys():
  109. out[key] = [_info[key] for _info in triplet_img_infos]
  110. return out