Преглед изворни кода

Merge semantic_segmentation contrib

Kentaro Wada пре 7 година
родитељ
комит
68f3395fa9
2 измењених фајлова са 25 додато и 76 уклоњено
  1. 3 65
      examples/semantic_segmentation/labelme2voc.py
  2. 22 11
      labelme/utils.py

+ 3 - 65
examples/semantic_segmentation/labelme2voc.py

@@ -24,69 +24,6 @@ from labelme.utils import label2rgb
 from labelme.utils import label_colormap
 
 
-# TODO(wkentaro): Move to labelme/utils.py
-# contrib
-# -----------------------------------------------------------------------------
-
-
-def labelme_shapes_to_label(img_shape, shapes, label_name_to_value):
-    lbl = np.zeros(img_shape[:2], dtype=np.int32)
-    for shape in shapes:
-        polygons = shape['points']
-        label_name = shape['label']
-        if label_name in label_name_to_value:
-            label_value = label_name_to_value[label_name]
-        else:
-            label_value = len(label_name_to_value)
-            label_name_to_value[label_name] = label_value
-        mask = labelme.utils.polygons_to_mask(img_shape[:2], polygons)
-        lbl[mask] = label_value
-
-    return lbl
-
-
-def draw_label(label, img, label_names, colormap=None):
-    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
-                        wspace=0, hspace=0)
-    plt.margins(0, 0)
-    plt.gca().xaxis.set_major_locator(plt.NullLocator())
-    plt.gca().yaxis.set_major_locator(plt.NullLocator())
-
-    if colormap is None:
-        colormap = label_colormap(len(label_names))
-
-    label_viz = label2rgb(
-        label, img, n_labels=len(label_names), alpha=.5)
-    plt.imshow(label_viz)
-    plt.axis('off')
-
-    plt_handlers = []
-    plt_titles = []
-    for label_value, label_name in enumerate(label_names):
-        if label_value not in label:
-            continue
-        if label_name.startswith('_'):
-            continue
-        fc = colormap[label_value]
-        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
-        plt_handlers.append(p)
-        plt_titles.append(label_name)
-    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
-
-    f = io.BytesIO()
-    plt.savefig(f, bbox_inches='tight', pad_inches=0)
-    plt.cla()
-    plt.close()
-
-    out_size = (img.shape[1], img.shape[0])
-    out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
-    out = np.asarray(out)
-    return out
-
-
-# -----------------------------------------------------------------------------
-
-
 def main():
     parser = argparse.ArgumentParser(
         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -142,18 +79,19 @@ def main():
             img = skimage.io.imread(img_file)
             skimage.io.imsave(out_img_file, img)
 
-            lbl = labelme_shapes_to_label(
+            lbl = labelme.utils.shapes_to_label(
                 img_shape=img.shape,
                 shapes=data['shapes'],
                 label_name_to_value=class_name_to_id,
             )
+
             lbl_pil = PIL.Image.fromarray(lbl)
             # Only works with uint8 label
             # lbl_pil = PIL.Image.fromarray(lbl, mode='P')
             # lbl_pil.putpalette((colormap * 255).flatten())
             lbl_pil.save(out_lbl_file)
 
-            viz = draw_label(
+            viz = labelme.utils.draw_label(
                 lbl, img, class_names, colormap=colormap)
             skimage.io.imsave(out_viz_file, viz)
 

+ 22 - 11
labelme/utils.py

@@ -95,6 +95,10 @@ def draw_label(label, img, label_names, colormap=None):
     plt_handlers = []
     plt_titles = []
     for label_value, label_name in enumerate(label_names):
+        if label_value not in label:
+            continue
+        if label_name.startswith('_'):
+            continue
         fc = colormap[label_value]
         p = plt.Rectangle((0, 0), 1, 1, fc=fc)
         plt_handlers.append(p)
@@ -114,22 +118,29 @@ def draw_label(label, img, label_names, colormap=None):
     return out
 
 
-def labelme_shapes_to_label(img_shape, shapes):
-    label_name_to_val = {'background': 0}
+def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
     lbl = np.zeros(img_shape[:2], dtype=np.int32)
     for shape in shapes:
         polygons = shape['points']
         label_name = shape['label']
-        if label_name in label_name_to_val:
-            label_value = label_name_to_val[label_name]
-        else:
-            label_value = len(label_name_to_val)
-            label_name_to_val[label_name] = label_value
+        label_value = label_name_to_value[label_name]
         mask = polygons_to_mask(img_shape[:2], polygons)
         lbl[mask] = label_value
+    return lbl
+
+
+def labelme_shapes_to_label(img_shape, shapes):
+    warnings.warn('labelme_shapes_to_label is deprecated, so please use '
+                  'shapes_to_label.')
 
-    lbl_names = [None] * (max(label_name_to_val.values()) + 1)
-    for label_name, label_value in label_name_to_val.items():
-        lbl_names[label_value] = label_name
+    label_name_to_value = {}
+    for shape in shapes:
+        label_name = shape['label']
+        if label_name in label_name_to_value:
+            label_value = label_name_to_value[label_name]
+        else:
+            label_value = len(label_name_to_value)
+            label_name_to_value[label_name] = label_value
 
-    return lbl, lbl_names
+    lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
+    return lbl, label_name_to_value