track_data_preprocessor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Sequence, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from mmengine.model.utils import stack_batch
  7. from mmdet.models.utils.misc import samplelist_boxtype2tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures import TrackDataSample
  10. from mmdet.structures.mask import BitmapMasks
  11. from .data_preprocessor import DetDataPreprocessor
  12. @MODELS.register_module()
  13. class TrackDataPreprocessor(DetDataPreprocessor):
  14. """Image pre-processor for tracking tasks.
  15. Accepts the data sampled by the dataloader, and preprocesses
  16. it into the format of the model input. ``TrackDataPreprocessor``
  17. provides the tracking data pre-processing as follows:
  18. - Collate and move data to the target device.
  19. - Pad inputs to the maximum size of current batch with defined
  20. ``pad_value``. The padding size can be divisible by a defined
  21. ``pad_size_divisor``
  22. - Stack inputs to inputs.
  23. - Convert inputs from bgr to rgb if the shape of input is (1, 3, H, W).
  24. - Normalize image with defined std and mean.
  25. - Do batch augmentations during training.
  26. - Record the information of ``batch_input_shape`` and ``pad_shape``.
  27. Args:
  28. mean (Sequence[Number], optional): The pixel mean of R, G, B
  29. channels. Defaults to None.
  30. std (Sequence[Number], optional): The pixel standard deviation of
  31. R, G, B channels. Defaults to None.
  32. pad_size_divisor (int): The size of padded image should be
  33. divisible by ``pad_size_divisor``. Defaults to 1.
  34. pad_value (Number): The padded pixel value. Defaults to 0.
  35. pad_mask (bool): Whether to pad instance masks. Defaults to False.
  36. mask_pad_value (int): The padded pixel value for instance masks.
  37. Defaults to 0.
  38. bgr_to_rgb (bool): whether to convert image from BGR to RGB.
  39. Defaults to False.
  40. rgb_to_bgr (bool): whether to convert image from RGB to RGB.
  41. Defaults to False.
  42. use_det_processor: (bool): whether to use DetDataPreprocessor
  43. in training phrase. This is mainly for some tracking models
  44. fed into one image rather than a group of image in training.
  45. Defaults to False.
  46. . boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of
  47. bboxes data to ``Tensor`` type. Defaults to True.
  48. batch_augments (list[dict], optional): Batch-level augmentations
  49. """
  50. def __init__(self,
  51. mean: Optional[Sequence[Union[float, int]]] = None,
  52. std: Optional[Sequence[Union[float, int]]] = None,
  53. use_det_processor: bool = False,
  54. **kwargs):
  55. super().__init__(mean=mean, std=std, **kwargs)
  56. self.use_det_processor = use_det_processor
  57. if mean is not None and not self.use_det_processor:
  58. # overwrite the ``register_bufffer`` in ``ImgDataPreprocessor``
  59. # since the shape of ``mean`` and ``std`` in tracking tasks must be
  60. # (T, C, H, W), which T is the temporal length of the video.
  61. self.register_buffer('mean',
  62. torch.tensor(mean).view(1, -1, 1, 1), False)
  63. self.register_buffer('std',
  64. torch.tensor(std).view(1, -1, 1, 1), False)
  65. def forward(self, data: dict, training: bool = False) -> Dict:
  66. """Perform normalization、padding and bgr2rgb conversion based on
  67. ``TrackDataPreprocessor``.
  68. Args:
  69. data (dict): data sampled from dataloader.
  70. training (bool): Whether to enable training time augmentation.
  71. Returns:
  72. Tuple[Dict[str, List[torch.Tensor]], OptSampleList]: Data in the
  73. same format as the model input.
  74. """
  75. if self.use_det_processor and training:
  76. batch_pad_shape = self._get_pad_shape(data)
  77. else:
  78. batch_pad_shape = self._get_track_pad_shape(data)
  79. data = self.cast_data(data)
  80. imgs, data_samples = data['inputs'], data['data_samples']
  81. if self.use_det_processor and training:
  82. assert imgs[0].dim() == 3, \
  83. 'Only support the 3 dims when use detpreprocessor in training'
  84. if self._channel_conversion:
  85. imgs = [_img[[2, 1, 0], ...] for _img in imgs]
  86. # Convert to `float`
  87. imgs = [_img.float() for _img in imgs]
  88. if self._enable_normalize:
  89. imgs = [(_img - self.mean) / self.std for _img in imgs]
  90. inputs = stack_batch(imgs, self.pad_size_divisor, self.pad_value)
  91. else:
  92. assert imgs[0].dim() == 4, \
  93. 'Only support the 4 dims when use trackprocessor in training'
  94. # The shape of imgs[0] is (T, C, H, W).
  95. channel = imgs[0].size(1)
  96. if self._channel_conversion and channel == 3:
  97. imgs = [_img[:, [2, 1, 0], ...] for _img in imgs]
  98. # change to `float`
  99. imgs = [_img.float() for _img in imgs]
  100. if self._enable_normalize:
  101. imgs = [(_img - self.mean) / self.std for _img in imgs]
  102. inputs = stack_track_batch(imgs, self.pad_size_divisor,
  103. self.pad_value)
  104. if data_samples is not None:
  105. # NOTE the batched image size information may be useful, e.g.
  106. # in DETR, this is needed for the construction of masks, which is
  107. # then used for the transformer_head.
  108. batch_input_shape = tuple(inputs.size()[-2:])
  109. if self.use_det_processor and training:
  110. for data_sample, pad_shape in zip(data_samples,
  111. batch_pad_shape):
  112. data_sample.set_metainfo({
  113. 'batch_input_shape': batch_input_shape,
  114. 'pad_shape': pad_shape
  115. })
  116. if self.boxtype2tensor:
  117. samplelist_boxtype2tensor(data_samples)
  118. if self.pad_mask:
  119. self.pad_gt_masks(data_samples)
  120. else:
  121. for track_data_sample, pad_shapes in zip(
  122. data_samples, batch_pad_shape):
  123. for i in range(len(track_data_sample)):
  124. det_data_sample = track_data_sample[i]
  125. det_data_sample.set_metainfo({
  126. 'batch_input_shape': batch_input_shape,
  127. 'pad_shape': pad_shapes[i]
  128. })
  129. if self.pad_mask and training:
  130. self.pad_track_gt_masks(data_samples)
  131. if training and self.batch_augments is not None:
  132. for batch_aug in self.batch_augments:
  133. if self.use_det_processor and training:
  134. inputs, data_samples = batch_aug(inputs, data_samples)
  135. else:
  136. # we only support T==1 when using batch augments.
  137. # Only yolox need batch_aug, and yolox can only process
  138. # (N, C, H, W) shape.
  139. # The shape of `inputs` is (N, T, C, H, W), hence, we use
  140. # inputs[:, 0] to change the shape to (N, C, H, W).
  141. assert inputs.size(1) == 1 and len(
  142. data_samples[0]
  143. ) == 1, 'Only support the number of sequence images equals to 1 when using batch augment.' # noqa: E501
  144. det_data_samples = [
  145. track_data_sample[0]
  146. for track_data_sample in data_samples
  147. ]
  148. aug_inputs, aug_det_samples = batch_aug(
  149. inputs[:, 0], det_data_samples)
  150. inputs = aug_inputs.unsqueeze(1)
  151. for track_data_sample, det_sample in zip(
  152. data_samples, aug_det_samples):
  153. track_data_sample.video_data_samples = [det_sample]
  154. # Note: inputs may contain large number of frames, so we must make
  155. # sure that the mmeory is contiguous for stable forward
  156. inputs = inputs.contiguous()
  157. return dict(inputs=inputs, data_samples=data_samples)
  158. def _get_track_pad_shape(self, data: dict) -> Dict[str, List]:
  159. """Get the pad_shape of each image based on data and pad_size_divisor.
  160. Args:
  161. data (dict): Data sampled from dataloader.
  162. Returns:
  163. Dict[str, List]: The shape of padding.
  164. """
  165. batch_pad_shape = dict()
  166. batch_pad_shape = []
  167. for imgs in data['inputs']:
  168. # The sequence images in one sample among a batch have the same
  169. # original shape
  170. pad_h = int(np.ceil(imgs.shape[-2] /
  171. self.pad_size_divisor)) * self.pad_size_divisor
  172. pad_w = int(np.ceil(imgs.shape[-1] /
  173. self.pad_size_divisor)) * self.pad_size_divisor
  174. pad_shapes = [(pad_h, pad_w)] * imgs.size(0)
  175. batch_pad_shape.append(pad_shapes)
  176. return batch_pad_shape
  177. def pad_track_gt_masks(self,
  178. data_samples: Sequence[TrackDataSample]) -> None:
  179. """Pad gt_masks to shape of batch_input_shape."""
  180. if 'masks' in data_samples[0][0].get('gt_instances', None):
  181. for track_data_sample in data_samples:
  182. for i in range(len(track_data_sample)):
  183. det_data_sample = track_data_sample[i]
  184. masks = det_data_sample.gt_instances.masks
  185. # TODO: whether to use BitmapMasks
  186. assert isinstance(masks, BitmapMasks)
  187. batch_input_shape = det_data_sample.batch_input_shape
  188. det_data_sample.gt_instances.masks = masks.pad(
  189. batch_input_shape, pad_val=self.mask_pad_value)
  190. def stack_track_batch(tensors: List[torch.Tensor],
  191. pad_size_divisor: int = 0,
  192. pad_value: Union[int, float] = 0) -> torch.Tensor:
  193. """Stack multiple tensors to form a batch and pad the images to the max
  194. shape use the right bottom padding mode in these images. If
  195. ``pad_size_divisor > 0``, add padding to ensure the common height and width
  196. is divisible by ``pad_size_divisor``. The difference between this function
  197. and ``stack_batch`` in MMEngine is that this function can process batch
  198. sequence images with shape (N, T, C, H, W).
  199. Args:
  200. tensors (List[Tensor]): The input multiple tensors. each is a
  201. TCHW 4D-tensor. T denotes the number of key/reference frames.
  202. pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
  203. to ensure the common height and width is divisible by
  204. ``pad_size_divisor``. This depends on the model, and many
  205. models need a divisibility of 32. Defaults to 0
  206. pad_value (int, float): The padding value. Defaults to 0
  207. Returns:
  208. Tensor: The NTCHW 5D-tensor. N denotes the batch size.
  209. """
  210. assert isinstance(tensors, list), \
  211. f'Expected input type to be list, but got {type(tensors)}'
  212. assert len(set([tensor.ndim for tensor in tensors])) == 1, \
  213. f'Expected the dimensions of all tensors must be the same, ' \
  214. f'but got {[tensor.ndim for tensor in tensors]}'
  215. assert tensors[0].ndim == 4, f'Expected tensor dimension to be 4, ' \
  216. f'but got {tensors[0].ndim}'
  217. assert len(set([tensor.shape[0] for tensor in tensors])) == 1, \
  218. f'Expected the channels of all tensors must be the same, ' \
  219. f'but got {[tensor.shape[0] for tensor in tensors]}'
  220. tensor_sizes = [(tensor.shape[-2], tensor.shape[-1]) for tensor in tensors]
  221. max_size = np.stack(tensor_sizes).max(0)
  222. if pad_size_divisor > 1:
  223. # the last two dims are H,W, both subject to divisibility requirement
  224. max_size = (
  225. max_size +
  226. (pad_size_divisor - 1)) // pad_size_divisor * pad_size_divisor
  227. padded_samples = []
  228. for tensor in tensors:
  229. padding_size = [
  230. 0, max_size[-1] - tensor.shape[-1], 0,
  231. max_size[-2] - tensor.shape[-2]
  232. ]
  233. if sum(padding_size) == 0:
  234. padded_samples.append(tensor)
  235. else:
  236. padded_samples.append(F.pad(tensor, padding_size, value=pad_value))
  237. return torch.stack(padded_samples, dim=0)