reid_data_preprocessor.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from numbers import Number
  4. from typing import Optional, Sequence
  5. import torch
  6. import torch.nn.functional as F
  7. from mmengine.model import BaseDataPreprocessor, stack_batch
  8. from mmdet.registry import MODELS
  9. try:
  10. import mmpretrain
  11. from mmpretrain.models.utils.batch_augments import RandomBatchAugment
  12. from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels,
  13. tensor_split)
  14. except ImportError:
  15. mmpretrain = None
  16. def stack_batch_scores(elements, device=None):
  17. """Stack the ``score`` of a batch of :obj:`LabelData` to a tensor.
  18. Args:
  19. elements (List[LabelData]): A batch of :obj`LabelData`.
  20. device (torch.device, optional): The output device of the batch label.
  21. Defaults to None.
  22. Returns:
  23. torch.Tensor: The stacked score tensor.
  24. """
  25. item = elements[0]
  26. if 'score' not in item._data_fields:
  27. return None
  28. batch_score = torch.stack([element.score for element in elements])
  29. if device is not None:
  30. batch_score = batch_score.to(device)
  31. return batch_score
  32. @MODELS.register_module()
  33. class ReIDDataPreprocessor(BaseDataPreprocessor):
  34. """Image pre-processor for classification tasks.
  35. Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
  36. 1. It won't do normalization if ``mean`` is not specified.
  37. 2. It does normalization and color space conversion after stacking batch.
  38. 3. It supports batch augmentations like mixup and cutmix.
  39. It provides the data pre-processing as follows
  40. - Collate and move data to the target device.
  41. - Pad inputs to the maximum size of current batch with defined
  42. ``pad_value``. The padding size can be divisible by a defined
  43. ``pad_size_divisor``
  44. - Stack inputs to batch_inputs.
  45. - Convert inputs from bgr to rgb if the shape of input is (3, H, W).
  46. - Normalize image with defined std and mean.
  47. - Do batch augmentations like Mixup and Cutmix during training.
  48. Args:
  49. mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
  50. Defaults to None.
  51. std (Sequence[Number], optional): The pixel standard deviation of
  52. R, G, B channels. Defaults to None.
  53. pad_size_divisor (int): The size of padded image should be
  54. divisible by ``pad_size_divisor``. Defaults to 1.
  55. pad_value (Number): The padded pixel value. Defaults to 0.
  56. to_rgb (bool): whether to convert image from BGR to RGB.
  57. Defaults to False.
  58. to_onehot (bool): Whether to generate one-hot format gt-labels and set
  59. to data samples. Defaults to False.
  60. num_classes (int, optional): The number of classes. Defaults to None.
  61. batch_augments (dict, optional): The batch augmentations settings,
  62. including "augments" and "probs". For more details, see
  63. :class:`mmpretrain.models.RandomBatchAugment`.
  64. """
  65. def __init__(self,
  66. mean: Sequence[Number] = None,
  67. std: Sequence[Number] = None,
  68. pad_size_divisor: int = 1,
  69. pad_value: Number = 0,
  70. to_rgb: bool = False,
  71. to_onehot: bool = False,
  72. num_classes: Optional[int] = None,
  73. batch_augments: Optional[dict] = None):
  74. if mmpretrain is None:
  75. raise RuntimeError('Please run "pip install openmim" and '
  76. 'run "mim install mmpretrain" to '
  77. 'install mmpretrain first.')
  78. super().__init__()
  79. self.pad_size_divisor = pad_size_divisor
  80. self.pad_value = pad_value
  81. self.to_rgb = to_rgb
  82. self.to_onehot = to_onehot
  83. self.num_classes = num_classes
  84. if mean is not None:
  85. assert std is not None, 'To enable the normalization in ' \
  86. 'preprocessing, please specify both `mean` and `std`.'
  87. # Enable the normalization in preprocessing.
  88. self._enable_normalize = True
  89. self.register_buffer('mean',
  90. torch.tensor(mean).view(-1, 1, 1), False)
  91. self.register_buffer('std',
  92. torch.tensor(std).view(-1, 1, 1), False)
  93. else:
  94. self._enable_normalize = False
  95. if batch_augments is not None:
  96. self.batch_augments = RandomBatchAugment(**batch_augments)
  97. if not self.to_onehot:
  98. from mmengine.logging import MMLogger
  99. MMLogger.get_current_instance().info(
  100. 'Because batch augmentations are enabled, the data '
  101. 'preprocessor automatically enables the `to_onehot` '
  102. 'option to generate one-hot format labels.')
  103. self.to_onehot = True
  104. else:
  105. self.batch_augments = None
  106. def forward(self, data: dict, training: bool = False) -> dict:
  107. """Perform normalization, padding, bgr2rgb conversion and batch
  108. augmentation based on ``BaseDataPreprocessor``.
  109. Args:
  110. data (dict): data sampled from dataloader.
  111. training (bool): Whether to enable training time augmentation.
  112. Returns:
  113. dict: Data in the same format as the model input.
  114. """
  115. inputs = self.cast_data(data['inputs'])
  116. if isinstance(inputs, torch.Tensor):
  117. # The branch if use `default_collate` as the collate_fn in the
  118. # dataloader.
  119. # ------ To RGB ------
  120. if self.to_rgb and inputs.size(1) == 3:
  121. inputs = inputs.flip(1)
  122. # -- Normalization ---
  123. inputs = inputs.float()
  124. if self._enable_normalize:
  125. inputs = (inputs - self.mean) / self.std
  126. # ------ Padding -----
  127. if self.pad_size_divisor > 1:
  128. h, w = inputs.shape[-2:]
  129. target_h = math.ceil(
  130. h / self.pad_size_divisor) * self.pad_size_divisor
  131. target_w = math.ceil(
  132. w / self.pad_size_divisor) * self.pad_size_divisor
  133. pad_h = target_h - h
  134. pad_w = target_w - w
  135. inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
  136. self.pad_value)
  137. else:
  138. # The branch if use `pseudo_collate` as the collate_fn in the
  139. # dataloader.
  140. processed_inputs = []
  141. for input_ in inputs:
  142. # ------ To RGB ------
  143. if self.to_rgb and input_.size(0) == 3:
  144. input_ = input_.flip(0)
  145. # -- Normalization ---
  146. input_ = input_.float()
  147. if self._enable_normalize:
  148. input_ = (input_ - self.mean) / self.std
  149. processed_inputs.append(input_)
  150. # Combine padding and stack
  151. inputs = stack_batch(processed_inputs, self.pad_size_divisor,
  152. self.pad_value)
  153. data_samples = data.get('data_samples', None)
  154. sample_item = data_samples[0] if data_samples is not None else None
  155. if 'gt_label' in sample_item:
  156. gt_labels = [sample.gt_label for sample in data_samples]
  157. gt_labels_tensor = [gt_label.label for gt_label in gt_labels]
  158. batch_label, label_indices = cat_batch_labels(gt_labels_tensor)
  159. batch_label = batch_label.to(self.device)
  160. batch_score = stack_batch_scores(gt_labels, device=self.device)
  161. if batch_score is None and self.to_onehot:
  162. assert batch_label is not None, \
  163. 'Cannot generate onehot format labels because no labels.'
  164. num_classes = self.num_classes or data_samples[0].get(
  165. 'num_classes')
  166. assert num_classes is not None, \
  167. 'Cannot generate one-hot format labels because not set ' \
  168. '`num_classes` in `data_preprocessor`.'
  169. batch_score = batch_label_to_onehot(batch_label, label_indices,
  170. num_classes)
  171. # ----- Batch Augmentations ----
  172. if training and self.batch_augments is not None:
  173. inputs, batch_score = self.batch_augments(inputs, batch_score)
  174. # ----- scatter labels and scores to data samples ---
  175. if batch_label is not None:
  176. for sample, label in zip(
  177. data_samples, tensor_split(batch_label,
  178. label_indices)):
  179. sample.set_gt_label(label)
  180. if batch_score is not None:
  181. for sample, score in zip(data_samples, batch_score):
  182. sample.set_gt_score(score)
  183. return {'inputs': inputs, 'data_samples': data_samples}