|
@@ -45,32 +45,35 @@ class SegmentAnythingModel:
|
|
|
self._image.tobytes()
|
|
|
)
|
|
|
|
|
|
- self._thread = threading.Thread(target=self._get_image_embedding)
|
|
|
+ self._thread = threading.Thread(
|
|
|
+ target=self._compute_and_cache_image_embedding
|
|
|
+ )
|
|
|
self._thread.start()
|
|
|
|
|
|
- def _get_image_embedding(self):
|
|
|
- if self._image_embedding is None:
|
|
|
+ def _compute_and_cache_image_embedding(self):
|
|
|
+ with self._lock:
|
|
|
logger.debug("Computing image embedding...")
|
|
|
- with self._lock:
|
|
|
- self._image_embedding = _compute_image_embedding(
|
|
|
- image_size=self._image_size,
|
|
|
- encoder_session=self._encoder_session,
|
|
|
- image=self._image,
|
|
|
- )
|
|
|
- if len(self._image_embedding_cache) > 10:
|
|
|
- self._image_embedding_cache.popitem(last=False)
|
|
|
- self._image_embedding_cache[
|
|
|
- self._image.tobytes()
|
|
|
- ] = self._image_embedding
|
|
|
+ self._image_embedding = _compute_image_embedding(
|
|
|
+ image_size=self._image_size,
|
|
|
+ encoder_session=self._encoder_session,
|
|
|
+ image=self._image,
|
|
|
+ )
|
|
|
+ if len(self._image_embedding_cache) > 10:
|
|
|
+ self._image_embedding_cache.popitem(last=False)
|
|
|
+ self._image_embedding_cache[
|
|
|
+ self._image.tobytes()
|
|
|
+ ] = self._image_embedding
|
|
|
logger.debug("Done computing image embedding.")
|
|
|
- return self._image_embedding
|
|
|
+
|
|
|
+ def _get_image_embedding(self):
|
|
|
+ if self._thread is not None:
|
|
|
+ self._thread.join()
|
|
|
+ self._thread = None
|
|
|
+ with self._lock:
|
|
|
+ return self._image_embedding
|
|
|
|
|
|
def predict_polygon_from_points(self, points, point_labels):
|
|
|
- logger.debug("Waiting for image embedding...")
|
|
|
- self._thread.join()
|
|
|
image_embedding = self._get_image_embedding()
|
|
|
- logger.debug("Done waiting for image embedding.")
|
|
|
-
|
|
|
polygon = _compute_polygon_from_points(
|
|
|
image_size=self._image_size,
|
|
|
decoder_session=self._decoder_session,
|