track_img_sampler.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import random
  4. from typing import Iterator, Optional, Sized
  5. import numpy as np
  6. from mmengine.dataset import ClassBalancedDataset, ConcatDataset
  7. from mmengine.dist import get_dist_info, sync_random_seed
  8. from torch.utils.data import Sampler
  9. from mmdet.registry import DATA_SAMPLERS
  10. from ..base_video_dataset import BaseVideoDataset
  11. @DATA_SAMPLERS.register_module()
  12. class TrackImgSampler(Sampler):
  13. """Sampler that providing image-level sampling outputs for video datasets
  14. in tracking tasks. It could be both used in both distributed and
  15. non-distributed environment.
  16. If using the default sampler in pytorch, the subsequent data receiver will
  17. get one video, which is not desired in some cases:
  18. (Take a non-distributed environment as an example)
  19. 1. In test mode, we want only one image is fed into the data pipeline. This
  20. is in consideration of memory usage since feeding the whole video commonly
  21. requires a large amount of memory (>=20G on MOTChallenge17 dataset), which
  22. is not available in some machines.
  23. 2. In training mode, we may want to make sure all the images in one video
  24. are randomly sampled once in one epoch and this can not be guaranteed in
  25. the default sampler in pytorch.
  26. Args:
  27. dataset (Sized): Dataset used for sampling.
  28. seed (int, optional): random seed used to shuffle the sampler. This
  29. number should be identical across all processes in the distributed
  30. group. Defaults to None.
  31. """
  32. def __init__(
  33. self,
  34. dataset: Sized,
  35. seed: Optional[int] = None,
  36. ) -> None:
  37. rank, world_size = get_dist_info()
  38. self.rank = rank
  39. self.world_size = world_size
  40. self.epoch = 0
  41. if seed is None:
  42. self.seed = sync_random_seed()
  43. else:
  44. self.seed = seed
  45. self.dataset = dataset
  46. self.indices = []
  47. # Hard code here to handle different dataset wrapper
  48. if isinstance(self.dataset, ConcatDataset):
  49. cat_datasets = self.dataset.datasets
  50. assert isinstance(
  51. cat_datasets[0], BaseVideoDataset
  52. ), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}'
  53. self.test_mode = cat_datasets[0].test_mode
  54. assert not self.test_mode, "'ConcatDataset' should not exist in "
  55. 'test mode'
  56. for dataset in cat_datasets:
  57. num_videos = len(dataset)
  58. for video_ind in range(num_videos):
  59. self.indices.extend([
  60. (video_ind, frame_ind) for frame_ind in range(
  61. dataset.get_len_per_video(video_ind))
  62. ])
  63. elif isinstance(self.dataset, ClassBalancedDataset):
  64. ori_dataset = self.dataset.dataset
  65. assert isinstance(
  66. ori_dataset, BaseVideoDataset
  67. ), f'expected BaseVideoDataset, but got {type(ori_dataset)}'
  68. self.test_mode = ori_dataset.test_mode
  69. assert not self.test_mode, "'ClassBalancedDataset' should not "
  70. 'exist in test mode'
  71. video_indices = self.dataset.repeat_indices
  72. for index in video_indices:
  73. self.indices.extend([(index, frame_ind) for frame_ind in range(
  74. ori_dataset.get_len_per_video(index))])
  75. else:
  76. assert isinstance(
  77. self.dataset, BaseVideoDataset
  78. ), 'TrackImgSampler is only supported in BaseVideoDataset or '
  79. 'dataset wrapper: ClassBalancedDataset and ConcatDataset, but '
  80. f'got {type(self.dataset)} '
  81. self.test_mode = self.dataset.test_mode
  82. num_videos = len(self.dataset)
  83. if self.test_mode:
  84. # in test mode, the images belong to the same video must be put
  85. # on the same device.
  86. if num_videos < self.world_size:
  87. raise ValueError(f'only {num_videos} videos loaded,'
  88. f'but {self.world_size} gpus were given.')
  89. chunks = np.array_split(
  90. list(range(num_videos)), self.world_size)
  91. for videos_inds in chunks:
  92. indices_chunk = []
  93. for video_ind in videos_inds:
  94. indices_chunk.extend([
  95. (video_ind, frame_ind) for frame_ind in range(
  96. self.dataset.get_len_per_video(video_ind))
  97. ])
  98. self.indices.append(indices_chunk)
  99. else:
  100. for video_ind in range(num_videos):
  101. self.indices.extend([
  102. (video_ind, frame_ind) for frame_ind in range(
  103. self.dataset.get_len_per_video(video_ind))
  104. ])
  105. if self.test_mode:
  106. self.num_samples = len(self.indices[self.rank])
  107. self.total_size = sum(
  108. [len(index_list) for index_list in self.indices])
  109. else:
  110. self.num_samples = int(
  111. math.ceil(len(self.indices) * 1.0 / self.world_size))
  112. self.total_size = self.num_samples * self.world_size
  113. def __iter__(self) -> Iterator:
  114. if self.test_mode:
  115. # in test mode, the order of frames can not be shuffled.
  116. indices = self.indices[self.rank]
  117. else:
  118. # deterministically shuffle based on epoch
  119. rng = random.Random(self.epoch + self.seed)
  120. indices = rng.sample(self.indices, len(self.indices))
  121. # add extra samples to make it evenly divisible
  122. indices += indices[:(self.total_size - len(indices))]
  123. assert len(indices) == self.total_size
  124. # subsample
  125. indices = indices[self.rank:self.total_size:self.world_size]
  126. assert len(indices) == self.num_samples
  127. return iter(indices)
  128. def __len__(self):
  129. return self.num_samples
  130. def set_epoch(self, epoch):
  131. self.epoch = epoch