frame_sampling.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import random
  3. from collections import defaultdict
  4. from typing import Dict, List, Optional, Union
  5. from mmcv.transforms import BaseTransform
  6. from mmdet.registry import TRANSFORMS
  7. @TRANSFORMS.register_module()
  8. class BaseFrameSample(BaseTransform):
  9. """Directly get the key frame, no reference frames.
  10. Args:
  11. collect_video_keys (list[str]): The keys of video info to be
  12. collected.
  13. """
  14. def __init__(self,
  15. collect_video_keys: List[str] = ['video_id', 'video_length']):
  16. self.collect_video_keys = collect_video_keys
  17. def prepare_data(self, video_infos: dict,
  18. sampled_inds: List[int]) -> Dict[str, List]:
  19. """Prepare data for the subsequent pipeline.
  20. Args:
  21. video_infos (dict): The whole video information.
  22. sampled_inds (list[int]): The sampled frame indices.
  23. Returns:
  24. dict: The processed data information.
  25. """
  26. frames_anns = video_infos['images']
  27. final_data_info = defaultdict(list)
  28. # for data in frames_anns:
  29. for index in sampled_inds:
  30. data = frames_anns[index]
  31. # copy the info in video-level into img-level
  32. for key in self.collect_video_keys:
  33. if key == 'video_length':
  34. data['ori_video_length'] = video_infos[key]
  35. data['video_length'] = len(sampled_inds)
  36. else:
  37. data[key] = video_infos[key]
  38. # Collate data_list (list of dict to dict of list)
  39. for key, value in data.items():
  40. final_data_info[key].append(value)
  41. return final_data_info
  42. def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
  43. """Transform the video information.
  44. Args:
  45. video_infos (dict): The whole video information.
  46. Returns:
  47. dict: The data information of the key frames.
  48. """
  49. if 'key_frame_id' in video_infos:
  50. key_frame_id = video_infos['key_frame_id']
  51. assert isinstance(video_infos['key_frame_id'], int)
  52. else:
  53. key_frame_id = random.sample(
  54. list(range(video_infos['video_length'])), 1)[0]
  55. results = self.prepare_data(video_infos, [key_frame_id])
  56. return results
  57. def __repr__(self) -> str:
  58. repr_str = self.__class__.__name__
  59. repr_str += f'(collect_video_keys={self.collect_video_keys})'
  60. return repr_str
  61. @TRANSFORMS.register_module()
  62. class UniformRefFrameSample(BaseFrameSample):
  63. """Uniformly sample reference frames.
  64. Args:
  65. num_ref_imgs (int): Number of reference frames to be sampled.
  66. frame_range (int | list[int]): Range of frames to be sampled around
  67. key frame. If int, the range is [-frame_range, frame_range].
  68. Defaults to 10.
  69. filter_key_img (bool): Whether to filter the key frame when
  70. sampling reference frames. Defaults to True.
  71. collect_video_keys (list[str]): The keys of video info to be
  72. collected.
  73. """
  74. def __init__(self,
  75. num_ref_imgs: int = 1,
  76. frame_range: Union[int, List[int]] = 10,
  77. filter_key_img: bool = True,
  78. collect_video_keys: List[str] = ['video_id', 'video_length']):
  79. self.num_ref_imgs = num_ref_imgs
  80. self.filter_key_img = filter_key_img
  81. if isinstance(frame_range, int):
  82. assert frame_range >= 0, 'frame_range can not be a negative value.'
  83. frame_range = [-frame_range, frame_range]
  84. elif isinstance(frame_range, list):
  85. assert len(frame_range) == 2, 'The length must be 2.'
  86. assert frame_range[0] <= 0 and frame_range[1] >= 0
  87. for i in frame_range:
  88. assert isinstance(i, int), 'Each element must be int.'
  89. else:
  90. raise TypeError('The type of frame_range must be int or list.')
  91. self.frame_range = frame_range
  92. super().__init__(collect_video_keys=collect_video_keys)
  93. def sampling_frames(self, video_length: int, key_frame_id: int):
  94. """Sampling frames.
  95. Args:
  96. video_length (int): The length of the video.
  97. key_frame_id (int): The key frame id.
  98. Returns:
  99. list[int]: The sampled frame indices.
  100. """
  101. if video_length > 1:
  102. left = max(0, key_frame_id + self.frame_range[0])
  103. right = min(key_frame_id + self.frame_range[1], video_length - 1)
  104. frame_ids = list(range(0, video_length))
  105. valid_ids = frame_ids[left:right + 1]
  106. if self.filter_key_img and key_frame_id in valid_ids:
  107. valid_ids.remove(key_frame_id)
  108. assert len(
  109. valid_ids
  110. ) > 0, 'After filtering key frame, there are no valid frames'
  111. if len(valid_ids) < self.num_ref_imgs:
  112. valid_ids = valid_ids * self.num_ref_imgs
  113. ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs)
  114. else:
  115. ref_frame_ids = [key_frame_id] * self.num_ref_imgs
  116. sampled_frames_ids = [key_frame_id] + ref_frame_ids
  117. sampled_frames_ids = sorted(sampled_frames_ids)
  118. key_frames_ind = sampled_frames_ids.index(key_frame_id)
  119. key_frame_flags = [False] * len(sampled_frames_ids)
  120. key_frame_flags[key_frames_ind] = True
  121. return sampled_frames_ids, key_frame_flags
  122. def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
  123. """Transform the video information.
  124. Args:
  125. video_infos (dict): The whole video information.
  126. Returns:
  127. dict: The data information of the sampled frames.
  128. """
  129. if 'key_frame_id' in video_infos:
  130. key_frame_id = video_infos['key_frame_id']
  131. assert isinstance(video_infos['key_frame_id'], int)
  132. else:
  133. key_frame_id = random.sample(
  134. list(range(video_infos['video_length'])), 1)[0]
  135. (sampled_frames_ids, key_frame_flags) = self.sampling_frames(
  136. video_infos['video_length'], key_frame_id=key_frame_id)
  137. results = self.prepare_data(video_infos, sampled_frames_ids)
  138. results['key_frame_flags'] = key_frame_flags
  139. return results
  140. def __repr__(self) -> str:
  141. repr_str = self.__class__.__name__
  142. repr_str += f'(num_ref_imgs={self.num_ref_imgs}, '
  143. repr_str += f'frame_range={self.frame_range}, '
  144. repr_str += f'filter_key_img={self.filter_key_img}, '
  145. repr_str += f'collect_video_keys={self.collect_video_keys})'
  146. return repr_str