batch_sampler.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Sequence
  3. from torch.utils.data import BatchSampler, Sampler
  4. from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler
  5. from mmdet.registry import DATA_SAMPLERS
  6. # TODO: maybe replace with a data_loader wrapper
  7. @DATA_SAMPLERS.register_module()
  8. class AspectRatioBatchSampler(BatchSampler):
  9. """A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
  10. >= 1) into a same batch.
  11. Args:
  12. sampler (Sampler): Base sampler.
  13. batch_size (int): Size of mini-batch.
  14. drop_last (bool): If ``True``, the sampler will drop the last batch if
  15. its size would be less than ``batch_size``.
  16. """
  17. def __init__(self,
  18. sampler: Sampler,
  19. batch_size: int,
  20. drop_last: bool = False) -> None:
  21. if not isinstance(sampler, Sampler):
  22. raise TypeError('sampler should be an instance of ``Sampler``, '
  23. f'but got {sampler}')
  24. if not isinstance(batch_size, int) or batch_size <= 0:
  25. raise ValueError('batch_size should be a positive integer value, '
  26. f'but got batch_size={batch_size}')
  27. self.sampler = sampler
  28. self.batch_size = batch_size
  29. self.drop_last = drop_last
  30. # two groups for w < h and w >= h
  31. self._aspect_ratio_buckets = [[] for _ in range(2)]
  32. def __iter__(self) -> Sequence[int]:
  33. for idx in self.sampler:
  34. data_info = self.sampler.dataset.get_data_info(idx)
  35. width, height = data_info['width'], data_info['height']
  36. bucket_id = 0 if width < height else 1
  37. bucket = self._aspect_ratio_buckets[bucket_id]
  38. bucket.append(idx)
  39. # yield a batch of indices in the same aspect ratio group
  40. if len(bucket) == self.batch_size:
  41. yield bucket[:]
  42. del bucket[:]
  43. # yield the rest data and reset the bucket
  44. left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
  45. 1]
  46. self._aspect_ratio_buckets = [[] for _ in range(2)]
  47. while len(left_data) > 0:
  48. if len(left_data) <= self.batch_size:
  49. if not self.drop_last:
  50. yield left_data[:]
  51. left_data = []
  52. else:
  53. yield left_data[:self.batch_size]
  54. left_data = left_data[self.batch_size:]
  55. def __len__(self) -> int:
  56. if self.drop_last:
  57. return len(self.sampler) // self.batch_size
  58. else:
  59. return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  60. @DATA_SAMPLERS.register_module()
  61. class TrackAspectRatioBatchSampler(AspectRatioBatchSampler):
  62. """A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
  63. >= 1) into a same batch.
  64. Args:
  65. sampler (Sampler): Base sampler.
  66. batch_size (int): Size of mini-batch.
  67. drop_last (bool): If ``True``, the sampler will drop the last batch if
  68. its size would be less than ``batch_size``.
  69. """
  70. def __iter__(self) -> Sequence[int]:
  71. for idx in self.sampler:
  72. # hard code to solve TrackImgSampler
  73. if isinstance(self.sampler, TrackImgSampler):
  74. video_idx, _ = idx
  75. else:
  76. video_idx = idx
  77. # video_idx
  78. data_info = self.sampler.dataset.get_data_info(video_idx)
  79. # data_info {video_id, images, video_length}
  80. img_data_info = data_info['images'][0]
  81. width, height = img_data_info['width'], img_data_info['height']
  82. bucket_id = 0 if width < height else 1
  83. bucket = self._aspect_ratio_buckets[bucket_id]
  84. bucket.append(idx)
  85. # yield a batch of indices in the same aspect ratio group
  86. if len(bucket) == self.batch_size:
  87. yield bucket[:]
  88. del bucket[:]
  89. # yield the rest data and reset the bucket
  90. left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
  91. 1]
  92. self._aspect_ratio_buckets = [[] for _ in range(2)]
  93. while len(left_data) > 0:
  94. if len(left_data) <= self.batch_size:
  95. if not self.drop_last:
  96. yield left_data[:]
  97. left_data = []
  98. else:
  99. yield left_data[:self.batch_size]
  100. left_data = left_data[self.batch_size:]