utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import base64
  2. import io
  3. import warnings
  4. import numpy as np
  5. import PIL.Image
  6. import PIL.ImageDraw
  7. def label_colormap(N=256):
  8. def bitget(byteval, idx):
  9. return ((byteval & (1 << idx)) != 0)
  10. cmap = np.zeros((N, 3))
  11. for i in range(0, N):
  12. id = i
  13. r, g, b = 0, 0, 0
  14. for j in range(0, 8):
  15. r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
  16. g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
  17. b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
  18. id = (id >> 3)
  19. cmap[i, 0] = r
  20. cmap[i, 1] = g
  21. cmap[i, 2] = b
  22. cmap = cmap.astype(np.float32) / 255
  23. return cmap
  24. def labelcolormap(N=256):
  25. warnings.warn('labelcolormap is deprecated. Please use label_colormap.')
  26. return label_colormap(N=N)
  27. # similar function as skimage.color.label2rgb
  28. def label2rgb(lbl, img=None, n_labels=None, alpha=0.3, thresh_suppress=0):
  29. if n_labels is None:
  30. n_labels = len(np.unique(lbl))
  31. cmap = label_colormap(n_labels)
  32. cmap = (cmap * 255).astype(np.uint8)
  33. lbl_viz = cmap[lbl]
  34. lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
  35. if img is not None:
  36. img_gray = PIL.Image.fromarray(img).convert('LA')
  37. img_gray = np.asarray(img_gray.convert('RGB'))
  38. # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  39. # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
  40. lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
  41. lbl_viz = lbl_viz.astype(np.uint8)
  42. return lbl_viz
  43. def img_b64_to_array(img_b64):
  44. warnings.warn('img_ba64_to_array is deprecated. '
  45. 'Please use img_b64_to_arr.')
  46. return img_b64_to_arr(img_b64)
  47. def img_b64_to_arr(img_b64):
  48. f = io.BytesIO()
  49. f.write(base64.b64decode(img_b64))
  50. img_arr = np.array(PIL.Image.open(f))
  51. return img_arr
  52. def img_arr_to_b64(img_arr):
  53. img_pil = PIL.Image.fromarray(img_arr)
  54. f = io.BytesIO()
  55. img_pil.save(f, format='PNG')
  56. img_bin = f.getvalue()
  57. img_b64 = base64.encodestring(img_bin)
  58. return img_b64
  59. def polygons_to_mask(img_shape, polygons):
  60. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  61. mask = PIL.Image.fromarray(mask)
  62. xy = list(map(tuple, polygons))
  63. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  64. mask = np.array(mask, dtype=bool)
  65. return mask
  66. def draw_label(label, img=None, label_names=None, colormap=None):
  67. import matplotlib.pyplot as plt
  68. backend_org = plt.rcParams['backend']
  69. plt.switch_backend('agg')
  70. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  71. wspace=0, hspace=0)
  72. plt.margins(0, 0)
  73. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  74. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  75. if label_names is None:
  76. label_names = [str(l) for l in range(label.max() + 1)]
  77. if colormap is None:
  78. colormap = label_colormap(len(label_names))
  79. label_viz = label2rgb(label, img, n_labels=len(label_names))
  80. plt.imshow(label_viz)
  81. plt.axis('off')
  82. plt_handlers = []
  83. plt_titles = []
  84. for label_value, label_name in enumerate(label_names):
  85. if label_value not in label:
  86. continue
  87. if label_name.startswith('_'):
  88. continue
  89. fc = colormap[label_value]
  90. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  91. plt_handlers.append(p)
  92. plt_titles.append(label_name)
  93. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  94. f = io.BytesIO()
  95. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  96. plt.cla()
  97. plt.close()
  98. plt.switch_backend(backend_org)
  99. out_size = (label_viz.shape[1], label_viz.shape[0])
  100. out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
  101. out = np.asarray(out)
  102. return out
  103. def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
  104. assert type in ['class', 'instance']
  105. cls = np.zeros(img_shape[:2], dtype=np.int32)
  106. if type == 'instance':
  107. ins = np.zeros(img_shape[:2], dtype=np.int32)
  108. instance_names = ['_background_']
  109. for shape in shapes:
  110. polygons = shape['points']
  111. label = shape['label']
  112. if type == 'class':
  113. cls_name = label
  114. elif type == 'instance':
  115. cls_name = label.split('-')[0]
  116. if label not in instance_names:
  117. instance_names.append(label)
  118. ins_id = len(instance_names) - 1
  119. cls_id = label_name_to_value[cls_name]
  120. mask = polygons_to_mask(img_shape[:2], polygons)
  121. cls[mask] = cls_id
  122. if type == 'instance':
  123. ins[mask] = ins_id
  124. if type == 'instance':
  125. return cls, ins
  126. return cls
  127. def labelme_shapes_to_label(img_shape, shapes):
  128. warnings.warn('labelme_shapes_to_label is deprecated, so please use '
  129. 'shapes_to_label.')
  130. label_name_to_value = {'_background_': 0}
  131. for shape in shapes:
  132. label_name = shape['label']
  133. if label_name in label_name_to_value:
  134. label_value = label_name_to_value[label_name]
  135. else:
  136. label_value = len(label_name_to_value)
  137. label_name_to_value[label_name] = label_value
  138. lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
  139. return lbl, label_name_to_value