Quellcode durchsuchen

Refactor the threading for image_embedding computation

Kentaro Wada vor 2 Jahren
Ursprung
Commit
1f96108115
1 geänderte Dateien mit 22 neuen und 19 gelöschten Zeilen
  1. 22 19
      labelme/ai/models/segment_anything.py

+ 22 - 19
labelme/ai/models/segment_anything.py

@@ -45,32 +45,35 @@ class SegmentAnythingModel:
                 self._image.tobytes()
             )
 
-        self._thread = threading.Thread(target=self._get_image_embedding)
+        self._thread = threading.Thread(
+            target=self._compute_and_cache_image_embedding
+        )
         self._thread.start()
 
-    def _get_image_embedding(self):
-        if self._image_embedding is None:
+    def _compute_and_cache_image_embedding(self):
+        with self._lock:
             logger.debug("Computing image embedding...")
-            with self._lock:
-                self._image_embedding = _compute_image_embedding(
-                    image_size=self._image_size,
-                    encoder_session=self._encoder_session,
-                    image=self._image,
-                )
-                if len(self._image_embedding_cache) > 10:
-                    self._image_embedding_cache.popitem(last=False)
-                self._image_embedding_cache[
-                    self._image.tobytes()
-                ] = self._image_embedding
+            self._image_embedding = _compute_image_embedding(
+                image_size=self._image_size,
+                encoder_session=self._encoder_session,
+                image=self._image,
+            )
+            if len(self._image_embedding_cache) > 10:
+                self._image_embedding_cache.popitem(last=False)
+            self._image_embedding_cache[
+                self._image.tobytes()
+            ] = self._image_embedding
             logger.debug("Done computing image embedding.")
-        return self._image_embedding
+
+    def _get_image_embedding(self):
+        if self._thread is not None:
+            self._thread.join()
+            self._thread = None
+        with self._lock:
+            return self._image_embedding
 
     def predict_polygon_from_points(self, points, point_labels):
-        logger.debug("Waiting for image embedding...")
-        self._thread.join()
         image_embedding = self._get_image_embedding()
-        logger.debug("Done waiting for image embedding.")
-
         polygon = _compute_polygon_from_points(
             image_size=self._image_size,
             decoder_session=self._decoder_session,