Forráskód Böngészése

Add image_embedding_cache

Kentaro Wada 2 éve
szülő
commit
0bacc48517
1 módosított fájl, 20 hozzáadás és 7 törlés
  1. 20 7
      labelme/ai/models/segment_anything.py

+ 20 - 7
labelme/ai/models/segment_anything.py

@@ -1,3 +1,4 @@
+import collections
 import threading
 
 import gdown
@@ -32,20 +33,32 @@ class SegmentAnythingModel:
         self._encoder_session = onnxruntime.InferenceSession(encoder_path)
         self._decoder_session = onnxruntime.InferenceSession(decoder_path)
 
+        self._lock = threading.Lock()
+        self._image_embedding_cache = collections.OrderedDict()
+
     def set_image(self, image: np.ndarray):
-        self._image = image
-        self._image_embedding = None
+        with self._lock:
+            self._image = image
+            self._image_embedding = self._image_embedding_cache.get(
+                self._image.tobytes()
+            )
 
         self._thread = threading.Thread(target=self.get_image_embedding)
         self._thread.start()
 
     def get_image_embedding(self):
         if self._image_embedding is None:
-            self._image_embedding = compute_image_embedding(
-                image_size=self._image_size,
-                encoder_session=self._encoder_session,
-                image=self._image,
-            )
+            with self._lock:
+                self._image_embedding = compute_image_embedding(
+                    image_size=self._image_size,
+                    encoder_session=self._encoder_session,
+                    image=self._image,
+                )
+                if len(self._image_embedding_cache) > 10:
+                    self._image_embedding_cache.popitem(last=False)
+                self._image_embedding_cache[
+                    self._image.tobytes()
+                ] = self._image_embedding
         return self._image_embedding
 
     def points_to_polygon_callback(self, points, point_labels):