image.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import mmcv
  4. import numpy as np
  5. import torch
  6. from torch import Tensor
  7. def imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict,
  8. new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]:
  9. """Re-normalize the image.
  10. Args:
  11. img (Tensor | ndarray): Input image. If the input is a Tensor, the
  12. shape is (1, C, H, W). If the input is a ndarray, the shape
  13. is (H, W, C).
  14. img_norm_cfg (dict): Original configuration for the normalization.
  15. new_img_norm_cfg (dict): New configuration for the normalization.
  16. Returns:
  17. Tensor | ndarray: Output image with the same type and shape of
  18. the input.
  19. """
  20. if isinstance(img, torch.Tensor):
  21. assert img.ndim == 4 and img.shape[0] == 1
  22. new_img = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
  23. new_img = _imrenormalize(new_img, img_norm_cfg, new_img_norm_cfg)
  24. new_img = new_img.transpose(2, 0, 1)[None]
  25. return torch.from_numpy(new_img).to(img)
  26. else:
  27. return _imrenormalize(img, img_norm_cfg, new_img_norm_cfg)
  28. def _imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict,
  29. new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]:
  30. """Re-normalize the image."""
  31. img_norm_cfg = img_norm_cfg.copy()
  32. new_img_norm_cfg = new_img_norm_cfg.copy()
  33. for k, v in img_norm_cfg.items():
  34. if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
  35. img_norm_cfg[k] = np.array(v, dtype=img.dtype)
  36. # reverse cfg
  37. if 'bgr_to_rgb' in img_norm_cfg:
  38. img_norm_cfg['rgb_to_bgr'] = img_norm_cfg['bgr_to_rgb']
  39. img_norm_cfg.pop('bgr_to_rgb')
  40. for k, v in new_img_norm_cfg.items():
  41. if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
  42. new_img_norm_cfg[k] = np.array(v, dtype=img.dtype)
  43. img = mmcv.imdenormalize(img, **img_norm_cfg)
  44. img = mmcv.imnormalize(img, **new_img_norm_cfg)
  45. return img