123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- import logging
- from contextlib import contextmanager
- from functools import wraps
- import torch
- from mmcv.cnn.bricks.wrappers import obsolete_torch_version
- from torch.nn import functional as F
- TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
- def is_lower_torch_version(version=(1, 10)):
- """Check if the pytorch version is lower than "version."""
- return obsolete_torch_version(TORCH_VERSION, version)
- @contextmanager
- def _ignore_torch_cuda_oom():
- """A context which ignores CUDA OOM exception from pytorch."""
- try:
- yield
- except RuntimeError as e:
- if 'CUDA out of memory. ' in str(e):
- pass
- else:
- raise
- def retry_if_cuda_oom(func):
- """Makes a function retry itself after encountering pytorch's CUDA OOM
- error. It will first retry after calling `torch.cuda.empty_cache()`.
- If that still fails, it will then retry by trying to convert inputs
- to CPUs. In this case, it expects the function to dispatch to CPU
- implementation. The return values may become CPU tensors as well
- and it's user's responsibility to convert it back to CUDA tensor
- if needed.
- Args:
- func: a stateless callable that takes tensor-like objects as arguments
- Returns:
- a callable which retries `func` if OOM is encountered.
- Examples:
- ::
- output = retry_if_cuda_oom(some_torch_function)(input1, input2)
- # output may be on CPU even if inputs are on GPU
- Note:
- 1. When converting inputs to CPU, it will only
- look at each argument and check if it has `.device`
- and `.to` for conversion. Nested structures of tensors
- are not supported.
- 2. Since the function might be called more than once, it has to be
- stateless.
- """
- def maybe_to_cpu(x):
- try:
- like_gpu_tensor = x.device.type == 'cuda' and hasattr(x, 'to')
- except AttributeError:
- like_gpu_tensor = False
- if like_gpu_tensor:
- return x.to(device='cpu')
- else:
- return x
- @wraps(func)
- def wrapped(*args, **kwargs):
- with _ignore_torch_cuda_oom():
- return func(*args, **kwargs)
- # Clear cache and retry
- torch.cuda.empty_cache()
- with _ignore_torch_cuda_oom():
- return func(*args, **kwargs)
- # Try on CPU. This slows down the code significantly,
- # therefore print a notice.
- logger = logging.getLogger(__name__)
- logger.info(
- 'Attempting to copy inputs of {} to CPU due to CUDA OOM'.format(
- str(func)[0:5]))
- new_args = (maybe_to_cpu(x) for x in args)
- new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
- return func(*new_args, **new_kwargs)
- return wrapped
- def sem_seg_postprocess(result, img_size, output_height, output_width):
- """Return semantic segmentation predictions in the original resolution.
- The input images are often resized when entering semantic segmentor.
- Moreover, in same cases, they also padded inside segmentor to be
- divisible by maximum network stride. As a result, we often need
- the predictions of the segmentor in a different resolution from
- its inputs.
- Args:
- result (Tensor): semantic segmentation prediction logits.
- A tensor of shape (C, H, W), where C is the number of classes,
- and H, W are the height and width of the prediction.
- img_size (tuple): image size that segmentor is taking as input.
- output_height, output_width: the desired output resolution.
- Returns:
- semantic segmentation prediction (Tensor): A tensor of the shape
- (C, output_height, output_width) that contains per-pixel
- soft predictions.
- """
- result = result[:, :img_size[0], :img_size[1]].expand(1, -1, -1, -1)
- if is_lower_torch_version():
- result = F.interpolate(
- result,
- size=(output_height, output_width),
- mode='bicubic',
- align_corners=False)[0]
- else:
- result = F.interpolate(
- result,
- size=(output_height, output_width),
- mode='bicubic',
- align_corners=False,
- antialias=True)[0]
- return result
- def get_prompt_templates():
- prompt_templates = [
- '{}.',
- 'a photo of a {}.',
- 'a bad photo of a {}.',
- 'a photo of many {}.',
- 'a sculpture of a {}.',
- 'a photo of the hard to see {}.',
- 'a low resolution photo of the {}.',
- 'a rendering of a {}.',
- 'graffiti of a {}.',
- 'a bad photo of the {}.',
- 'a cropped photo of the {}.',
- 'a tattoo of a {}.',
- 'the embroidered {}.',
- 'a photo of a hard to see {}.',
- 'a bright photo of a {}.',
- 'a photo of a clean {}.',
- 'a photo of a dirty {}.',
- 'a dark photo of the {}.',
- 'a drawing of a {}.',
- 'a photo of my {}.',
- 'the plastic {}.',
- 'a photo of the cool {}.',
- 'a close-up photo of a {}.',
- 'a black and white photo of the {}.',
- 'a painting of the {}.',
- 'a painting of a {}.',
- 'a pixelated photo of the {}.',
- 'a sculpture of the {}.',
- 'a bright photo of the {}.',
- 'a cropped photo of a {}.',
- 'a plastic {}.',
- 'a photo of the dirty {}.',
- 'a jpeg corrupted photo of a {}.',
- 'a blurry photo of the {}.',
- 'a photo of the {}.',
- 'a good photo of the {}.',
- 'a rendering of the {}.',
- 'a {} in a video game.',
- 'a photo of one {}.',
- 'a doodle of a {}.',
- 'a close-up photo of the {}.',
- 'the origami {}.',
- 'the {} in a video game.',
- 'a sketch of a {}.',
- 'a doodle of the {}.',
- 'a origami {}.',
- 'a low resolution photo of a {}.',
- 'the toy {}.',
- 'a rendition of the {}.',
- 'a photo of the clean {}.',
- 'a photo of a large {}.',
- 'a rendition of a {}.',
- 'a photo of a nice {}.',
- 'a photo of a weird {}.',
- 'a blurry photo of a {}.',
- 'a cartoon {}.',
- 'art of a {}.',
- 'a sketch of the {}.',
- 'a embroidered {}.',
- 'a pixelated photo of a {}.',
- 'itap of the {}.',
- 'a jpeg corrupted photo of the {}.',
- 'a good photo of a {}.',
- 'a plushie {}.',
- 'a photo of the nice {}.',
- 'a photo of the small {}.',
- 'a photo of the weird {}.',
- 'the cartoon {}.',
- 'art of the {}.',
- 'a drawing of the {}.',
- 'a photo of the large {}.',
- 'a black and white photo of a {}.',
- 'the plushie {}.',
- 'a dark photo of a {}.',
- 'itap of a {}.',
- 'graffiti of the {}.',
- 'a toy {}.',
- 'itap of my {}.',
- 'a photo of a cool {}.',
- 'a photo of a small {}.',
- 'a tattoo of the {}.',
- ]
- return prompt_templates
|