123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Union
- import cv2
- import mmcv
- import numpy as np
- from mmcv.transforms import BaseTransform
- from mmcv.transforms.utils import cache_randomness
- from mmdet.registry import TRANSFORMS
- from mmdet.structures.bbox import autocast_box_type
- from .augment_wrappers import _MAX_LEVEL, level_to_mag
- @TRANSFORMS.register_module()
- class GeomTransform(BaseTransform):
- """Base class for geometric transformations. All geometric transformations
- need to inherit from this base class. ``GeomTransform`` unifies the class
- attributes and class functions of geometric transformations (ShearX,
- ShearY, Rotate, TranslateX, and TranslateY), and records the homography
- matrix.
- Required Keys:
- - img
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_seg_map (np.uint8) (optional)
- Modified Keys:
- - img
- - gt_bboxes
- - gt_masks
- - gt_seg_map
- Added Keys:
- - homography_matrix
- Args:
- prob (float): The probability for performing the geometric
- transformation and should be in range [0, 1]. Defaults to 1.0.
- level (int, optional): The level should be in range [0, _MAX_LEVEL].
- If level is None, it will generate from [0, _MAX_LEVEL] randomly.
- Defaults to None.
- min_mag (float): The minimum magnitude for geometric transformation.
- Defaults to 0.0.
- max_mag (float): The maximum magnitude for geometric transformation.
- Defaults to 1.0.
- reversal_prob (float): The probability that reverses the geometric
- transformation magnitude. Should be in range [0,1].
- Defaults to 0.5.
- img_border_value (int | float | tuple): The filled values for
- image border. If float, the same fill value will be used for
- all the three channels of image. If tuple, it should be 3 elements.
- Defaults to 128.
- mask_border_value (int): The fill value used for masks. Defaults to 0.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend. Defaults
- to 'bilinear'.
- """
- def __init__(self,
- prob: float = 1.0,
- level: Optional[int] = None,
- min_mag: float = 0.0,
- max_mag: float = 1.0,
- reversal_prob: float = 0.5,
- img_border_value: Union[int, float, tuple] = 128,
- mask_border_value: int = 0,
- seg_ignore_label: int = 255,
- interpolation: str = 'bilinear') -> None:
- assert 0 <= prob <= 1.0, f'The probability of the transformation ' \
- f'should be in range [0,1], got {prob}.'
- assert level is None or isinstance(level, int), \
- f'The level should be None or type int, got {type(level)}.'
- assert level is None or 0 <= level <= _MAX_LEVEL, \
- f'The level should be in range [0,{_MAX_LEVEL}], got {level}.'
- assert isinstance(min_mag, float), \
- f'min_mag should be type float, got {type(min_mag)}.'
- assert isinstance(max_mag, float), \
- f'max_mag should be type float, got {type(max_mag)}.'
- assert min_mag <= max_mag, \
- f'min_mag should smaller than max_mag, ' \
- f'got min_mag={min_mag} and max_mag={max_mag}'
- assert isinstance(reversal_prob, float), \
- f'reversal_prob should be type float, got {type(max_mag)}.'
- assert 0 <= reversal_prob <= 1.0, \
- f'The reversal probability of the transformation magnitude ' \
- f'should be type float, got {type(reversal_prob)}.'
- if isinstance(img_border_value, (float, int)):
- img_border_value = tuple([float(img_border_value)] * 3)
- elif isinstance(img_border_value, tuple):
- assert len(img_border_value) == 3, \
- f'img_border_value as tuple must have 3 elements, ' \
- f'got {len(img_border_value)}.'
- img_border_value = tuple([float(val) for val in img_border_value])
- else:
- raise ValueError(
- 'img_border_value must be float or tuple with 3 elements.')
- assert np.all([0 <= val <= 255 for val in img_border_value]), 'all ' \
- 'elements of img_border_value should between range [0,255].' \
- f'got {img_border_value}.'
- self.prob = prob
- self.level = level
- self.min_mag = min_mag
- self.max_mag = max_mag
- self.reversal_prob = reversal_prob
- self.img_border_value = img_border_value
- self.mask_border_value = mask_border_value
- self.seg_ignore_label = seg_ignore_label
- self.interpolation = interpolation
- def _transform_img(self, results: dict, mag: float) -> None:
- """Transform the image."""
- pass
- def _transform_masks(self, results: dict, mag: float) -> None:
- """Transform the masks."""
- pass
- def _transform_seg(self, results: dict, mag: float) -> None:
- """Transform the segmentation map."""
- pass
- def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
- """Get the homography matrix for the geometric transformation."""
- return np.eye(3, dtype=np.float32)
- def _transform_bboxes(self, results: dict, mag: float) -> None:
- """Transform the bboxes."""
- results['gt_bboxes'].project_(self.homography_matrix)
- results['gt_bboxes'].clip_(results['img_shape'])
- def _record_homography_matrix(self, results: dict) -> None:
- """Record the homography matrix for the geometric transformation."""
- if results.get('homography_matrix', None) is None:
- results['homography_matrix'] = self.homography_matrix
- else:
- results['homography_matrix'] = self.homography_matrix @ results[
- 'homography_matrix']
- @cache_randomness
- def _random_disable(self):
- """Randomly disable the transform."""
- return np.random.rand() > self.prob
- @cache_randomness
- def _get_mag(self):
- """Get the magnitude of the transform."""
- mag = level_to_mag(self.level, self.min_mag, self.max_mag)
- return -mag if np.random.rand() > self.reversal_prob else mag
- @autocast_box_type()
- def transform(self, results: dict) -> dict:
- """Transform function for images, bounding boxes, masks and semantic
- segmentation map.
- Args:
- results (dict): Result dict from loading pipeline.
- Returns:
- dict: Transformed results.
- """
- if self._random_disable():
- return results
- mag = self._get_mag()
- self.homography_matrix = self._get_homography_matrix(results, mag)
- self._record_homography_matrix(results)
- self._transform_img(results, mag)
- if results.get('gt_bboxes', None) is not None:
- self._transform_bboxes(results, mag)
- if results.get('gt_masks', None) is not None:
- self._transform_masks(results, mag)
- if results.get('gt_seg_map', None) is not None:
- self._transform_seg(results, mag)
- return results
- def __repr__(self) -> str:
- repr_str = self.__class__.__name__
- repr_str += f'(prob={self.prob}, '
- repr_str += f'level={self.level}, '
- repr_str += f'min_mag={self.min_mag}, '
- repr_str += f'max_mag={self.max_mag}, '
- repr_str += f'reversal_prob={self.reversal_prob}, '
- repr_str += f'img_border_value={self.img_border_value}, '
- repr_str += f'mask_border_value={self.mask_border_value}, '
- repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
- repr_str += f'interpolation={self.interpolation})'
- return repr_str
- @TRANSFORMS.register_module()
- class ShearX(GeomTransform):
- """Shear the images, bboxes, masks and segmentation map horizontally.
- Required Keys:
- - img
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_seg_map (np.uint8) (optional)
- Modified Keys:
- - img
- - gt_bboxes
- - gt_masks
- - gt_seg_map
- Added Keys:
- - homography_matrix
- Args:
- prob (float): The probability for performing Shear and should be in
- range [0, 1]. Defaults to 1.0.
- level (int, optional): The level should be in range [0, _MAX_LEVEL].
- If level is None, it will generate from [0, _MAX_LEVEL] randomly.
- Defaults to None.
- min_mag (float): The minimum angle for the horizontal shear.
- Defaults to 0.0.
- max_mag (float): The maximum angle for the horizontal shear.
- Defaults to 30.0.
- reversal_prob (float): The probability that reverses the horizontal
- shear magnitude. Should be in range [0,1]. Defaults to 0.5.
- img_border_value (int | float | tuple): The filled values for
- image border. If float, the same fill value will be used for
- all the three channels of image. If tuple, it should be 3 elements.
- Defaults to 128.
- mask_border_value (int): The fill value used for masks. Defaults to 0.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend. Defaults
- to 'bilinear'.
- """
- def __init__(self,
- prob: float = 1.0,
- level: Optional[int] = None,
- min_mag: float = 0.0,
- max_mag: float = 30.0,
- reversal_prob: float = 0.5,
- img_border_value: Union[int, float, tuple] = 128,
- mask_border_value: int = 0,
- seg_ignore_label: int = 255,
- interpolation: str = 'bilinear') -> None:
- assert 0. <= min_mag <= 90., \
- f'min_mag angle for ShearX should be ' \
- f'in range [0, 90], got {min_mag}.'
- assert 0. <= max_mag <= 90., \
- f'max_mag angle for ShearX should be ' \
- f'in range [0, 90], got {max_mag}.'
- super().__init__(
- prob=prob,
- level=level,
- min_mag=min_mag,
- max_mag=max_mag,
- reversal_prob=reversal_prob,
- img_border_value=img_border_value,
- mask_border_value=mask_border_value,
- seg_ignore_label=seg_ignore_label,
- interpolation=interpolation)
- @cache_randomness
- def _get_mag(self):
- """Get the magnitude of the transform."""
- mag = level_to_mag(self.level, self.min_mag, self.max_mag)
- mag = np.tan(mag * np.pi / 180)
- return -mag if np.random.rand() > self.reversal_prob else mag
- def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
- """Get the homography matrix for ShearX."""
- return np.array([[1, mag, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
- def _transform_img(self, results: dict, mag: float) -> None:
- """Shear the image horizontally."""
- results['img'] = mmcv.imshear(
- results['img'],
- mag,
- direction='horizontal',
- border_value=self.img_border_value,
- interpolation=self.interpolation)
- def _transform_masks(self, results: dict, mag: float) -> None:
- """Shear the masks horizontally."""
- results['gt_masks'] = results['gt_masks'].shear(
- results['img_shape'],
- mag,
- direction='horizontal',
- border_value=self.mask_border_value,
- interpolation=self.interpolation)
- def _transform_seg(self, results: dict, mag: float) -> None:
- """Shear the segmentation map horizontally."""
- results['gt_seg_map'] = mmcv.imshear(
- results['gt_seg_map'],
- mag,
- direction='horizontal',
- border_value=self.seg_ignore_label,
- interpolation='nearest')
- @TRANSFORMS.register_module()
- class ShearY(GeomTransform):
- """Shear the images, bboxes, masks and segmentation map vertically.
- Required Keys:
- - img
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_seg_map (np.uint8) (optional)
- Modified Keys:
- - img
- - gt_bboxes
- - gt_masks
- - gt_seg_map
- Added Keys:
- - homography_matrix
- Args:
- prob (float): The probability for performing ShearY and should be in
- range [0, 1]. Defaults to 1.0.
- level (int, optional): The level should be in range [0,_MAX_LEVEL].
- If level is None, it will generate from [0, _MAX_LEVEL] randomly.
- Defaults to None.
- min_mag (float): The minimum angle for the vertical shear.
- Defaults to 0.0.
- max_mag (float): The maximum angle for the vertical shear.
- Defaults to 30.0.
- reversal_prob (float): The probability that reverses the vertical
- shear magnitude. Should be in range [0,1]. Defaults to 0.5.
- img_border_value (int | float | tuple): The filled values for
- image border. If float, the same fill value will be used for
- all the three channels of image. If tuple, it should be 3 elements.
- Defaults to 128.
- mask_border_value (int): The fill value used for masks. Defaults to 0.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend. Defaults
- to 'bilinear'.
- """
- def __init__(self,
- prob: float = 1.0,
- level: Optional[int] = None,
- min_mag: float = 0.0,
- max_mag: float = 30.,
- reversal_prob: float = 0.5,
- img_border_value: Union[int, float, tuple] = 128,
- mask_border_value: int = 0,
- seg_ignore_label: int = 255,
- interpolation: str = 'bilinear') -> None:
- assert 0. <= min_mag <= 90., \
- f'min_mag angle for ShearY should be ' \
- f'in range [0, 90], got {min_mag}.'
- assert 0. <= max_mag <= 90., \
- f'max_mag angle for ShearY should be ' \
- f'in range [0, 90], got {max_mag}.'
- super().__init__(
- prob=prob,
- level=level,
- min_mag=min_mag,
- max_mag=max_mag,
- reversal_prob=reversal_prob,
- img_border_value=img_border_value,
- mask_border_value=mask_border_value,
- seg_ignore_label=seg_ignore_label,
- interpolation=interpolation)
- @cache_randomness
- def _get_mag(self):
- """Get the magnitude of the transform."""
- mag = level_to_mag(self.level, self.min_mag, self.max_mag)
- mag = np.tan(mag * np.pi / 180)
- return -mag if np.random.rand() > self.reversal_prob else mag
- def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
- """Get the homography matrix for ShearY."""
- return np.array([[1, 0, 0], [mag, 1, 0], [0, 0, 1]], dtype=np.float32)
- def _transform_img(self, results: dict, mag: float) -> None:
- """Shear the image vertically."""
- results['img'] = mmcv.imshear(
- results['img'],
- mag,
- direction='vertical',
- border_value=self.img_border_value,
- interpolation=self.interpolation)
- def _transform_masks(self, results: dict, mag: float) -> None:
- """Shear the masks vertically."""
- results['gt_masks'] = results['gt_masks'].shear(
- results['img_shape'],
- mag,
- direction='vertical',
- border_value=self.mask_border_value,
- interpolation=self.interpolation)
- def _transform_seg(self, results: dict, mag: float) -> None:
- """Shear the segmentation map vertically."""
- results['gt_seg_map'] = mmcv.imshear(
- results['gt_seg_map'],
- mag,
- direction='vertical',
- border_value=self.seg_ignore_label,
- interpolation='nearest')
- @TRANSFORMS.register_module()
- class Rotate(GeomTransform):
- """Rotate the images, bboxes, masks and segmentation map.
- Required Keys:
- - img
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_seg_map (np.uint8) (optional)
- Modified Keys:
- - img
- - gt_bboxes
- - gt_masks
- - gt_seg_map
- Added Keys:
- - homography_matrix
- Args:
- prob (float): The probability for perform transformation and
- should be in range 0 to 1. Defaults to 1.0.
- level (int, optional): The level should be in range [0, _MAX_LEVEL].
- If level is None, it will generate from [0, _MAX_LEVEL] randomly.
- Defaults to None.
- min_mag (float): The maximum angle for rotation.
- Defaults to 0.0.
- max_mag (float): The maximum angle for rotation.
- Defaults to 30.0.
- reversal_prob (float): The probability that reverses the rotation
- magnitude. Should be in range [0,1]. Defaults to 0.5.
- img_border_value (int | float | tuple): The filled values for
- image border. If float, the same fill value will be used for
- all the three channels of image. If tuple, it should be 3 elements.
- Defaults to 128.
- mask_border_value (int): The fill value used for masks. Defaults to 0.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend. Defaults
- to 'bilinear'.
- """
- def __init__(self,
- prob: float = 1.0,
- level: Optional[int] = None,
- min_mag: float = 0.0,
- max_mag: float = 30.0,
- reversal_prob: float = 0.5,
- img_border_value: Union[int, float, tuple] = 128,
- mask_border_value: int = 0,
- seg_ignore_label: int = 255,
- interpolation: str = 'bilinear') -> None:
- assert 0. <= min_mag <= 180., \
- f'min_mag for Rotate should be in range [0,180], got {min_mag}.'
- assert 0. <= max_mag <= 180., \
- f'max_mag for Rotate should be in range [0,180], got {max_mag}.'
- super().__init__(
- prob=prob,
- level=level,
- min_mag=min_mag,
- max_mag=max_mag,
- reversal_prob=reversal_prob,
- img_border_value=img_border_value,
- mask_border_value=mask_border_value,
- seg_ignore_label=seg_ignore_label,
- interpolation=interpolation)
- def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
- """Get the homography matrix for Rotate."""
- img_shape = results['img_shape']
- center = ((img_shape[1] - 1) * 0.5, (img_shape[0] - 1) * 0.5)
- cv2_rotation_matrix = cv2.getRotationMatrix2D(center, -mag, 1.0)
- return np.concatenate(
- [cv2_rotation_matrix,
- np.array([0, 0, 1]).reshape((1, 3))]).astype(np.float32)
- def _transform_img(self, results: dict, mag: float) -> None:
- """Rotate the image."""
- results['img'] = mmcv.imrotate(
- results['img'],
- mag,
- border_value=self.img_border_value,
- interpolation=self.interpolation)
- def _transform_masks(self, results: dict, mag: float) -> None:
- """Rotate the masks."""
- results['gt_masks'] = results['gt_masks'].rotate(
- results['img_shape'],
- mag,
- border_value=self.mask_border_value,
- interpolation=self.interpolation)
- def _transform_seg(self, results: dict, mag: float) -> None:
- """Rotate the segmentation map."""
- results['gt_seg_map'] = mmcv.imrotate(
- results['gt_seg_map'],
- mag,
- border_value=self.seg_ignore_label,
- interpolation='nearest')
- @TRANSFORMS.register_module()
- class TranslateX(GeomTransform):
- """Translate the images, bboxes, masks and segmentation map horizontally.
- Required Keys:
- - img
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_seg_map (np.uint8) (optional)
- Modified Keys:
- - img
- - gt_bboxes
- - gt_masks
- - gt_seg_map
- Added Keys:
- - homography_matrix
- Args:
- prob (float): The probability for perform transformation and
- should be in range 0 to 1. Defaults to 1.0.
- level (int, optional): The level should be in range [0, _MAX_LEVEL].
- If level is None, it will generate from [0, _MAX_LEVEL] randomly.
- Defaults to None.
- min_mag (float): The minimum pixel's offset ratio for horizontal
- translation. Defaults to 0.0.
- max_mag (float): The maximum pixel's offset ratio for horizontal
- translation. Defaults to 0.1.
- reversal_prob (float): The probability that reverses the horizontal
- translation magnitude. Should be in range [0,1]. Defaults to 0.5.
- img_border_value (int | float | tuple): The filled values for
- image border. If float, the same fill value will be used for
- all the three channels of image. If tuple, it should be 3 elements.
- Defaults to 128.
- mask_border_value (int): The fill value used for masks. Defaults to 0.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend. Defaults
- to 'bilinear'.
- """
- def __init__(self,
- prob: float = 1.0,
- level: Optional[int] = None,
- min_mag: float = 0.0,
- max_mag: float = 0.1,
- reversal_prob: float = 0.5,
- img_border_value: Union[int, float, tuple] = 128,
- mask_border_value: int = 0,
- seg_ignore_label: int = 255,
- interpolation: str = 'bilinear') -> None:
- assert 0. <= min_mag <= 1., \
- f'min_mag ratio for TranslateX should be ' \
- f'in range [0, 1], got {min_mag}.'
- assert 0. <= max_mag <= 1., \
- f'max_mag ratio for TranslateX should be ' \
- f'in range [0, 1], got {max_mag}.'
- super().__init__(
- prob=prob,
- level=level,
- min_mag=min_mag,
- max_mag=max_mag,
- reversal_prob=reversal_prob,
- img_border_value=img_border_value,
- mask_border_value=mask_border_value,
- seg_ignore_label=seg_ignore_label,
- interpolation=interpolation)
- def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
- """Get the homography matrix for TranslateX."""
- mag = int(results['img_shape'][1] * mag)
- return np.array([[1, 0, mag], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
- def _transform_img(self, results: dict, mag: float) -> None:
- """Translate the image horizontally."""
- mag = int(results['img_shape'][1] * mag)
- results['img'] = mmcv.imtranslate(
- results['img'],
- mag,
- direction='horizontal',
- border_value=self.img_border_value,
- interpolation=self.interpolation)
- def _transform_masks(self, results: dict, mag: float) -> None:
- """Translate the masks horizontally."""
- mag = int(results['img_shape'][1] * mag)
- results['gt_masks'] = results['gt_masks'].translate(
- results['img_shape'],
- mag,
- direction='horizontal',
- border_value=self.mask_border_value,
- interpolation=self.interpolation)
- def _transform_seg(self, results: dict, mag: float) -> None:
- """Translate the segmentation map horizontally."""
- mag = int(results['img_shape'][1] * mag)
- results['gt_seg_map'] = mmcv.imtranslate(
- results['gt_seg_map'],
- mag,
- direction='horizontal',
- border_value=self.seg_ignore_label,
- interpolation='nearest')
- @TRANSFORMS.register_module()
- class TranslateY(GeomTransform):
- """Translate the images, bboxes, masks and segmentation map vertically.
- Required Keys:
- - img
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_seg_map (np.uint8) (optional)
- Modified Keys:
- - img
- - gt_bboxes
- - gt_masks
- - gt_seg_map
- Added Keys:
- - homography_matrix
- Args:
- prob (float): The probability for perform transformation and
- should be in range 0 to 1. Defaults to 1.0.
- level (int, optional): The level should be in range [0, _MAX_LEVEL].
- If level is None, it will generate from [0, _MAX_LEVEL] randomly.
- Defaults to None.
- min_mag (float): The minimum pixel's offset ratio for vertical
- translation. Defaults to 0.0.
- max_mag (float): The maximum pixel's offset ratio for vertical
- translation. Defaults to 0.1.
- reversal_prob (float): The probability that reverses the vertical
- translation magnitude. Should be in range [0,1]. Defaults to 0.5.
- img_border_value (int | float | tuple): The filled values for
- image border. If float, the same fill value will be used for
- all the three channels of image. If tuple, it should be 3 elements.
- Defaults to 128.
- mask_border_value (int): The fill value used for masks. Defaults to 0.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend. Defaults
- to 'bilinear'.
- """
- def __init__(self,
- prob: float = 1.0,
- level: Optional[int] = None,
- min_mag: float = 0.0,
- max_mag: float = 0.1,
- reversal_prob: float = 0.5,
- img_border_value: Union[int, float, tuple] = 128,
- mask_border_value: int = 0,
- seg_ignore_label: int = 255,
- interpolation: str = 'bilinear') -> None:
- assert 0. <= min_mag <= 1., \
- f'min_mag ratio for TranslateY should be ' \
- f'in range [0,1], got {min_mag}.'
- assert 0. <= max_mag <= 1., \
- f'max_mag ratio for TranslateY should be ' \
- f'in range [0,1], got {max_mag}.'
- super().__init__(
- prob=prob,
- level=level,
- min_mag=min_mag,
- max_mag=max_mag,
- reversal_prob=reversal_prob,
- img_border_value=img_border_value,
- mask_border_value=mask_border_value,
- seg_ignore_label=seg_ignore_label,
- interpolation=interpolation)
- def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
- """Get the homography matrix for TranslateY."""
- mag = int(results['img_shape'][0] * mag)
- return np.array([[1, 0, 0], [0, 1, mag], [0, 0, 1]], dtype=np.float32)
- def _transform_img(self, results: dict, mag: float) -> None:
- """Translate the image vertically."""
- mag = int(results['img_shape'][0] * mag)
- results['img'] = mmcv.imtranslate(
- results['img'],
- mag,
- direction='vertical',
- border_value=self.img_border_value,
- interpolation=self.interpolation)
- def _transform_masks(self, results: dict, mag: float) -> None:
- """Translate masks vertically."""
- mag = int(results['img_shape'][0] * mag)
- results['gt_masks'] = results['gt_masks'].translate(
- results['img_shape'],
- mag,
- direction='vertical',
- border_value=self.mask_border_value,
- interpolation=self.interpolation)
- def _transform_seg(self, results: dict, mag: float) -> None:
- """Translate segmentation map vertically."""
- mag = int(results['img_shape'][0] * mag)
- results['gt_seg_map'] = mmcv.imtranslate(
- results['gt_seg_map'],
- mag,
- direction='vertical',
- border_value=self.seg_ignore_label,
- interpolation='nearest')
|