123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Optional, Sequence, Union
- import numpy as np
- import torch
- import torch.nn.functional as F
- from mmengine.model.utils import stack_batch
- from mmdet.models.utils.misc import samplelist_boxtype2tensor
- from mmdet.registry import MODELS
- from mmdet.structures import TrackDataSample
- from mmdet.structures.mask import BitmapMasks
- from .data_preprocessor import DetDataPreprocessor
- @MODELS.register_module()
- class TrackDataPreprocessor(DetDataPreprocessor):
- """Image pre-processor for tracking tasks.
- Accepts the data sampled by the dataloader, and preprocesses
- it into the format of the model input. ``TrackDataPreprocessor``
- provides the tracking data pre-processing as follows:
- - Collate and move data to the target device.
- - Pad inputs to the maximum size of current batch with defined
- ``pad_value``. The padding size can be divisible by a defined
- ``pad_size_divisor``
- - Stack inputs to inputs.
- - Convert inputs from bgr to rgb if the shape of input is (1, 3, H, W).
- - Normalize image with defined std and mean.
- - Do batch augmentations during training.
- - Record the information of ``batch_input_shape`` and ``pad_shape``.
- Args:
- mean (Sequence[Number], optional): The pixel mean of R, G, B
- channels. Defaults to None.
- std (Sequence[Number], optional): The pixel standard deviation of
- R, G, B channels. Defaults to None.
- pad_size_divisor (int): The size of padded image should be
- divisible by ``pad_size_divisor``. Defaults to 1.
- pad_value (Number): The padded pixel value. Defaults to 0.
- pad_mask (bool): Whether to pad instance masks. Defaults to False.
- mask_pad_value (int): The padded pixel value for instance masks.
- Defaults to 0.
- bgr_to_rgb (bool): whether to convert image from BGR to RGB.
- Defaults to False.
- rgb_to_bgr (bool): whether to convert image from RGB to RGB.
- Defaults to False.
- use_det_processor: (bool): whether to use DetDataPreprocessor
- in training phrase. This is mainly for some tracking models
- fed into one image rather than a group of image in training.
- Defaults to False.
- . boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of
- bboxes data to ``Tensor`` type. Defaults to True.
- batch_augments (list[dict], optional): Batch-level augmentations
- """
- def __init__(self,
- mean: Optional[Sequence[Union[float, int]]] = None,
- std: Optional[Sequence[Union[float, int]]] = None,
- use_det_processor: bool = False,
- **kwargs):
- super().__init__(mean=mean, std=std, **kwargs)
- self.use_det_processor = use_det_processor
- if mean is not None and not self.use_det_processor:
- # overwrite the ``register_bufffer`` in ``ImgDataPreprocessor``
- # since the shape of ``mean`` and ``std`` in tracking tasks must be
- # (T, C, H, W), which T is the temporal length of the video.
- self.register_buffer('mean',
- torch.tensor(mean).view(1, -1, 1, 1), False)
- self.register_buffer('std',
- torch.tensor(std).view(1, -1, 1, 1), False)
- def forward(self, data: dict, training: bool = False) -> Dict:
- """Perform normalization、padding and bgr2rgb conversion based on
- ``TrackDataPreprocessor``.
- Args:
- data (dict): data sampled from dataloader.
- training (bool): Whether to enable training time augmentation.
- Returns:
- Tuple[Dict[str, List[torch.Tensor]], OptSampleList]: Data in the
- same format as the model input.
- """
- if self.use_det_processor and training:
- batch_pad_shape = self._get_pad_shape(data)
- else:
- batch_pad_shape = self._get_track_pad_shape(data)
- data = self.cast_data(data)
- imgs, data_samples = data['inputs'], data['data_samples']
- if self.use_det_processor and training:
- assert imgs[0].dim() == 3, \
- 'Only support the 3 dims when use detpreprocessor in training'
- if self._channel_conversion:
- imgs = [_img[[2, 1, 0], ...] for _img in imgs]
- # Convert to `float`
- imgs = [_img.float() for _img in imgs]
- if self._enable_normalize:
- imgs = [(_img - self.mean) / self.std for _img in imgs]
- inputs = stack_batch(imgs, self.pad_size_divisor, self.pad_value)
- else:
- assert imgs[0].dim() == 4, \
- 'Only support the 4 dims when use trackprocessor in training'
- # The shape of imgs[0] is (T, C, H, W).
- channel = imgs[0].size(1)
- if self._channel_conversion and channel == 3:
- imgs = [_img[:, [2, 1, 0], ...] for _img in imgs]
- # change to `float`
- imgs = [_img.float() for _img in imgs]
- if self._enable_normalize:
- imgs = [(_img - self.mean) / self.std for _img in imgs]
- inputs = stack_track_batch(imgs, self.pad_size_divisor,
- self.pad_value)
- if data_samples is not None:
- # NOTE the batched image size information may be useful, e.g.
- # in DETR, this is needed for the construction of masks, which is
- # then used for the transformer_head.
- batch_input_shape = tuple(inputs.size()[-2:])
- if self.use_det_processor and training:
- for data_sample, pad_shape in zip(data_samples,
- batch_pad_shape):
- data_sample.set_metainfo({
- 'batch_input_shape': batch_input_shape,
- 'pad_shape': pad_shape
- })
- if self.boxtype2tensor:
- samplelist_boxtype2tensor(data_samples)
- if self.pad_mask:
- self.pad_gt_masks(data_samples)
- else:
- for track_data_sample, pad_shapes in zip(
- data_samples, batch_pad_shape):
- for i in range(len(track_data_sample)):
- det_data_sample = track_data_sample[i]
- det_data_sample.set_metainfo({
- 'batch_input_shape': batch_input_shape,
- 'pad_shape': pad_shapes[i]
- })
- if self.pad_mask and training:
- self.pad_track_gt_masks(data_samples)
- if training and self.batch_augments is not None:
- for batch_aug in self.batch_augments:
- if self.use_det_processor and training:
- inputs, data_samples = batch_aug(inputs, data_samples)
- else:
- # we only support T==1 when using batch augments.
- # Only yolox need batch_aug, and yolox can only process
- # (N, C, H, W) shape.
- # The shape of `inputs` is (N, T, C, H, W), hence, we use
- # inputs[:, 0] to change the shape to (N, C, H, W).
- assert inputs.size(1) == 1 and len(
- data_samples[0]
- ) == 1, 'Only support the number of sequence images equals to 1 when using batch augment.' # noqa: E501
- det_data_samples = [
- track_data_sample[0]
- for track_data_sample in data_samples
- ]
- aug_inputs, aug_det_samples = batch_aug(
- inputs[:, 0], det_data_samples)
- inputs = aug_inputs.unsqueeze(1)
- for track_data_sample, det_sample in zip(
- data_samples, aug_det_samples):
- track_data_sample.video_data_samples = [det_sample]
- # Note: inputs may contain large number of frames, so we must make
- # sure that the mmeory is contiguous for stable forward
- inputs = inputs.contiguous()
- return dict(inputs=inputs, data_samples=data_samples)
- def _get_track_pad_shape(self, data: dict) -> Dict[str, List]:
- """Get the pad_shape of each image based on data and pad_size_divisor.
- Args:
- data (dict): Data sampled from dataloader.
- Returns:
- Dict[str, List]: The shape of padding.
- """
- batch_pad_shape = dict()
- batch_pad_shape = []
- for imgs in data['inputs']:
- # The sequence images in one sample among a batch have the same
- # original shape
- pad_h = int(np.ceil(imgs.shape[-2] /
- self.pad_size_divisor)) * self.pad_size_divisor
- pad_w = int(np.ceil(imgs.shape[-1] /
- self.pad_size_divisor)) * self.pad_size_divisor
- pad_shapes = [(pad_h, pad_w)] * imgs.size(0)
- batch_pad_shape.append(pad_shapes)
- return batch_pad_shape
- def pad_track_gt_masks(self,
- data_samples: Sequence[TrackDataSample]) -> None:
- """Pad gt_masks to shape of batch_input_shape."""
- if 'masks' in data_samples[0][0].get('gt_instances', None):
- for track_data_sample in data_samples:
- for i in range(len(track_data_sample)):
- det_data_sample = track_data_sample[i]
- masks = det_data_sample.gt_instances.masks
- # TODO: whether to use BitmapMasks
- assert isinstance(masks, BitmapMasks)
- batch_input_shape = det_data_sample.batch_input_shape
- det_data_sample.gt_instances.masks = masks.pad(
- batch_input_shape, pad_val=self.mask_pad_value)
- def stack_track_batch(tensors: List[torch.Tensor],
- pad_size_divisor: int = 0,
- pad_value: Union[int, float] = 0) -> torch.Tensor:
- """Stack multiple tensors to form a batch and pad the images to the max
- shape use the right bottom padding mode in these images. If
- ``pad_size_divisor > 0``, add padding to ensure the common height and width
- is divisible by ``pad_size_divisor``. The difference between this function
- and ``stack_batch`` in MMEngine is that this function can process batch
- sequence images with shape (N, T, C, H, W).
- Args:
- tensors (List[Tensor]): The input multiple tensors. each is a
- TCHW 4D-tensor. T denotes the number of key/reference frames.
- pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
- to ensure the common height and width is divisible by
- ``pad_size_divisor``. This depends on the model, and many
- models need a divisibility of 32. Defaults to 0
- pad_value (int, float): The padding value. Defaults to 0
- Returns:
- Tensor: The NTCHW 5D-tensor. N denotes the batch size.
- """
- assert isinstance(tensors, list), \
- f'Expected input type to be list, but got {type(tensors)}'
- assert len(set([tensor.ndim for tensor in tensors])) == 1, \
- f'Expected the dimensions of all tensors must be the same, ' \
- f'but got {[tensor.ndim for tensor in tensors]}'
- assert tensors[0].ndim == 4, f'Expected tensor dimension to be 4, ' \
- f'but got {tensors[0].ndim}'
- assert len(set([tensor.shape[0] for tensor in tensors])) == 1, \
- f'Expected the channels of all tensors must be the same, ' \
- f'but got {[tensor.shape[0] for tensor in tensors]}'
- tensor_sizes = [(tensor.shape[-2], tensor.shape[-1]) for tensor in tensors]
- max_size = np.stack(tensor_sizes).max(0)
- if pad_size_divisor > 1:
- # the last two dims are H,W, both subject to divisibility requirement
- max_size = (
- max_size +
- (pad_size_divisor - 1)) // pad_size_divisor * pad_size_divisor
- padded_samples = []
- for tensor in tensors:
- padding_size = [
- 0, max_size[-1] - tensor.shape[-1], 0,
- max_size[-2] - tensor.shape[-2]
- ]
- if sum(padding_size) == 0:
- padded_samples.append(tensor)
- else:
- padded_samples.append(F.pad(tensor, padding_size, value=pad_value))
- return torch.stack(padded_samples, dim=0)
|