Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
added loading model from path
Browse files Browse the repository at this point in the history
  • Loading branch information
nsosio committed Nov 3, 2023
1 parent 00b28a5 commit 0f310d6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
15 changes: 11 additions & 4 deletions ebd-all-minilm/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
import os

import uvicorn
from dotenv import load_dotenv
Expand All @@ -10,6 +11,15 @@

load_dotenv()

MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'all-MiniLM-L6-v2')}"

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", help="Port to run model server on", type=int, default=8444)
parser.add_argument("--model-path", help="Path to model", default=MODEL_PATH)
args = parser.parse_args()
MODEL_PATH = args.model_path

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
Expand All @@ -19,7 +29,7 @@

def create_start_app_handler(app: FastAPI):
def start_app() -> None:
SentenceTransformerBasedModel.get_model()
SentenceTransformerBasedModel.get_model(MODEL_PATH)

return start_app

Expand All @@ -42,7 +52,4 @@ def get_application() -> FastAPI:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", help="Port to run model server on", type=int, default=8444)
args = parser.parse_args()
uvicorn.run("main:app", host="0.0.0.0", port=args.port)
6 changes: 4 additions & 2 deletions ebd-all-minilm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ def embeddings(cls, texts):
return values.tolist()

@classmethod
def get_model(cls):
def get_model(cls, model=None):
if cls.model is None:
if model is None or not os.path.exists(model):
model = os.getenv("MODEL_ID", "all-MiniLM-L6-v2")
cls.model = SentenceTransformer(
os.getenv("MODEL_ID", "all-MiniLM-L6-v2"),
model,
device=os.getenv("DEVICE", "cpu"),
)
return cls.model

0 comments on commit 0f310d6

Please sign in to comment.