|
@@ -5,7 +5,7 @@ import imgviz
|
|
|
import numpy as np
|
|
|
import onnxruntime
|
|
|
import PIL.Image
|
|
|
-import skimage.measure
|
|
|
+import skimage
|
|
|
|
|
|
from ...logger import logger
|
|
|
|
|
@@ -175,6 +175,12 @@ def _compute_mask_from_points(
|
|
|
masks, _, _ = decoder_session.run(None, decoder_inputs)
|
|
|
mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
|
|
|
mask = mask > 0.0
|
|
|
+
|
|
|
+ MIN_SIZE_RATIO = 0.05
|
|
|
+ skimage.morphology.remove_small_objects(
|
|
|
+ mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
|
|
|
+ )
|
|
|
+
|
|
|
if 0:
|
|
|
imgviz.io.imsave(
|
|
|
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
|