|
@@ -26,15 +26,28 @@ def label_colormap(N=256):
|
|
|
return cmap
|
|
|
|
|
|
|
|
|
+def _validate_colormap(colormap, n_labels):
|
|
|
+ if colormap is None:
|
|
|
+ colormap = label_colormap(n_labels)
|
|
|
+ else:
|
|
|
+ assert colormap.shape == (colormap.shape[0], 3), \
|
|
|
+ 'colormap must be sequence of RGB values'
|
|
|
+ assert 0 <= colormap.min() and colormap.max() <= 1, \
|
|
|
+ 'colormap must ranges 0 to 1'
|
|
|
+ return colormap
|
|
|
+
|
|
|
+
|
|
|
# similar function as skimage.color.label2rgb
|
|
|
-def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
|
|
|
+def label2rgb(
|
|
|
+ lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
|
|
|
+):
|
|
|
if n_labels is None:
|
|
|
n_labels = len(np.unique(lbl))
|
|
|
|
|
|
- cmap = label_colormap(n_labels)
|
|
|
- cmap = (cmap * 255).astype(np.uint8)
|
|
|
+ colormap = _validate_colormap(colormap, n_labels)
|
|
|
+ colormap = (colormap * 255).astype(np.uint8)
|
|
|
|
|
|
- lbl_viz = cmap[lbl]
|
|
|
+ lbl_viz = colormap[lbl]
|
|
|
lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
|
|
|
|
|
|
if img is not None:
|
|
@@ -48,8 +61,18 @@ def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
|
|
|
return lbl_viz
|
|
|
|
|
|
|
|
|
-def draw_label(label, img=None, label_names=None, colormap=None):
|
|
|
+def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
|
|
|
+ """Draw pixel-wise label with colorization and label names.
|
|
|
+
|
|
|
+ label: ndarray, (H, W)
|
|
|
+ Pixel-wise labels to colorize.
|
|
|
+ img: ndarray, (H, W, 3), optional
|
|
|
+ Image on which the colorized label will be drawn.
|
|
|
+ label_names: iterable
|
|
|
+ List of label names.
|
|
|
+ """
|
|
|
import matplotlib.pyplot as plt
|
|
|
+
|
|
|
backend_org = plt.rcParams['backend']
|
|
|
plt.switch_backend('agg')
|
|
|
|
|
@@ -62,10 +85,11 @@ def draw_label(label, img=None, label_names=None, colormap=None):
|
|
|
if label_names is None:
|
|
|
label_names = [str(l) for l in range(label.max() + 1)]
|
|
|
|
|
|
- if colormap is None:
|
|
|
- colormap = label_colormap(len(label_names))
|
|
|
+ colormap = _validate_colormap(colormap, len(label_names))
|
|
|
|
|
|
- label_viz = label2rgb(label, img, n_labels=len(label_names))
|
|
|
+ label_viz = label2rgb(
|
|
|
+ label, img, n_labels=len(label_names), colormap=colormap, **kwargs
|
|
|
+ )
|
|
|
plt.imshow(label_viz)
|
|
|
plt.axis('off')
|
|
|
|