瀏覽代碼

Support portraint images in Segment-Anything Model

Kentaro Wada 2 年之前
父節點
當前提交
97b7b81c1a
共有 1 個文件被更改,包括 28 次插入10 次删除
  1. 28 10
      labelme/ai/models/segment_anything.py

+ 28 - 10
labelme/ai/models/segment_anything.py

@@ -82,15 +82,34 @@ class SegmentAnythingModel:
         return polygon
         return polygon
 
 
 
 
-def compute_image_embedding(image_size, encoder_session, image):
-    assert image.shape[1] > image.shape[0]
-    scale = image_size / image.shape[1]
-    x = imgviz.resize(
+def _compute_scale_to_resize_image(image_size, image):
+    height, width = image.shape[:2]
+    if width > height:
+        scale = image_size / width
+        new_height = int(round(height * scale))
+        new_width = image_size
+    else:
+        scale = image_size / height
+        new_height = image_size
+        new_width = int(round(width * scale))
+    return scale, new_height, new_width
+
+
+def _resize_image(image_size, image):
+    scale, new_height, new_width = _compute_scale_to_resize_image(
+        image_size=image_size, image=image
+    )
+    scaled_image = imgviz.resize(
         image,
         image,
-        height=int(round(image.shape[0] * scale)),
-        width=image_size,
+        height=new_height,
+        width=new_width,
         backend="pillow",
         backend="pillow",
     ).astype(np.float32)
     ).astype(np.float32)
+    return scale, scaled_image
+
+
+def compute_image_embedding(image_size, encoder_session, image):
+    scale, x = _resize_image(image_size, image)
     x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
     x = (x - np.array([123.675, 116.28, 103.53], dtype=np.float32)) / np.array(
         [58.395, 57.12, 57.375], dtype=np.float32
         [58.395, 57.12, 57.375], dtype=np.float32
     )
     )
@@ -129,10 +148,9 @@ def compute_polygon_from_points(
         None, :
         None, :
     ].astype(np.float32)
     ].astype(np.float32)
 
 
-    assert image.shape[1] > image.shape[0]
-    scale = image_size / image.shape[1]
-    new_height = int(round(image.shape[0] * scale))
-    new_width = image_size
+    scale, new_height, new_width = _compute_scale_to_resize_image(
+        image_size=image_size, image=image
+    )
     onnx_coord = (
     onnx_coord = (
         onnx_coord.astype(float)
         onnx_coord.astype(float)
         * (new_width / image.shape[1], new_height / image.shape[0])
         * (new_width / image.shape[1], new_height / image.shape[0])