draw.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import io
  2. import numpy as np
  3. import PIL.Image
  4. import PIL.ImageDraw
  5. def label_colormap(N=256):
  6. def bitget(byteval, idx):
  7. return ((byteval & (1 << idx)) != 0)
  8. cmap = np.zeros((N, 3))
  9. for i in range(0, N):
  10. id = i
  11. r, g, b = 0, 0, 0
  12. for j in range(0, 8):
  13. r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
  14. g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
  15. b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
  16. id = (id >> 3)
  17. cmap[i, 0] = r
  18. cmap[i, 1] = g
  19. cmap[i, 2] = b
  20. cmap = cmap.astype(np.float32) / 255
  21. return cmap
  22. # similar function as skimage.color.label2rgb
  23. def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
  24. if n_labels is None:
  25. n_labels = len(np.unique(lbl))
  26. cmap = label_colormap(n_labels)
  27. cmap = (cmap * 255).astype(np.uint8)
  28. lbl_viz = cmap[lbl]
  29. lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
  30. if img is not None:
  31. img_gray = PIL.Image.fromarray(img).convert('LA')
  32. img_gray = np.asarray(img_gray.convert('RGB'))
  33. # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  34. # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
  35. lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
  36. lbl_viz = lbl_viz.astype(np.uint8)
  37. return lbl_viz
  38. def draw_label(label, img=None, label_names=None, colormap=None):
  39. import matplotlib.pyplot as plt
  40. backend_org = plt.rcParams['backend']
  41. plt.switch_backend('agg')
  42. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  43. wspace=0, hspace=0)
  44. plt.margins(0, 0)
  45. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  46. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  47. if label_names is None:
  48. label_names = [str(l) for l in range(label.max() + 1)]
  49. if colormap is None:
  50. colormap = label_colormap(len(label_names))
  51. label_viz = label2rgb(label, img, n_labels=len(label_names))
  52. plt.imshow(label_viz)
  53. plt.axis('off')
  54. plt_handlers = []
  55. plt_titles = []
  56. for label_value, label_name in enumerate(label_names):
  57. if label_value not in label:
  58. continue
  59. if label_name.startswith('_'):
  60. continue
  61. fc = colormap[label_value]
  62. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  63. plt_handlers.append(p)
  64. plt_titles.append('{value}: {name}'
  65. .format(value=label_value, name=label_name))
  66. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  67. f = io.BytesIO()
  68. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  69. plt.cla()
  70. plt.close()
  71. plt.switch_backend(backend_org)
  72. out_size = (label_viz.shape[1], label_viz.shape[0])
  73. out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
  74. out = np.asarray(out)
  75. return out