diff --git a/semantic_search/main.py b/semantic_search/main.py index 09cb1f6..c03915b 100644 --- a/semantic_search/main.py +++ b/semantic_search/main.py @@ -30,6 +30,7 @@ class Settings(BaseSettings): """ pretrained_model_name_or_path: str = "johngiorgi/declutr-sci-base" + serialized_index_path: Optional[str] = None batch_size: int = 64 max_length: Optional[int] = None mean_pool: bool = True @@ -80,7 +81,10 @@ def app_startup(): settings.pretrained_model_name_or_path, cuda_device=settings.cuda_device ) embedding_dim = model.model.config.hidden_size - model.index = setup_faiss_index(embedding_dim) + if settings.serialized_index_path is not None: + model.index = faiss.swigfaiss.read_index(settings.file_path) + else: + model.index = setup_faiss_index(embedding_dim) @app.post("/")