Эх сурвалжийг харах

Add colormap arg to draw_label

Kentaro Wada 6 жил өмнө
parent
commit
37fe44a980
1 өөрчлөгдсөн 32 нэмэгдсэн , 8 устгасан
  1. 32 8
      labelme/utils/draw.py

+ 32 - 8
labelme/utils/draw.py

@@ -26,15 +26,28 @@ def label_colormap(N=256):
     return cmap
 
 
+def _validate_colormap(colormap, n_labels):
+    if colormap is None:
+        colormap = label_colormap(n_labels)
+    else:
+        assert colormap.shape == (colormap.shape[0], 3), \
+            'colormap must be sequence of RGB values'
+        assert 0 <= colormap.min() and colormap.max() <= 1, \
+            'colormap must ranges 0 to 1'
+    return colormap
+
+
 # similar function as skimage.color.label2rgb
-def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
+def label2rgb(
+    lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
+):
     if n_labels is None:
         n_labels = len(np.unique(lbl))
 
-    cmap = label_colormap(n_labels)
-    cmap = (cmap * 255).astype(np.uint8)
+    colormap = _validate_colormap(colormap, n_labels)
+    colormap = (colormap * 255).astype(np.uint8)
 
-    lbl_viz = cmap[lbl]
+    lbl_viz = colormap[lbl]
     lbl_viz[lbl == -1] = (0, 0, 0)  # unlabeled
 
     if img is not None:
@@ -48,8 +61,18 @@ def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
     return lbl_viz
 
 
-def draw_label(label, img=None, label_names=None, colormap=None):
+def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
+    """Draw pixel-wise label with colorization and label names.
+
+    label: ndarray, (H, W)
+        Pixel-wise labels to colorize.
+    img: ndarray, (H, W, 3), optional
+        Image on which the colorized label will be drawn.
+    label_names: iterable
+        List of label names.
+    """
     import matplotlib.pyplot as plt
+
     backend_org = plt.rcParams['backend']
     plt.switch_backend('agg')
 
@@ -62,10 +85,11 @@ def draw_label(label, img=None, label_names=None, colormap=None):
     if label_names is None:
         label_names = [str(l) for l in range(label.max() + 1)]
 
-    if colormap is None:
-        colormap = label_colormap(len(label_names))
+    colormap = _validate_colormap(colormap, len(label_names))
 
-    label_viz = label2rgb(label, img, n_labels=len(label_names))
+    label_viz = label2rgb(
+        label, img, n_labels=len(label_names), colormap=colormap, **kwargs
+    )
     plt.imshow(label_viz)
     plt.axis('off')