12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Union
- import mmcv
- import numpy as np
- import torch
- from torch import Tensor
- def imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict,
- new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]:
- """Re-normalize the image.
- Args:
- img (Tensor | ndarray): Input image. If the input is a Tensor, the
- shape is (1, C, H, W). If the input is a ndarray, the shape
- is (H, W, C).
- img_norm_cfg (dict): Original configuration for the normalization.
- new_img_norm_cfg (dict): New configuration for the normalization.
- Returns:
- Tensor | ndarray: Output image with the same type and shape of
- the input.
- """
- if isinstance(img, torch.Tensor):
- assert img.ndim == 4 and img.shape[0] == 1
- new_img = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
- new_img = _imrenormalize(new_img, img_norm_cfg, new_img_norm_cfg)
- new_img = new_img.transpose(2, 0, 1)[None]
- return torch.from_numpy(new_img).to(img)
- else:
- return _imrenormalize(img, img_norm_cfg, new_img_norm_cfg)
- def _imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict,
- new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]:
- """Re-normalize the image."""
- img_norm_cfg = img_norm_cfg.copy()
- new_img_norm_cfg = new_img_norm_cfg.copy()
- for k, v in img_norm_cfg.items():
- if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
- img_norm_cfg[k] = np.array(v, dtype=img.dtype)
- # reverse cfg
- if 'bgr_to_rgb' in img_norm_cfg:
- img_norm_cfg['rgb_to_bgr'] = img_norm_cfg['bgr_to_rgb']
- img_norm_cfg.pop('bgr_to_rgb')
- for k, v in new_img_norm_cfg.items():
- if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
- new_img_norm_cfg[k] = np.array(v, dtype=img.dtype)
- img = mmcv.imdenormalize(img, **img_norm_cfg)
- img = mmcv.imnormalize(img, **new_img_norm_cfg)
- return img
|