|
@@ -46,23 +46,25 @@ def polygons_to_mask(img_shape, polygons):
|
|
|
return mask
|
|
|
|
|
|
|
|
|
-def draw_label(label, img, label_names):
|
|
|
+def draw_label(label, img, label_names, colormap=None):
|
|
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
|
|
|
wspace=0, hspace=0)
|
|
|
plt.margins(0, 0)
|
|
|
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
|
|
|
- cmap = labelcolormap(len(label_names))
|
|
|
+ if colormap is None:
|
|
|
+ colormap = labelcolormap(len(label_names))
|
|
|
+
|
|
|
label_viz = skimage.color.label2rgb(
|
|
|
- label, img, colors=cmap[1:], bg_label=0)
|
|
|
+ label, img, colors=colormap[1:], bg_label=0, bg_color=colormap[0])
|
|
|
plt.imshow(label_viz)
|
|
|
plt.axis('off')
|
|
|
|
|
|
plt_handlers = []
|
|
|
plt_titles = []
|
|
|
for label_value, label_name in enumerate(label_names):
|
|
|
- fc = cmap[label_value]
|
|
|
+ fc = colormap[label_value]
|
|
|
p = plt.Rectangle((0, 0), 1, 1, fc=fc)
|
|
|
plt_handlers.append(p)
|
|
|
plt_titles.append(label_name)
|