diff --git a/ebd-all-minilm/main.py b/ebd-all-minilm/main.py index 8fb43df..7e68642 100644 --- a/ebd-all-minilm/main.py +++ b/ebd-all-minilm/main.py @@ -1,5 +1,6 @@ import argparse import logging +import os import uvicorn from dotenv import load_dotenv @@ -10,6 +11,15 @@ load_dotenv() +MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'all-MiniLM-L6-v2')}" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", help="Port to run model server on", type=int, default=8444) + parser.add_argument("--model-path", help="Path to model", default=MODEL_PATH) + args = parser.parse_args() + MODEL_PATH = args.model_path + logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, @@ -19,7 +29,7 @@ def create_start_app_handler(app: FastAPI): def start_app() -> None: - SentenceTransformerBasedModel.get_model() + SentenceTransformerBasedModel.get_model(MODEL_PATH) return start_app @@ -42,7 +52,4 @@ def get_application() -> FastAPI: if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--port", help="Port to run model server on", type=int, default=8444) - args = parser.parse_args() uvicorn.run("main:app", host="0.0.0.0", port=args.port) diff --git a/ebd-all-minilm/models.py b/ebd-all-minilm/models.py index 7ed2ed9..115ab95 100644 --- a/ebd-all-minilm/models.py +++ b/ebd-all-minilm/models.py @@ -12,10 +12,12 @@ def embeddings(cls, texts): return values.tolist() @classmethod - def get_model(cls): + def get_model(cls, model=None): if cls.model is None: + if model is None or not os.path.exists(model): + model = os.getenv("MODEL_ID", "all-MiniLM-L6-v2") cls.model = SentenceTransformer( - os.getenv("MODEL_ID", "all-MiniLM-L6-v2"), + model, device=os.getenv("DEVICE", "cpu"), ) return cls.model