Quellcode durchsuchen

Replace labelme.utils.draw with imgviz

Kentaro Wada vor 5 Jahren
Ursprung
Commit
8edd218200

+ 9 - 6
examples/bbox_detection/labelme2voc.py

@@ -9,6 +9,7 @@ import os
 import os.path as osp
 import sys
 
+import imgviz
 try:
     import lxml.builder
     import lxml.etree
@@ -18,8 +19,6 @@ except ImportError:
 import numpy as np
 import PIL.Image
 
-import labelme
-
 
 def main():
     parser = argparse.ArgumentParser(
@@ -110,7 +109,7 @@ def main():
             xmin, xmax = sorted([xmin, xmax])
             ymin, ymax = sorted([ymin, ymax])
 
-            bboxes.append([xmin, ymin, xmax, ymax])
+            bboxes.append([ymin, xmin, ymax, xmax])
             labels.append(class_id)
 
             xml.append(
@@ -130,10 +129,14 @@ def main():
 
         if not args.noviz:
             captions = [class_names[l] for l in labels]
-            viz = labelme.utils.draw_instances(
-                img, bboxes, labels, captions=captions
+            viz = imgviz.instances2rgb(
+                image=img,
+                labels=labels,
+                bboxes=bboxes,
+                captions=captions,
+                font_size=15,
             )
-            PIL.Image.fromarray(viz).save(out_viz_file)
+            imgviz.io.imsave(out_viz_file, viz)
 
         with open(out_xml_file, 'wb') as f:
             f.write(lxml.etree.tostring(xml, pretty_print=True))

+ 16 - 7
examples/instance_segmentation/labelme2voc.py

@@ -9,6 +9,7 @@ import os
 import os.path as osp
 import sys
 
+import imgviz
 import numpy as np
 import PIL.Image
 
@@ -65,8 +66,6 @@ def main():
         f.writelines('\n'.join(class_names))
     print('Saved class_names:', out_class_names_file)
 
-    colormap = labelme.utils.label_colormap(255)
-
     for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
         print('Generating dataset from:', label_file)
         with open(label_file) as f:
@@ -112,10 +111,14 @@ def main():
             labelme.utils.lblsave(out_clsp_file, cls)
             np.save(out_cls_file, cls)
             if not args.noviz:
-                clsv = labelme.utils.draw_label(
-                    cls, img, class_names, colormap=colormap
+                clsv = imgviz.label2rgb(
+                    label=cls,
+                    img=imgviz.rgb2gray(img),
+                    label_names=class_names,
+                    font_size=15,
+                    loc='rb',
                 )
-                PIL.Image.fromarray(clsv).save(out_clsv_file)
+                imgviz.io.imsave(out_clsv_file, clsv)
 
             # instance label
             labelme.utils.lblsave(out_insp_file, ins)
@@ -123,8 +126,14 @@ def main():
             if not args.noviz:
                 instance_ids = np.unique(ins)
                 instance_names = [str(i) for i in range(max(instance_ids) + 1)]
-                insv = labelme.utils.draw_label(ins, img, instance_names)
-                PIL.Image.fromarray(insv).save(out_insv_file)
+                insv = imgviz.label2rgb(
+                    label=ins,
+                    img=imgviz.rgb2gray(img),
+                    label_names=instance_names,
+                    font_size=15,
+                    loc='rb',
+                )
+                imgviz.io.imsave(out_insv_file, insv)
 
 
 if __name__ == '__main__':

+ 9 - 5
examples/semantic_segmentation/labelme2voc.py

@@ -9,6 +9,7 @@ import os
 import os.path as osp
 import sys
 
+import imgviz
 import numpy as np
 import PIL.Image
 
@@ -59,8 +60,6 @@ def main():
         f.writelines('\n'.join(class_names))
     print('Saved class_names:', out_class_names_file)
 
-    colormap = labelme.utils.label_colormap(255)
-
     for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
         print('Generating dataset from:', label_file)
         with open(label_file) as f:
@@ -94,9 +93,14 @@ def main():
             np.save(out_lbl_file, lbl)
 
             if not args.noviz:
-                viz = labelme.utils.draw_label(
-                    lbl, img, class_names, colormap=colormap)
-                PIL.Image.fromarray(viz).save(out_viz_file)
+                viz = imgviz.label2rgb(
+                    label=lbl,
+                    img=imgviz.rgb2gray(img),
+                    font_size=15,
+                    label_names=class_names,
+                    loc='rb',
+                )
+                imgviz.io.imsave(out_viz_file, viz)
 
 
 if __name__ == '__main__':

+ 8 - 1
labelme/cli/draw_json.py

@@ -6,6 +6,7 @@ import json
 import os
 import sys
 
+import imgviz
 import matplotlib.pyplot as plt
 
 from labelme import utils
@@ -45,7 +46,13 @@ def main():
     label_names = [None] * (max(label_name_to_value.values()) + 1)
     for name, value in label_name_to_value.items():
         label_names[value] = name
-    lbl_viz = utils.draw_label(lbl, img, label_names)
+    lbl_viz = imgviz.label2rgb(
+        label=lbl,
+        img=imgviz.rgb2gray(img),
+        label_names=label_names,
+        font_size=30,
+        loc='rb',
+    )
 
     plt.subplot(121)
     plt.imshow(img)

+ 4 - 1
labelme/cli/json_to_dataset.py

@@ -4,6 +4,7 @@ import json
 import os
 import os.path as osp
 
+import imgviz
 import PIL.Image
 import yaml
 
@@ -54,7 +55,9 @@ def main():
     label_names = [None] * (max(label_name_to_value.values()) + 1)
     for name, value in label_name_to_value.items():
         label_names[value] = name
-    lbl_viz = utils.draw_label(lbl, img, label_names)
+    lbl_viz = imgviz.label2rgb(
+        label=lbl, img=img, label_names=label_names, loc='rb'
+    )
 
     PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))
     utils.lblsave(osp.join(out_dir, 'label.png'), lbl)

+ 0 - 5
labelme/utils/__init__.py

@@ -13,11 +13,6 @@ from .shape import polygons_to_mask
 from .shape import shape_to_mask
 from .shape import shapes_to_label
 
-from .draw import draw_instances
-from .draw import draw_label
-from .draw import label_colormap
-from .draw import label2rgb
-
 from .qt import newIcon
 from .qt import newButton
 from .qt import newAction

+ 3 - 3
labelme/utils/_io.py

@@ -3,17 +3,17 @@ import os.path as osp
 import numpy as np
 import PIL.Image
 
-from labelme.utils.draw import label_colormap
-
 
 def lblsave(filename, lbl):
+    import imgviz
+
     if osp.splitext(filename)[1] != '.png':
         filename += '.png'
     # Assume label ranses [-1, 254] for int32,
     # and [0, 255] for uint8 as VOC.
     if lbl.min() >= -1 and lbl.max() < 255:
         lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
-        colormap = label_colormap(255)
+        colormap = imgviz.label_colormap()
         lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
         lbl_pil.save(filename)
     else:

+ 0 - 157
labelme/utils/draw.py

@@ -1,157 +0,0 @@
-import io
-import os.path as osp
-
-import numpy as np
-import PIL.Image
-import PIL.ImageDraw
-import PIL.ImageFont
-
-
-def label_colormap(N=256):
-
-    def bitget(byteval, idx):
-        return ((byteval & (1 << idx)) != 0)
-
-    cmap = np.zeros((N, 3))
-    for i in range(0, N):
-        id = i
-        r, g, b = 0, 0, 0
-        for j in range(0, 8):
-            r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
-            g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
-            b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
-            id = (id >> 3)
-        cmap[i, 0] = r
-        cmap[i, 1] = g
-        cmap[i, 2] = b
-    cmap = cmap.astype(np.float32) / 255
-    return cmap
-
-
-def _validate_colormap(colormap, n_labels):
-    if colormap is None:
-        colormap = label_colormap(n_labels)
-    else:
-        assert colormap.shape == (colormap.shape[0], 3), \
-            'colormap must be sequence of RGB values'
-        assert 0 <= colormap.min() and colormap.max() <= 1, \
-            'colormap must ranges 0 to 1'
-    return colormap
-
-
-# similar function as skimage.color.label2rgb
-def label2rgb(
-    lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
-):
-    if n_labels is None:
-        n_labels = len(np.unique(lbl))
-
-    colormap = _validate_colormap(colormap, n_labels)
-    colormap = (colormap * 255).astype(np.uint8)
-
-    lbl_viz = colormap[lbl]
-    lbl_viz[lbl == -1] = (0, 0, 0)  # unlabeled
-
-    if img is not None:
-        img_gray = PIL.Image.fromarray(img).convert('LA')
-        img_gray = np.asarray(img_gray.convert('RGB'))
-        # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
-        # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
-        lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
-        lbl_viz = lbl_viz.astype(np.uint8)
-
-    return lbl_viz
-
-
-def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
-    """Draw pixel-wise label with colorization and label names.
-
-    label: ndarray, (H, W)
-        Pixel-wise labels to colorize.
-    img: ndarray, (H, W, 3), optional
-        Image on which the colorized label will be drawn.
-    label_names: iterable
-        List of label names.
-    """
-    import matplotlib.pyplot as plt
-
-    backend_org = plt.rcParams['backend']
-    plt.switch_backend('agg')
-
-    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
-                        wspace=0, hspace=0)
-    plt.margins(0, 0)
-    plt.gca().xaxis.set_major_locator(plt.NullLocator())
-    plt.gca().yaxis.set_major_locator(plt.NullLocator())
-
-    if label_names is None:
-        label_names = [str(l) for l in range(label.max() + 1)]
-
-    colormap = _validate_colormap(colormap, len(label_names))
-
-    label_viz = label2rgb(
-        label, img, n_labels=len(label_names), colormap=colormap, **kwargs
-    )
-    plt.imshow(label_viz)
-    plt.axis('off')
-
-    plt_handlers = []
-    plt_titles = []
-    for label_value, label_name in enumerate(label_names):
-        if label_value not in label:
-            continue
-        fc = colormap[label_value]
-        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
-        plt_handlers.append(p)
-        plt_titles.append('{value}: {name}'
-                          .format(value=label_value, name=label_name))
-    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
-
-    f = io.BytesIO()
-    plt.savefig(f, bbox_inches='tight', pad_inches=0)
-    plt.cla()
-    plt.close()
-
-    plt.switch_backend(backend_org)
-
-    out_size = (label_viz.shape[1], label_viz.shape[0])
-    out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
-    out = np.asarray(out)
-    return out
-
-
-def draw_instances(
-    image=None,
-    bboxes=None,
-    labels=None,
-    masks=None,
-    captions=None,
-):
-    import matplotlib
-
-    # TODO(wkentaro)
-    assert image is not None
-    assert bboxes is not None
-    assert labels is not None
-    assert masks is None
-    assert captions is not None
-
-    viz = PIL.Image.fromarray(image)
-    draw = PIL.ImageDraw.ImageDraw(viz)
-
-    font_path = osp.join(
-        osp.dirname(matplotlib.__file__),
-        'mpl-data/fonts/ttf/DejaVuSans.ttf'
-    )
-    font = PIL.ImageFont.truetype(font_path)
-
-    colormap = label_colormap(255)
-    for bbox, label, caption in zip(bboxes, labels, captions):
-        color = colormap[label]
-        color = tuple((color * 255).astype(np.uint8).tolist())
-
-        xmin, ymin, xmax, ymax = bbox
-        draw.rectangle((xmin, ymin, xmax, ymax), outline=color)
-        draw.text((xmin, ymin), caption, font=font)
-
-    return np.asarray(viz)

+ 1 - 0
setup.py

@@ -29,6 +29,7 @@ del here
 
 
 install_requires = [
+    'imgviz',
     'matplotlib',
     'numpy',
     'Pillow>=2.8.0',

+ 0 - 48
tests/labelme_tests/utils_tests/test_draw.py

@@ -1,48 +0,0 @@
-import numpy as np
-
-from labelme.utils import draw as draw_module
-from labelme.utils import shape as shape_module
-
-from .util import get_img_and_lbl
-
-
-# -----------------------------------------------------------------------------
-
-
-def test_label_colormap():
-    N = 255
-    colormap = draw_module.label_colormap(N=N)
-    assert colormap.shape == (N, 3)
-
-
-def test_label2rgb():
-    img, lbl, label_names = get_img_and_lbl()
-    n_labels = len(label_names)
-
-    viz = draw_module.label2rgb(lbl=lbl, n_labels=n_labels)
-    assert lbl.shape == viz.shape[:2]
-    assert viz.dtype == np.uint8
-
-    viz = draw_module.label2rgb(lbl=lbl, img=img, n_labels=n_labels)
-    assert img.shape[:2] == lbl.shape == viz.shape[:2]
-    assert viz.dtype == np.uint8
-
-
-def test_draw_label():
-    img, lbl, label_names = get_img_and_lbl()
-
-    viz = draw_module.draw_label(lbl, img, label_names=label_names)
-    assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
-    assert viz.dtype == np.uint8
-
-
-def test_draw_instances():
-    img, lbl, label_names = get_img_and_lbl()
-    labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0}
-    labels, masks = zip(*labels_and_masks.items())
-    masks = np.asarray(masks)
-    bboxes = shape_module.masks_to_bboxes(masks)
-    captions = [label_names[l] for l in labels]
-    viz = draw_module.draw_instances(img, bboxes, labels, captions=captions)
-    assert viz.shape[:2] == img.shape[:2]
-    assert viz.dtype == np.uint8