123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import random
- from numbers import Number
- from typing import List, Optional, Sequence, Tuple, Union
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmengine.dist import barrier, broadcast, get_dist_info
- from mmengine.logging import MessageHub
- from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
- from mmengine.structures import PixelData
- from mmengine.utils import is_seq_of
- from torch import Tensor
- from mmdet.models.utils import unfold_wo_center
- from mmdet.models.utils.misc import samplelist_boxtype2tensor
- from mmdet.registry import MODELS
- from mmdet.structures import DetDataSample
- from mmdet.structures.mask import BitmapMasks
- from mmdet.utils import ConfigType
- try:
- import skimage
- except ImportError:
- skimage = None
- @MODELS.register_module()
- class DetDataPreprocessor(ImgDataPreprocessor):
- """Image pre-processor for detection tasks.
- Comparing with the :class:`mmengine.ImgDataPreprocessor`,
- 1. It supports batch augmentations.
- 2. It will additionally append batch_input_shape and pad_shape
- to data_samples considering the object detection task.
- It provides the data pre-processing as follows
- - Collate and move data to the target device.
- - Pad inputs to the maximum size of current batch with defined
- ``pad_value``. The padding size can be divisible by a defined
- ``pad_size_divisor``
- - Stack inputs to batch_inputs.
- - Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- - Normalize image with defined std and mean.
- - Do batch augmentations during training.
- Args:
- mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
- Defaults to None.
- std (Sequence[Number], optional): The pixel standard deviation of
- R, G, B channels. Defaults to None.
- pad_size_divisor (int): The size of padded image should be
- divisible by ``pad_size_divisor``. Defaults to 1.
- pad_value (Number): The padded pixel value. Defaults to 0.
- pad_mask (bool): Whether to pad instance masks. Defaults to False.
- mask_pad_value (int): The padded pixel value for instance masks.
- Defaults to 0.
- pad_seg (bool): Whether to pad semantic segmentation maps.
- Defaults to False.
- seg_pad_value (int): The padded pixel value for semantic
- segmentation maps. Defaults to 255.
- bgr_to_rgb (bool): whether to convert image from BGR to RGB.
- Defaults to False.
- rgb_to_bgr (bool): whether to convert image from RGB to RGB.
- Defaults to False.
- boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of
- bboxes data to ``Tensor`` type. Defaults to True.
- non_blocking (bool): Whether block current process
- when transferring data to device. Defaults to False.
- batch_augments (list[dict], optional): Batch-level augmentations
- """
- def __init__(self,
- mean: Sequence[Number] = None,
- std: Sequence[Number] = None,
- pad_size_divisor: int = 1,
- pad_value: Union[float, int] = 0,
- pad_mask: bool = False,
- mask_pad_value: int = 0,
- pad_seg: bool = False,
- seg_pad_value: int = 255,
- bgr_to_rgb: bool = False,
- rgb_to_bgr: bool = False,
- boxtype2tensor: bool = True,
- non_blocking: Optional[bool] = False,
- batch_augments: Optional[List[dict]] = None):
- super().__init__(
- mean=mean,
- std=std,
- pad_size_divisor=pad_size_divisor,
- pad_value=pad_value,
- bgr_to_rgb=bgr_to_rgb,
- rgb_to_bgr=rgb_to_bgr,
- non_blocking=non_blocking)
- if batch_augments is not None:
- self.batch_augments = nn.ModuleList(
- [MODELS.build(aug) for aug in batch_augments])
- else:
- self.batch_augments = None
- self.pad_mask = pad_mask
- self.mask_pad_value = mask_pad_value
- self.pad_seg = pad_seg
- self.seg_pad_value = seg_pad_value
- self.boxtype2tensor = boxtype2tensor
- def forward(self, data: dict, training: bool = False) -> dict:
- """Perform normalization、padding and bgr2rgb conversion based on
- ``BaseDataPreprocessor``.
- Args:
- data (dict): Data sampled from dataloader.
- training (bool): Whether to enable training time augmentation.
- Returns:
- dict: Data in the same format as the model input.
- """
- batch_pad_shape = self._get_pad_shape(data)
- data = super().forward(data=data, training=training)
- inputs, data_samples = data['inputs'], data['data_samples']
- if data_samples is not None:
- # NOTE the batched image size information may be useful, e.g.
- # in DETR, this is needed for the construction of masks, which is
- # then used for the transformer_head.
- batch_input_shape = tuple(inputs[0].size()[-2:])
- for data_sample, pad_shape in zip(data_samples, batch_pad_shape):
- data_sample.set_metainfo({
- 'batch_input_shape': batch_input_shape,
- 'pad_shape': pad_shape
- })
- if self.boxtype2tensor:
- samplelist_boxtype2tensor(data_samples)
- if self.pad_mask and training:
- self.pad_gt_masks(data_samples)
- if self.pad_seg and training:
- self.pad_gt_sem_seg(data_samples)
- if training and self.batch_augments is not None:
- for batch_aug in self.batch_augments:
- inputs, data_samples = batch_aug(inputs, data_samples)
- return {'inputs': inputs, 'data_samples': data_samples}
- def _get_pad_shape(self, data: dict) -> List[tuple]:
- """Get the pad_shape of each image based on data and
- pad_size_divisor."""
- _batch_inputs = data['inputs']
- # Process data with `pseudo_collate`.
- if is_seq_of(_batch_inputs, torch.Tensor):
- batch_pad_shape = []
- for ori_input in _batch_inputs:
- pad_h = int(
- np.ceil(ori_input.shape[1] /
- self.pad_size_divisor)) * self.pad_size_divisor
- pad_w = int(
- np.ceil(ori_input.shape[2] /
- self.pad_size_divisor)) * self.pad_size_divisor
- batch_pad_shape.append((pad_h, pad_w))
- # Process data with `default_collate`.
- elif isinstance(_batch_inputs, torch.Tensor):
- assert _batch_inputs.dim() == 4, (
- 'The input of `ImgDataPreprocessor` should be a NCHW tensor '
- 'or a list of tensor, but got a tensor with shape: '
- f'{_batch_inputs.shape}')
- pad_h = int(
- np.ceil(_batch_inputs.shape[1] /
- self.pad_size_divisor)) * self.pad_size_divisor
- pad_w = int(
- np.ceil(_batch_inputs.shape[2] /
- self.pad_size_divisor)) * self.pad_size_divisor
- batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
- else:
- raise TypeError('Output of `cast_data` should be a dict '
- 'or a tuple with inputs and data_samples, but got'
- f'{type(data)}: {data}')
- return batch_pad_shape
- def pad_gt_masks(self,
- batch_data_samples: Sequence[DetDataSample]) -> None:
- """Pad gt_masks to shape of batch_input_shape."""
- if 'masks' in batch_data_samples[0].gt_instances:
- for data_samples in batch_data_samples:
- masks = data_samples.gt_instances.masks
- data_samples.gt_instances.masks = masks.pad(
- data_samples.batch_input_shape,
- pad_val=self.mask_pad_value)
- def pad_gt_sem_seg(self,
- batch_data_samples: Sequence[DetDataSample]) -> None:
- """Pad gt_sem_seg to shape of batch_input_shape."""
- if 'gt_sem_seg' in batch_data_samples[0]:
- for data_samples in batch_data_samples:
- gt_sem_seg = data_samples.gt_sem_seg.sem_seg
- h, w = gt_sem_seg.shape[-2:]
- pad_h, pad_w = data_samples.batch_input_shape
- gt_sem_seg = F.pad(
- gt_sem_seg,
- pad=(0, max(pad_w - w, 0), 0, max(pad_h - h, 0)),
- mode='constant',
- value=self.seg_pad_value)
- data_samples.gt_sem_seg = PixelData(sem_seg=gt_sem_seg)
- @MODELS.register_module()
- class BatchSyncRandomResize(nn.Module):
- """Batch random resize which synchronizes the random size across ranks.
- Args:
- random_size_range (tuple): The multi-scale random range during
- multi-scale training.
- interval (int): The iter interval of change
- image size. Defaults to 10.
- size_divisor (int): Image size divisible factor.
- Defaults to 32.
- """
- def __init__(self,
- random_size_range: Tuple[int, int],
- interval: int = 10,
- size_divisor: int = 32) -> None:
- super().__init__()
- self.rank, self.world_size = get_dist_info()
- self._input_size = None
- self._random_size_range = (round(random_size_range[0] / size_divisor),
- round(random_size_range[1] / size_divisor))
- self._interval = interval
- self._size_divisor = size_divisor
- def forward(
- self, inputs: Tensor, data_samples: List[DetDataSample]
- ) -> Tuple[Tensor, List[DetDataSample]]:
- """resize a batch of images and bboxes to shape ``self._input_size``"""
- h, w = inputs.shape[-2:]
- if self._input_size is None:
- self._input_size = (h, w)
- scale_y = self._input_size[0] / h
- scale_x = self._input_size[1] / w
- if scale_x != 1 or scale_y != 1:
- inputs = F.interpolate(
- inputs,
- size=self._input_size,
- mode='bilinear',
- align_corners=False)
- for data_sample in data_samples:
- img_shape = (int(data_sample.img_shape[0] * scale_y),
- int(data_sample.img_shape[1] * scale_x))
- pad_shape = (int(data_sample.pad_shape[0] * scale_y),
- int(data_sample.pad_shape[1] * scale_x))
- data_sample.set_metainfo({
- 'img_shape': img_shape,
- 'pad_shape': pad_shape,
- 'batch_input_shape': self._input_size
- })
- data_sample.gt_instances.bboxes[
- ...,
- 0::2] = data_sample.gt_instances.bboxes[...,
- 0::2] * scale_x
- data_sample.gt_instances.bboxes[
- ...,
- 1::2] = data_sample.gt_instances.bboxes[...,
- 1::2] * scale_y
- if 'ignored_instances' in data_sample:
- data_sample.ignored_instances.bboxes[
- ..., 0::2] = data_sample.ignored_instances.bboxes[
- ..., 0::2] * scale_x
- data_sample.ignored_instances.bboxes[
- ..., 1::2] = data_sample.ignored_instances.bboxes[
- ..., 1::2] * scale_y
- message_hub = MessageHub.get_current_instance()
- if (message_hub.get_info('iter') + 1) % self._interval == 0:
- self._input_size = self._get_random_size(
- aspect_ratio=float(w / h), device=inputs.device)
- return inputs, data_samples
- def _get_random_size(self, aspect_ratio: float,
- device: torch.device) -> Tuple[int, int]:
- """Randomly generate a shape in ``_random_size_range`` and broadcast to
- all ranks."""
- tensor = torch.LongTensor(2).to(device)
- if self.rank == 0:
- size = random.randint(*self._random_size_range)
- size = (self._size_divisor * size,
- self._size_divisor * int(aspect_ratio * size))
- tensor[0] = size[0]
- tensor[1] = size[1]
- barrier()
- broadcast(tensor, 0)
- input_size = (tensor[0].item(), tensor[1].item())
- return input_size
- @MODELS.register_module()
- class BatchFixedSizePad(nn.Module):
- """Fixed size padding for batch images.
- Args:
- size (Tuple[int, int]): Fixed padding size. Expected padding
- shape (h, w). Defaults to None.
- img_pad_value (int): The padded pixel value for images.
- Defaults to 0.
- pad_mask (bool): Whether to pad instance masks. Defaults to False.
- mask_pad_value (int): The padded pixel value for instance masks.
- Defaults to 0.
- pad_seg (bool): Whether to pad semantic segmentation maps.
- Defaults to False.
- seg_pad_value (int): The padded pixel value for semantic
- segmentation maps. Defaults to 255.
- """
- def __init__(self,
- size: Tuple[int, int],
- img_pad_value: int = 0,
- pad_mask: bool = False,
- mask_pad_value: int = 0,
- pad_seg: bool = False,
- seg_pad_value: int = 255) -> None:
- super().__init__()
- self.size = size
- self.pad_mask = pad_mask
- self.pad_seg = pad_seg
- self.img_pad_value = img_pad_value
- self.mask_pad_value = mask_pad_value
- self.seg_pad_value = seg_pad_value
- def forward(
- self,
- inputs: Tensor,
- data_samples: Optional[List[dict]] = None
- ) -> Tuple[Tensor, Optional[List[dict]]]:
- """Pad image, instance masks, segmantic segmentation maps."""
- src_h, src_w = inputs.shape[-2:]
- dst_h, dst_w = self.size
- if src_h >= dst_h and src_w >= dst_w:
- return inputs, data_samples
- inputs = F.pad(
- inputs,
- pad=(0, max(0, dst_w - src_w), 0, max(0, dst_h - src_h)),
- mode='constant',
- value=self.img_pad_value)
- if data_samples is not None:
- # update batch_input_shape
- for data_sample in data_samples:
- data_sample.set_metainfo({
- 'batch_input_shape': (dst_h, dst_w),
- 'pad_shape': (dst_h, dst_w)
- })
- if self.pad_mask:
- for data_sample in data_samples:
- masks = data_sample.gt_instances.masks
- data_sample.gt_instances.masks = masks.pad(
- (dst_h, dst_w), pad_val=self.mask_pad_value)
- if self.pad_seg:
- for data_sample in data_samples:
- gt_sem_seg = data_sample.gt_sem_seg.sem_seg
- h, w = gt_sem_seg.shape[-2:]
- gt_sem_seg = F.pad(
- gt_sem_seg,
- pad=(0, max(0, dst_w - w), 0, max(0, dst_h - h)),
- mode='constant',
- value=self.seg_pad_value)
- data_sample.gt_sem_seg = PixelData(sem_seg=gt_sem_seg)
- return inputs, data_samples
- @MODELS.register_module()
- class MultiBranchDataPreprocessor(BaseDataPreprocessor):
- """DataPreprocessor wrapper for multi-branch data.
- Take semi-supervised object detection as an example, assume that
- the ratio of labeled data and unlabeled data in a batch is 1:2,
- `sup` indicates the branch where the labeled data is augmented,
- `unsup_teacher` and `unsup_student` indicate the branches where
- the unlabeled data is augmented by different pipeline.
- The input format of multi-branch data is shown as below :
- .. code-block:: none
- {
- 'inputs':
- {
- 'sup': [Tensor, None, None],
- 'unsup_teacher': [None, Tensor, Tensor],
- 'unsup_student': [None, Tensor, Tensor],
- },
- 'data_sample':
- {
- 'sup': [DetDataSample, None, None],
- 'unsup_teacher': [None, DetDataSample, DetDataSample],
- 'unsup_student': [NOne, DetDataSample, DetDataSample],
- }
- }
- The format of multi-branch data
- after filtering None is shown as below :
- .. code-block:: none
- {
- 'inputs':
- {
- 'sup': [Tensor],
- 'unsup_teacher': [Tensor, Tensor],
- 'unsup_student': [Tensor, Tensor],
- },
- 'data_sample':
- {
- 'sup': [DetDataSample],
- 'unsup_teacher': [DetDataSample, DetDataSample],
- 'unsup_student': [DetDataSample, DetDataSample],
- }
- }
- In order to reuse `DetDataPreprocessor` for the data
- from different branches, the format of multi-branch data
- grouped by branch is as below :
- .. code-block:: none
- {
- 'sup':
- {
- 'inputs': [Tensor]
- 'data_sample': [DetDataSample, DetDataSample]
- },
- 'unsup_teacher':
- {
- 'inputs': [Tensor, Tensor]
- 'data_sample': [DetDataSample, DetDataSample]
- },
- 'unsup_student':
- {
- 'inputs': [Tensor, Tensor]
- 'data_sample': [DetDataSample, DetDataSample]
- },
- }
- After preprocessing data from different branches,
- the multi-branch data needs to be reformatted as:
- .. code-block:: none
- {
- 'inputs':
- {
- 'sup': [Tensor],
- 'unsup_teacher': [Tensor, Tensor],
- 'unsup_student': [Tensor, Tensor],
- },
- 'data_sample':
- {
- 'sup': [DetDataSample],
- 'unsup_teacher': [DetDataSample, DetDataSample],
- 'unsup_student': [DetDataSample, DetDataSample],
- }
- }
- Args:
- data_preprocessor (:obj:`ConfigDict` or dict): Config of
- :class:`DetDataPreprocessor` to process the input data.
- """
- def __init__(self, data_preprocessor: ConfigType) -> None:
- super().__init__()
- self.data_preprocessor = MODELS.build(data_preprocessor)
- def forward(self, data: dict, training: bool = False) -> dict:
- """Perform normalization、padding and bgr2rgb conversion based on
- ``BaseDataPreprocessor`` for multi-branch data.
- Args:
- data (dict): Data sampled from dataloader.
- training (bool): Whether to enable training time augmentation.
- Returns:
- dict:
- - 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of
- models from different branches.
- - 'data_sample' (Dict[str, obj:`DetDataSample`]): The annotation
- info of the sample from different branches.
- """
- if training is False:
- return self.data_preprocessor(data, training)
- # Filter out branches with a value of None
- for key in data.keys():
- for branch in data[key].keys():
- data[key][branch] = list(
- filter(lambda x: x is not None, data[key][branch]))
- # Group data by branch
- multi_branch_data = {}
- for key in data.keys():
- for branch in data[key].keys():
- if multi_branch_data.get(branch, None) is None:
- multi_branch_data[branch] = {key: data[key][branch]}
- elif multi_branch_data[branch].get(key, None) is None:
- multi_branch_data[branch][key] = data[key][branch]
- else:
- multi_branch_data[branch][key].append(data[key][branch])
- # Preprocess data from different branches
- for branch, _data in multi_branch_data.items():
- multi_branch_data[branch] = self.data_preprocessor(_data, training)
- # Format data by inputs and data_samples
- format_data = {}
- for branch in multi_branch_data.keys():
- for key in multi_branch_data[branch].keys():
- if format_data.get(key, None) is None:
- format_data[key] = {branch: multi_branch_data[branch][key]}
- elif format_data[key].get(branch, None) is None:
- format_data[key][branch] = multi_branch_data[branch][key]
- else:
- format_data[key][branch].append(
- multi_branch_data[branch][key])
- return format_data
- @property
- def device(self):
- return self.data_preprocessor.device
- def to(self, device: Optional[Union[int, torch.device]], *args,
- **kwargs) -> nn.Module:
- """Overrides this method to set the :attr:`device`
- Args:
- device (int or torch.device, optional): The desired device of the
- parameters and buffers in this module.
- Returns:
- nn.Module: The model itself.
- """
- return self.data_preprocessor.to(device, *args, **kwargs)
- def cuda(self, *args, **kwargs) -> nn.Module:
- """Overrides this method to set the :attr:`device`
- Returns:
- nn.Module: The model itself.
- """
- return self.data_preprocessor.cuda(*args, **kwargs)
- def cpu(self, *args, **kwargs) -> nn.Module:
- """Overrides this method to set the :attr:`device`
- Returns:
- nn.Module: The model itself.
- """
- return self.data_preprocessor.cpu(*args, **kwargs)
- @MODELS.register_module()
- class BatchResize(nn.Module):
- """Batch resize during training. This implementation is modified from
- https://github.com/Purkialo/CrowdDet/blob/master/lib/data/CrowdHuman.py.
- It provides the data pre-processing as follows:
- - A batch of all images will pad to a uniform size and stack them into
- a torch.Tensor by `DetDataPreprocessor`.
- - `BatchFixShapeResize` resize all images to the target size.
- - Padding images to make sure the size of image can be divisible by
- ``pad_size_divisor``.
- Args:
- scale (tuple): Images scales for resizing.
- pad_size_divisor (int): Image size divisible factor.
- Defaults to 1.
- pad_value (Number): The padded pixel value. Defaults to 0.
- """
- def __init__(
- self,
- scale: tuple,
- pad_size_divisor: int = 1,
- pad_value: Union[float, int] = 0,
- ) -> None:
- super().__init__()
- self.min_size = min(scale)
- self.max_size = max(scale)
- self.pad_size_divisor = pad_size_divisor
- self.pad_value = pad_value
- def forward(
- self, inputs: Tensor, data_samples: List[DetDataSample]
- ) -> Tuple[Tensor, List[DetDataSample]]:
- """resize a batch of images and bboxes."""
- batch_height, batch_width = inputs.shape[-2:]
- target_height, target_width, scale = self.get_target_size(
- batch_height, batch_width)
- inputs = F.interpolate(
- inputs,
- size=(target_height, target_width),
- mode='bilinear',
- align_corners=False)
- inputs = self.get_padded_tensor(inputs, self.pad_value)
- if data_samples is not None:
- batch_input_shape = tuple(inputs.size()[-2:])
- for data_sample in data_samples:
- img_shape = [
- int(scale * _) for _ in list(data_sample.img_shape)
- ]
- data_sample.set_metainfo({
- 'img_shape': tuple(img_shape),
- 'batch_input_shape': batch_input_shape,
- 'pad_shape': batch_input_shape,
- 'scale_factor': (scale, scale)
- })
- data_sample.gt_instances.bboxes *= scale
- data_sample.ignored_instances.bboxes *= scale
- return inputs, data_samples
- def get_target_size(self, height: int,
- width: int) -> Tuple[int, int, float]:
- """Get the target size of a batch of images based on data and scale."""
- im_size_min = np.min([height, width])
- im_size_max = np.max([height, width])
- scale = self.min_size / im_size_min
- if scale * im_size_max > self.max_size:
- scale = self.max_size / im_size_max
- target_height, target_width = int(round(height * scale)), int(
- round(width * scale))
- return target_height, target_width, scale
- def get_padded_tensor(self, tensor: Tensor, pad_value: int) -> Tensor:
- """Pad images according to pad_size_divisor."""
- assert tensor.ndim == 4
- target_height, target_width = tensor.shape[-2], tensor.shape[-1]
- divisor = self.pad_size_divisor
- padded_height = (target_height + divisor - 1) // divisor * divisor
- padded_width = (target_width + divisor - 1) // divisor * divisor
- padded_tensor = torch.ones([
- tensor.shape[0], tensor.shape[1], padded_height, padded_width
- ]) * pad_value
- padded_tensor = padded_tensor.type_as(tensor)
- padded_tensor[:, :, :target_height, :target_width] = tensor
- return padded_tensor
- @MODELS.register_module()
- class BoxInstDataPreprocessor(DetDataPreprocessor):
- """Pseudo mask pre-processor for BoxInst.
- Comparing with the :class:`mmdet.DetDataPreprocessor`,
- 1. It generates masks using box annotations.
- 2. It computes the images color similarity in LAB color space.
- Args:
- mask_stride (int): The mask output stride in boxinst. Defaults to 4.
- pairwise_size (int): The size of neighborhood for each pixel.
- Defaults to 3.
- pairwise_dilation (int): The dilation of neighborhood for each pixel.
- Defaults to 2.
- pairwise_color_thresh (float): The thresh of image color similarity.
- Defaults to 0.3.
- bottom_pixels_removed (int): The length of removed pixels in bottom.
- It is caused by the annotation error in coco dataset.
- Defaults to 10.
- """
- def __init__(self,
- *arg,
- mask_stride: int = 4,
- pairwise_size: int = 3,
- pairwise_dilation: int = 2,
- pairwise_color_thresh: float = 0.3,
- bottom_pixels_removed: int = 10,
- **kwargs) -> None:
- super().__init__(*arg, **kwargs)
- self.mask_stride = mask_stride
- self.pairwise_size = pairwise_size
- self.pairwise_dilation = pairwise_dilation
- self.pairwise_color_thresh = pairwise_color_thresh
- self.bottom_pixels_removed = bottom_pixels_removed
- if skimage is None:
- raise RuntimeError('skimage is not installed,\
- please install it by: pip install scikit-image')
- def get_images_color_similarity(self, inputs: Tensor,
- image_masks: Tensor) -> Tensor:
- """Compute the image color similarity in LAB color space."""
- assert inputs.dim() == 4
- assert inputs.size(0) == 1
- unfolded_images = unfold_wo_center(
- inputs,
- kernel_size=self.pairwise_size,
- dilation=self.pairwise_dilation)
- diff = inputs[:, :, None] - unfolded_images
- similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5)
- unfolded_weights = unfold_wo_center(
- image_masks[None, None],
- kernel_size=self.pairwise_size,
- dilation=self.pairwise_dilation)
- unfolded_weights = torch.max(unfolded_weights, dim=1)[0]
- return similarity * unfolded_weights
- def forward(self, data: dict, training: bool = False) -> dict:
- """Get pseudo mask labels using color similarity."""
- det_data = super().forward(data, training)
- inputs, data_samples = det_data['inputs'], det_data['data_samples']
- if training:
- # get image masks and remove bottom pixels
- b_img_h, b_img_w = data_samples[0].batch_input_shape
- img_masks = []
- for i in range(inputs.shape[0]):
- img_h, img_w = data_samples[i].img_shape
- img_mask = inputs.new_ones((img_h, img_w))
- pixels_removed = int(self.bottom_pixels_removed *
- float(img_h) / float(b_img_h))
- if pixels_removed > 0:
- img_mask[-pixels_removed:, :] = 0
- pad_w = b_img_w - img_w
- pad_h = b_img_h - img_h
- img_mask = F.pad(img_mask, (0, pad_w, 0, pad_h), 'constant',
- 0.)
- img_masks.append(img_mask)
- img_masks = torch.stack(img_masks, dim=0)
- start = int(self.mask_stride // 2)
- img_masks = img_masks[:, start::self.mask_stride,
- start::self.mask_stride]
- # Get origin rgb image for color similarity
- ori_imgs = inputs * self.std + self.mean
- downsampled_imgs = F.avg_pool2d(
- ori_imgs.float(),
- kernel_size=self.mask_stride,
- stride=self.mask_stride,
- padding=0)
- # Compute color similarity for pseudo mask generation
- for im_i, data_sample in enumerate(data_samples):
- # TODO: Support rgb2lab in mmengine?
- images_lab = skimage.color.rgb2lab(
- downsampled_imgs[im_i].byte().permute(1, 2,
- 0).cpu().numpy())
- images_lab = torch.as_tensor(
- images_lab, device=ori_imgs.device, dtype=torch.float32)
- images_lab = images_lab.permute(2, 0, 1)[None]
- images_color_similarity = self.get_images_color_similarity(
- images_lab, img_masks[im_i])
- pairwise_mask = (images_color_similarity >=
- self.pairwise_color_thresh).float()
- per_im_bboxes = data_sample.gt_instances.bboxes
- if per_im_bboxes.shape[0] > 0:
- per_im_masks = []
- for per_box in per_im_bboxes:
- mask_full = torch.zeros((b_img_h, b_img_w),
- device=self.device).float()
- mask_full[int(per_box[1]):int(per_box[3] + 1),
- int(per_box[0]):int(per_box[2] + 1)] = 1.0
- per_im_masks.append(mask_full)
- per_im_masks = torch.stack(per_im_masks, dim=0)
- pairwise_masks = torch.cat(
- [pairwise_mask for _ in range(per_im_bboxes.shape[0])],
- dim=0)
- else:
- per_im_masks = torch.zeros((0, b_img_h, b_img_w))
- pairwise_masks = torch.zeros(
- (0, self.pairwise_size**2 - 1, b_img_h, b_img_w))
- # TODO: Support BitmapMasks with tensor?
- data_sample.gt_instances.masks = BitmapMasks(
- per_im_masks.cpu().numpy(), b_img_h, b_img_w)
- data_sample.gt_instances.pairwise_masks = pairwise_masks
- return {'inputs': inputs, 'data_samples': data_samples}
|