diff --git a/presets/rag_service/embedding/huggingface_remote.py b/presets/rag_service/embedding/huggingface_remote.py index 341a1d03b..f45e08c2c 100644 --- a/presets/rag_service/embedding/huggingface_remote.py +++ b/presets/rag_service/embedding/huggingface_remote.py @@ -6,7 +6,7 @@ class RemoteHuggingFaceEmbedding(BaseEmbeddingModel): def __init__(self, model_name: str, api_key: str): - self.model = HuggingFaceInferenceAPIEmbedding(model_name=model_name, api_key=api_key) + self.model = HuggingFaceInferenceAPIEmbedding(model_name=model_name, token=api_key) def get_text_embedding(self, text: str): """Returns the text embedding for a given input string.""" diff --git a/presets/rag_service/main.py b/presets/rag_service/main.py index 80f6da87f..97fed9151 100644 --- a/presets/rag_service/main.py +++ b/presets/rag_service/main.py @@ -1,7 +1,7 @@ from typing import Dict, List from crud.operations import RAGOperations -from embedding import get_embedding_model +from embedding.huggingface_local import LocalHuggingFaceEmbedding from fastapi import FastAPI, HTTPException from models import (DocumentResponse, IndexRequest, ListDocumentsResponse, QueryRequest, RefreshRequest, UpdateRequest) @@ -12,10 +12,10 @@ app = FastAPI() # Initialize embedding model -embed_model = get_embedding_model(EMBEDDING_TYPE, MODEL_ID, ACCESS_SECRET) +embed_model = LocalHuggingFaceEmbedding(MODEL_ID) # Initialize vector store -vector_store = FaissVectorStoreManager(dimension=384, embed_model=embed_model) +vector_store = FaissVectorStoreManager(embed_model=embed_model) # Initialize RAG operations rag_ops = RAGOperations(vector_store)