draw.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import io
  2. import os.path as osp
  3. import numpy as np
  4. import PIL.Image
  5. import PIL.ImageDraw
  6. import PIL.ImageFont
  7. def label_colormap(N=256):
  8. def bitget(byteval, idx):
  9. return ((byteval & (1 << idx)) != 0)
  10. cmap = np.zeros((N, 3))
  11. for i in range(0, N):
  12. id = i
  13. r, g, b = 0, 0, 0
  14. for j in range(0, 8):
  15. r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
  16. g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
  17. b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
  18. id = (id >> 3)
  19. cmap[i, 0] = r
  20. cmap[i, 1] = g
  21. cmap[i, 2] = b
  22. cmap = cmap.astype(np.float32) / 255
  23. return cmap
  24. def _validate_colormap(colormap, n_labels):
  25. if colormap is None:
  26. colormap = label_colormap(n_labels)
  27. else:
  28. assert colormap.shape == (colormap.shape[0], 3), \
  29. 'colormap must be sequence of RGB values'
  30. assert 0 <= colormap.min() and colormap.max() <= 1, \
  31. 'colormap must ranges 0 to 1'
  32. return colormap
  33. # similar function as skimage.color.label2rgb
  34. def label2rgb(
  35. lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
  36. ):
  37. if n_labels is None:
  38. n_labels = len(np.unique(lbl))
  39. colormap = _validate_colormap(colormap, n_labels)
  40. colormap = (colormap * 255).astype(np.uint8)
  41. lbl_viz = colormap[lbl]
  42. lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
  43. if img is not None:
  44. img_gray = PIL.Image.fromarray(img).convert('LA')
  45. img_gray = np.asarray(img_gray.convert('RGB'))
  46. # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  47. # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
  48. lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
  49. lbl_viz = lbl_viz.astype(np.uint8)
  50. return lbl_viz
  51. def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
  52. """Draw pixel-wise label with colorization and label names.
  53. label: ndarray, (H, W)
  54. Pixel-wise labels to colorize.
  55. img: ndarray, (H, W, 3), optional
  56. Image on which the colorized label will be drawn.
  57. label_names: iterable
  58. List of label names.
  59. """
  60. import matplotlib.pyplot as plt
  61. backend_org = plt.rcParams['backend']
  62. plt.switch_backend('agg')
  63. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  64. wspace=0, hspace=0)
  65. plt.margins(0, 0)
  66. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  67. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  68. if label_names is None:
  69. label_names = [str(l) for l in range(label.max() + 1)]
  70. colormap = _validate_colormap(colormap, len(label_names))
  71. label_viz = label2rgb(
  72. label, img, n_labels=len(label_names), colormap=colormap, **kwargs
  73. )
  74. plt.imshow(label_viz)
  75. plt.axis('off')
  76. plt_handlers = []
  77. plt_titles = []
  78. for label_value, label_name in enumerate(label_names):
  79. if label_value not in label:
  80. continue
  81. fc = colormap[label_value]
  82. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  83. plt_handlers.append(p)
  84. plt_titles.append('{value}: {name}'
  85. .format(value=label_value, name=label_name))
  86. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  87. f = io.BytesIO()
  88. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  89. plt.cla()
  90. plt.close()
  91. plt.switch_backend(backend_org)
  92. out_size = (label_viz.shape[1], label_viz.shape[0])
  93. out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
  94. out = np.asarray(out)
  95. return out
  96. def draw_instances(
  97. image=None,
  98. bboxes=None,
  99. labels=None,
  100. masks=None,
  101. captions=None,
  102. ):
  103. import matplotlib
  104. # TODO(wkentaro)
  105. assert image is not None
  106. assert bboxes is not None
  107. assert labels is not None
  108. assert masks is None
  109. assert captions is not None
  110. viz = PIL.Image.fromarray(image)
  111. draw = PIL.ImageDraw.ImageDraw(viz)
  112. font_path = osp.join(
  113. osp.dirname(matplotlib.__file__),
  114. 'mpl-data/fonts/ttf/DejaVuSans.ttf'
  115. )
  116. font = PIL.ImageFont.truetype(font_path)
  117. colormap = label_colormap(255)
  118. for bbox, label, caption in zip(bboxes, labels, captions):
  119. color = colormap[label]
  120. color = tuple((color * 255).astype(np.uint8).tolist())
  121. xmin, ymin, xmax, ymax = bbox
  122. draw.rectangle((xmin, ymin, xmax, ymax), outline=color)
  123. draw.text((xmin, ymin), caption, font=font)
  124. return np.asarray(viz)