|
@@ -1,3 +1,4 @@
|
|
|
|
+import collections
|
|
import threading
|
|
import threading
|
|
|
|
|
|
import gdown
|
|
import gdown
|
|
@@ -32,20 +33,32 @@ class SegmentAnythingModel:
|
|
self._encoder_session = onnxruntime.InferenceSession(encoder_path)
|
|
self._encoder_session = onnxruntime.InferenceSession(encoder_path)
|
|
self._decoder_session = onnxruntime.InferenceSession(decoder_path)
|
|
self._decoder_session = onnxruntime.InferenceSession(decoder_path)
|
|
|
|
|
|
|
|
+ self._lock = threading.Lock()
|
|
|
|
+ self._image_embedding_cache = collections.OrderedDict()
|
|
|
|
+
|
|
def set_image(self, image: np.ndarray):
|
|
def set_image(self, image: np.ndarray):
|
|
- self._image = image
|
|
|
|
- self._image_embedding = None
|
|
|
|
|
|
+ with self._lock:
|
|
|
|
+ self._image = image
|
|
|
|
+ self._image_embedding = self._image_embedding_cache.get(
|
|
|
|
+ self._image.tobytes()
|
|
|
|
+ )
|
|
|
|
|
|
self._thread = threading.Thread(target=self.get_image_embedding)
|
|
self._thread = threading.Thread(target=self.get_image_embedding)
|
|
self._thread.start()
|
|
self._thread.start()
|
|
|
|
|
|
def get_image_embedding(self):
|
|
def get_image_embedding(self):
|
|
if self._image_embedding is None:
|
|
if self._image_embedding is None:
|
|
- self._image_embedding = compute_image_embedding(
|
|
|
|
- image_size=self._image_size,
|
|
|
|
- encoder_session=self._encoder_session,
|
|
|
|
- image=self._image,
|
|
|
|
- )
|
|
|
|
|
|
+ with self._lock:
|
|
|
|
+ self._image_embedding = compute_image_embedding(
|
|
|
|
+ image_size=self._image_size,
|
|
|
|
+ encoder_session=self._encoder_session,
|
|
|
|
+ image=self._image,
|
|
|
|
+ )
|
|
|
|
+ if len(self._image_embedding_cache) > 10:
|
|
|
|
+ self._image_embedding_cache.popitem(last=False)
|
|
|
|
+ self._image_embedding_cache[
|
|
|
|
+ self._image.tobytes()
|
|
|
|
+ ] = self._image_embedding
|
|
return self._image_embedding
|
|
return self._image_embedding
|
|
|
|
|
|
def points_to_polygon_callback(self, points, point_labels):
|
|
def points_to_polygon_callback(self, points, point_labels):
|