utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import logging
  2. from contextlib import contextmanager
  3. from functools import wraps
  4. import torch
  5. from mmcv.cnn.bricks.wrappers import obsolete_torch_version
  6. from torch.nn import functional as F
  7. TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
  8. def is_lower_torch_version(version=(1, 10)):
  9. """Check if the pytorch version is lower than "version."""
  10. return obsolete_torch_version(TORCH_VERSION, version)
  11. @contextmanager
  12. def _ignore_torch_cuda_oom():
  13. """A context which ignores CUDA OOM exception from pytorch."""
  14. try:
  15. yield
  16. except RuntimeError as e:
  17. if 'CUDA out of memory. ' in str(e):
  18. pass
  19. else:
  20. raise
  21. def retry_if_cuda_oom(func):
  22. """Makes a function retry itself after encountering pytorch's CUDA OOM
  23. error. It will first retry after calling `torch.cuda.empty_cache()`.
  24. If that still fails, it will then retry by trying to convert inputs
  25. to CPUs. In this case, it expects the function to dispatch to CPU
  26. implementation. The return values may become CPU tensors as well
  27. and it's user's responsibility to convert it back to CUDA tensor
  28. if needed.
  29. Args:
  30. func: a stateless callable that takes tensor-like objects as arguments
  31. Returns:
  32. a callable which retries `func` if OOM is encountered.
  33. Examples:
  34. ::
  35. output = retry_if_cuda_oom(some_torch_function)(input1, input2)
  36. # output may be on CPU even if inputs are on GPU
  37. Note:
  38. 1. When converting inputs to CPU, it will only
  39. look at each argument and check if it has `.device`
  40. and `.to` for conversion. Nested structures of tensors
  41. are not supported.
  42. 2. Since the function might be called more than once, it has to be
  43. stateless.
  44. """
  45. def maybe_to_cpu(x):
  46. try:
  47. like_gpu_tensor = x.device.type == 'cuda' and hasattr(x, 'to')
  48. except AttributeError:
  49. like_gpu_tensor = False
  50. if like_gpu_tensor:
  51. return x.to(device='cpu')
  52. else:
  53. return x
  54. @wraps(func)
  55. def wrapped(*args, **kwargs):
  56. with _ignore_torch_cuda_oom():
  57. return func(*args, **kwargs)
  58. # Clear cache and retry
  59. torch.cuda.empty_cache()
  60. with _ignore_torch_cuda_oom():
  61. return func(*args, **kwargs)
  62. # Try on CPU. This slows down the code significantly,
  63. # therefore print a notice.
  64. logger = logging.getLogger(__name__)
  65. logger.info(
  66. 'Attempting to copy inputs of {} to CPU due to CUDA OOM'.format(
  67. str(func)[0:5]))
  68. new_args = (maybe_to_cpu(x) for x in args)
  69. new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
  70. return func(*new_args, **new_kwargs)
  71. return wrapped
  72. def sem_seg_postprocess(result, img_size, output_height, output_width):
  73. """Return semantic segmentation predictions in the original resolution.
  74. The input images are often resized when entering semantic segmentor.
  75. Moreover, in same cases, they also padded inside segmentor to be
  76. divisible by maximum network stride. As a result, we often need
  77. the predictions of the segmentor in a different resolution from
  78. its inputs.
  79. Args:
  80. result (Tensor): semantic segmentation prediction logits.
  81. A tensor of shape (C, H, W), where C is the number of classes,
  82. and H, W are the height and width of the prediction.
  83. img_size (tuple): image size that segmentor is taking as input.
  84. output_height, output_width: the desired output resolution.
  85. Returns:
  86. semantic segmentation prediction (Tensor): A tensor of the shape
  87. (C, output_height, output_width) that contains per-pixel
  88. soft predictions.
  89. """
  90. result = result[:, :img_size[0], :img_size[1]].expand(1, -1, -1, -1)
  91. if is_lower_torch_version():
  92. result = F.interpolate(
  93. result,
  94. size=(output_height, output_width),
  95. mode='bicubic',
  96. align_corners=False)[0]
  97. else:
  98. result = F.interpolate(
  99. result,
  100. size=(output_height, output_width),
  101. mode='bicubic',
  102. align_corners=False,
  103. antialias=True)[0]
  104. return result
  105. def get_prompt_templates():
  106. prompt_templates = [
  107. '{}.',
  108. 'a photo of a {}.',
  109. 'a bad photo of a {}.',
  110. 'a photo of many {}.',
  111. 'a sculpture of a {}.',
  112. 'a photo of the hard to see {}.',
  113. 'a low resolution photo of the {}.',
  114. 'a rendering of a {}.',
  115. 'graffiti of a {}.',
  116. 'a bad photo of the {}.',
  117. 'a cropped photo of the {}.',
  118. 'a tattoo of a {}.',
  119. 'the embroidered {}.',
  120. 'a photo of a hard to see {}.',
  121. 'a bright photo of a {}.',
  122. 'a photo of a clean {}.',
  123. 'a photo of a dirty {}.',
  124. 'a dark photo of the {}.',
  125. 'a drawing of a {}.',
  126. 'a photo of my {}.',
  127. 'the plastic {}.',
  128. 'a photo of the cool {}.',
  129. 'a close-up photo of a {}.',
  130. 'a black and white photo of the {}.',
  131. 'a painting of the {}.',
  132. 'a painting of a {}.',
  133. 'a pixelated photo of the {}.',
  134. 'a sculpture of the {}.',
  135. 'a bright photo of the {}.',
  136. 'a cropped photo of a {}.',
  137. 'a plastic {}.',
  138. 'a photo of the dirty {}.',
  139. 'a jpeg corrupted photo of a {}.',
  140. 'a blurry photo of the {}.',
  141. 'a photo of the {}.',
  142. 'a good photo of the {}.',
  143. 'a rendering of the {}.',
  144. 'a {} in a video game.',
  145. 'a photo of one {}.',
  146. 'a doodle of a {}.',
  147. 'a close-up photo of the {}.',
  148. 'the origami {}.',
  149. 'the {} in a video game.',
  150. 'a sketch of a {}.',
  151. 'a doodle of the {}.',
  152. 'a origami {}.',
  153. 'a low resolution photo of a {}.',
  154. 'the toy {}.',
  155. 'a rendition of the {}.',
  156. 'a photo of the clean {}.',
  157. 'a photo of a large {}.',
  158. 'a rendition of a {}.',
  159. 'a photo of a nice {}.',
  160. 'a photo of a weird {}.',
  161. 'a blurry photo of a {}.',
  162. 'a cartoon {}.',
  163. 'art of a {}.',
  164. 'a sketch of the {}.',
  165. 'a embroidered {}.',
  166. 'a pixelated photo of a {}.',
  167. 'itap of the {}.',
  168. 'a jpeg corrupted photo of the {}.',
  169. 'a good photo of a {}.',
  170. 'a plushie {}.',
  171. 'a photo of the nice {}.',
  172. 'a photo of the small {}.',
  173. 'a photo of the weird {}.',
  174. 'the cartoon {}.',
  175. 'art of the {}.',
  176. 'a drawing of the {}.',
  177. 'a photo of the large {}.',
  178. 'a black and white photo of a {}.',
  179. 'the plushie {}.',
  180. 'a dark photo of a {}.',
  181. 'itap of a {}.',
  182. 'graffiti of the {}.',
  183. 'a toy {}.',
  184. 'itap of my {}.',
  185. 'a photo of a cool {}.',
  186. 'a photo of a small {}.',
  187. 'a tattoo of the {}.',
  188. ]
  189. return prompt_templates