From 9c505d50392bc07f80be5a64e01c7271a8c8be9c Mon Sep 17 00:00:00 2001 From: glorenzo972 Date: Mon, 1 Jul 2024 13:39:08 +0200 Subject: [PATCH] add: /api/ask --- CHANGELOG.md | 6 + README.md | 23 + pyproject.toml | 33 +- tilellm/__main__.py | 94 ++- tilellm/controller/controller.py | 688 ++++++++++++++++++ tilellm/controller/openai_controller.py | 54 +- tilellm/models/item_model.py | 44 +- tilellm/shared/const.py | 4 +- tilellm/shared/utility.py | 107 ++- tilellm/store/pinecone/__init__.py | 0 .../pinecone_repository_base.py | 20 +- .../{ => pinecone}/pinecone_repository_pod.py | 14 +- .../pinecone_repository_serverless.py | 22 +- tilellm/tools/document_tool_simple.py | 17 + 14 files changed, 1027 insertions(+), 99 deletions(-) create mode 100644 tilellm/controller/controller.py create mode 100644 tilellm/store/pinecone/__init__.py rename tilellm/store/{ => pinecone}/pinecone_repository_base.py (97%) rename tilellm/store/{ => pinecone}/pinecone_repository_pod.py (92%) rename tilellm/store/{ => pinecone}/pinecone_repository_serverless.py (90%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34b1c65..d2e7d2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ *Andrea Sponziello* ### **Copyrigth**: *Tiledesk SRL* +## [2024-07-01] +### 0.2.4 +- fix: scrape_type=0 +- added: /api/ask to ask to llm + + ## [2024-06-21] ### 0.2.3 - fix: delete chunks from namespace by metadata id diff --git a/README.md b/README.md index 026a33a..dd3a96c 100644 --- a/README.md +++ b/README.md @@ -85,3 +85,26 @@ pc.create_index(const.PINECONE_INDEX, ) ) ``` + +## Models +### OpenAI - engine: openai +- gpt-3.5-turbo +- gpt-4 +- gpt-4-turbo +- got-4o + +### Cohere - engine: cohere +- command-r +- command-r-plus + +### Google - engine: google +- gemini-pro + +### Anthropic - engine: anthropic +- claude-3-5-sonnet-20240620 + +### Groq - engine: groq +- Llama3-70b-8192 +- Llama3-8b-8192 +- Mixtral-8x7b-32768 +- Gemma-7b-It \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fef283f..895c1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tilellm" -version = "0.2.3" +version = "0.2.4" description = "tiledesk for RAG" authors = ["Gianluca Lorenzo "] repository = "https://github.com/Tiledesk/tiledesk-llm" @@ -13,26 +13,33 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] python = "^3.10" -fastapi = "^0.110" -jsonschema= "^4.20.0" -redis= "^5.0.0" -aioredis= "^2.0.0" +fastapi = "^0.110.0" +jsonschema= "^4.22.0" +redis= "^5.0.7" +aioredis= "^2.0.1" #redismutex = "^1.0.0" -langchain = "^0.1.16" -jq = "^1.6.0" -openai = "^1.12.0" -langchain_openai = "0.0.x" -pinecone-client = "^3.1.0" +langchain = "^0.2.6" +jq = "^1.7.0" +openai = "^1.35.7" +langchain_openai = "0.1.x" +langchain-voyageai = "0.1.1" +langchain-anthropic = "0.1.16" +langchain-cohere="0.1.8" +langchain-google-genai= "1.0.7" +langchain-groq ="0.1.6" +langchain-aws="0.1.8" +pinecone-client = "^4.1.1" python-dotenv = "^1.0.1" -langchain_community = "0.0.x" -tiktoken = "0.6.x" +langchain_community = "0.2.x" +tiktoken = "0.7.x" beautifulsoup4 ="^4.12.3" #uvicorn = "^0.28" -unstructured= "^0.12.6" +unstructured= "0.14.x" #playwright = "^1.43.0" pypdf="^4.2.0" docx2txt="^0.8" wikipedia="^1.4.0" +psutil="^6.0.0" [tool.poetry.dependencies.uvicorn] version = "^0.28" diff --git a/tilellm/__main__.py b/tilellm/__main__.py index 7a19120..b21cf16 100644 --- a/tilellm/__main__.py +++ b/tilellm/__main__.py @@ -21,19 +21,20 @@ PineconeNamespaceToDelete, ScrapeStatusReq, ScrapeStatusResponse, - PineconeIndexingResult) + PineconeIndexingResult, RetrievalResult, PineconeNamespaceResult, + PineconeDescNamespaceResult, PineconeItems, QuestionToLLM, SimpleAnswer) from tilellm.store.redis_repository import redis_xgroup_create -from tilellm.controller.openai_controller import (ask_with_memory, - ask_with_sequence, - add_pc_item, - delete_namespace, - delete_id_from_namespace, - get_ids_namespace, - get_listitems_namespace, - get_desc_namespace, - get_list_namespace, - get_sources_namespace) +from tilellm.controller.controller import (ask_with_memory, + ask_with_sequence, + add_pc_item, + delete_namespace, + delete_id_from_namespace, + get_ids_namespace, + get_listitems_namespace, + get_desc_namespace, + get_list_namespace, + get_sources_namespace, ask_to_llm) import logging @@ -215,7 +216,7 @@ async def enqueue_scrape_item_main(item: ItemSingle, redis_client: aioredis.clie enqueue item to redis. Consumer read message and add it to namespace :param item: :param redis_client: - :return: + :return: PineconeIndexingResult """ from tilellm.shared import const logger.debug(item) @@ -231,13 +232,14 @@ async def enqueue_scrape_item_main(item: ItemSingle, redis_client: aioredis.clie return {"message": f"Item {item.id} created successfully, more {res}"} -@app.post("/api/scrape/single") + +@app.post("/api/scrape/single", response_model=PineconeIndexingResult) async def create_scrape_item_single(item: ItemSingle, redis_client: aioredis.client.Redis = Depends(get_redis_client)): """ Add item to namespace :param item: :param redis_client: - :return: + :return: PineconeIndexingResult """ webhook = "" token = "" @@ -248,8 +250,8 @@ async def create_scrape_item_single(item: ItemSingle, redis_client: aioredis.cli status_code=2 ) add_to_queue = await redis_client.set(f"{item.namespace}/{item.id}", - scrape_status_response.model_dump_json(), - ex=expiration_in_seconds) + scrape_status_response.model_dump_json(), + ex=expiration_in_seconds) logger.debug(f"Start {add_to_queue}") @@ -290,8 +292,8 @@ async def create_scrape_item_single(item: ItemSingle, redis_client: aioredis.cli status_code=3 ) add_to_queue = await redis_client.set(f"{item.namespace}/{item.id}", - scrape_status_response.model_dump_json(), - ex=expiration_in_seconds) + scrape_status_response.model_dump_json(), + ex=expiration_in_seconds) # logger.debug(f"End {add_to_queue}") # if webhook: @@ -313,8 +315,8 @@ async def create_scrape_item_single(item: ItemSingle, redis_client: aioredis.cli status_code=4 ) add_to_queue = await redis_client.set(f"{item.namespace}/{item.id}", - scrape_status_response.model_dump_json(), - ex=expiration_in_seconds) + scrape_status_response.model_dump_json(), + ex=expiration_in_seconds) logger.error(f"Error {add_to_queue}") import traceback @@ -330,27 +332,48 @@ async def create_scrape_item_single(item: ItemSingle, redis_client: aioredis.cli logger.error(e) raise HTTPException(status_code=400, detail=repr(e)) -@app.post("/api/qa") + +@app.post("/api/qa", response_model=RetrievalResult) async def post_ask_with_memory_main(question_answer: QuestionAnswer): + """ + Query and Aswer with chat history + :param question_answer: + :return: RetrievalResult + """ logger.debug(question_answer) result = await ask_with_memory(question_answer) logger.debug(result) - return JSONResponse(content=result) + return JSONResponse(content=result.model_dump()) -@app.post("/api/qachain") +@app.post("/api/ask", response_model=SimpleAnswer) +async def post_ask_to_llm_main(question: QuestionToLLM): + """ + Query and Answer with a LLM + :param question: + :return: RetrievalResult + """ + logger.info(question) + + result = await ask_to_llm(question=question) + + logger.debug(result) + return JSONResponse(content=result.model_dump()) + + +@app.post("/api/qachain", response_model=RetrievalResult) async def post_ask_with_memory_chain_main(question_answer: QuestionAnswer): print(question_answer) logger.debug(question_answer) - result = ask_with_sequence(question_answer) + result = await ask_with_sequence(question_answer) logger.debug(result) - return JSONResponse(content=result) + return JSONResponse(content=result.model_dump()) # return result -@app.get("/api/list/namespace") +@app.get("/api/list/namespace", response_model=PineconeNamespaceResult) async def list_namespace_main(): """ Get all namespaces with id and vector count @@ -364,7 +387,8 @@ async def list_namespace_main(): logger.error(ex) raise HTTPException(status_code=400, detail=repr(ex)) -@app.get("/api/desc/namespace/{namespace}") + +@app.get("/api/desc/namespace/{namespace}", response_model=PineconeDescNamespaceResult) async def list_namespace_items_main(namespace: str): """ Get description for given namespace @@ -380,7 +404,8 @@ async def list_namespace_items_main(namespace: str): logger.error(ex) raise HTTPException(status_code=400, detail=repr(ex)) -@app.get("/api/listitems/namespace/{namespace}") + +@app.get("/api/listitems/namespace/{namespace}", response_model=PineconeItems) async def list_namespace_items_main(namespace: str): """ Get all item with given namespace @@ -397,7 +422,8 @@ async def list_namespace_items_main(namespace: str): raise HTTPException(status_code=400, detail=repr(ex)) -@app.post("/api/scrape/status") +@app.post("/api/scrape/status", response_model= + ScrapeStatusResponse) async def scrape_status_main(scrape_status_req: ScrapeStatusReq, redis_client: aioredis.client.Redis = Depends(get_redis_client)): """ @@ -502,7 +528,7 @@ async def delete_item_id_namespace_post(item_to_delete: PineconeItemToDelete): # raise HTTPException(status_code=400, detail=repr(ex)) -@app.get("/api/id/{metadata_id}/namespace/{namespace}") +@app.get("/api/id/{metadata_id}/namespace/{namespace}", response_model=PineconeItems) async def get_items_id_namespace_main(metadata_id: str, namespace: str): """ Get all items from namespace given id of document @@ -512,7 +538,7 @@ async def get_items_id_namespace_main(metadata_id: str, namespace: str): """ try: logger.info(f"retrieve id {metadata_id} dal namespace {namespace}") - result = await get_ids_namespace(metadata_id,namespace) + result = await get_ids_namespace(metadata_id, namespace) return JSONResponse(content=result.model_dump()) except Exception as ex: @@ -520,8 +546,8 @@ async def get_items_id_namespace_main(metadata_id: str, namespace: str): raise HTTPException(status_code=400, detail=repr(ex)) -@app.get("/api/items")#?source={source}&namespace={namespace} -async def get_items_source_namespace_main(source: str, namespace: str ): +@app.get("/api/items", response_model=PineconeItems)#?source={source}&namespace={namespace} +async def get_items_source_namespace_main(source: str, namespace: str): """ Get all item given the source and namespace :param source: source of document @@ -533,7 +559,7 @@ async def get_items_source_namespace_main(source: str, namespace: str ): from urllib.parse import unquote source = unquote(source) - result = await get_sources_namespace(source,namespace) + result = await get_sources_namespace(source, namespace) return JSONResponse(content=result.model_dump()) except Exception as ex: diff --git a/tilellm/controller/controller.py b/tilellm/controller/controller.py new file mode 100644 index 0000000..d51637f --- /dev/null +++ b/tilellm/controller/controller.py @@ -0,0 +1,688 @@ +import uuid + +import fastapi +from langchain.chains import ConversationalRetrievalChain, LLMChain # Deprecata +from langchain_core.prompts import PromptTemplate, SystemMessagePromptTemplate +from langchain_openai import ChatOpenAI +# from tilellm.store.pinecone_repository import add_pc_item as pinecone_add_item +# from tilellm.store.pinecone_repository import create_pc_index, get_embeddings_dimension +from langchain_openai import OpenAIEmbeddings +from langchain_community.callbacks.openai_info import OpenAICallbackHandler +from tilellm.models.item_model import RetrievalResult, ChatEntry, PineconeIndexingResult, PineconeNamespaceResult, \ + PineconeDescNamespaceResult, PineconeItems, SimpleAnswer +from tilellm.shared.utility import inject_repo, inject_llm +import tilellm.shared.const as const +# from tilellm.store.pinecone_repository_base import PineconeRepositoryBase + +from langchain.chains import create_history_aware_retriever +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory + +from langchain.schema import( + AIMessage, + HumanMessage, + SystemMessage + +) + +import logging + +logger = logging.getLogger(__name__) + + +@inject_repo +async def ask_with_memory1(question_answer, repo=None): + try: + logger.info(question_answer) + # question = str + # namespace: str + # gptkey: str + # model: str =Field(default="gpt-3.5-turbo") + # temperature: float = Field(default=0.0) + # top_k: int = Field(default=5) + # max_tokens: int = Field(default=128) + # system_context: Optional[str] + # chat_history_dict : Dict[str, ChatEntry] + + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + + logger.info(question_answer_list) + openai_callback_handler = OpenAICallbackHandler() + + llm = ChatOpenAI(model_name=question_answer.model, + temperature=question_answer.temperature, + openai_api_key=question_answer.gptkey, + max_tokens=question_answer.max_tokens, + callbacks=[openai_callback_handler]) + + emb_dimension = repo.get_embeddings_dimension(question_answer.embedding) + oai_embeddings = OpenAIEmbeddings(api_key=question_answer.gptkey, model=question_answer.embedding) + + vector_store = await repo.create_pc_index(oai_embeddings, emb_dimension) + + retriever = vector_store.as_retriever(search_type='similarity', + search_kwargs={'k': question_answer.top_k, + 'namespace': question_answer.namespace} + ) + # Query on store for relevant document, returned by callback + # mydocs = retriever.get_relevant_documents( question_answer.question) + # from pprint import pprint + # pprint(len(mydocs)) + + if question_answer.system_context is not None and question_answer.system_context: + print("blocco if") + from langchain.chains import LLMChain + + # prompt_template = "Tell me a {adjective} joke" + # prompt = PromptTemplate( + # input_variables=["adjective"], template=prompt_template + # ) + # llm = LLMChain(llm=OpenAI(), prompt=prompt) + sys_template = """{system_context}. + + {context} + """ + + sys_prompt = PromptTemplate.from_template(sys_template) + + # llm_chain = LLMChain(llm=llm, prompt=prompt) + crc = ConversationalRetrievalChain.from_llm( + llm=llm, + retriever=retriever, + return_source_documents=True, + verbose=True, + combine_docs_chain_kwargs={"prompt": sys_prompt} + ) + # from pprint import pprint + # pprint(crc.combine_docs_chain.llm_chain.prompt.messages) + # crc.combine_docs_chain.llm_chain.prompt.messages[0]=SystemMessagePromptTemplate.from_template(sys_prompt) + + result = crc.invoke({'question': question_answer.question, + 'system_context': question_answer.system_context, + 'chat_history': question_answer_list} + ) + + else: + print("blocco else") + # PromptTemplate.from_template() + crc = ConversationalRetrievalChain.from_llm(llm=llm, + retriever=retriever, + return_source_documents=True, + verbose=True) + + # 'Use the following pieces of context to answer the user\'s question. If you don\'t know the answer, just say that you don\'t know, don\'t try to make up an answer.', + result = crc.invoke({'question': question_answer.question, + 'chat_history': question_answer_list} + ) + + docs = result["source_documents"] + from pprint import pprint + pprint(result) + + ids = [] + sources = [] + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) + + ids = list(set(ids)) + sources = list(set(sources)) + source = " ".join(sources) + metadata_id = ids[0] + + logger.info(result) + + question_answer_list.append((result['question'], result['answer'])) + + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + success = bool(openai_callback_handler.successful_requests) + prompt_token_size = openai_callback_handler.total_tokens + + result_to_return = RetrievalResult( + answer=result['answer'], + namespace=question_answer.namespace, + sources=sources, + ids=ids, + source=source, + id=metadata_id, + prompt_token_size=prompt_token_size, + success=success, + error_message=None, + chat_history_dict=chat_history_dict + ) + + return result_to_return.dict() + except Exception as e: + import traceback + traceback.print_exc() + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + result_to_return = RetrievalResult( + namespace=question_answer.namespace, + error_message=repr(e), + chat_history_dict=chat_history_dict + ) + raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + + +@inject_llm +async def ask_to_llm(question, chat_model=None): + try: + logger.info(question) + if question.llm == "cohere": + quest = question.system_context+" "+question.question + messages = [ + HumanMessage(content=quest) + ] + else: + messages = [ + SystemMessage(content=question.system_context), + HumanMessage(content=question.question) + ] + + a = chat_model.invoke(messages) + return SimpleAnswer(content=a.content) + + except Exception as e: + import traceback + traceback.print_exc() + question_answer_list = [] + + result_to_return = SimpleAnswer( + content=repr(e) + ) + raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + +@inject_repo +async def ask_with_memory(question_answer, repo=None) -> RetrievalResult: + try: + logger.info(question_answer) + # question = str + # namespace: str + # gptkey: str + # model: str =Field(default="gpt-3.5-turbo") + # temperature: float = Field(default=0.0) + # top_k: int = Field(default=5) + # max_tokens: int = Field(default=128) + # system_context: Optional[str] + # chat_history_dict : Dict[str, ChatEntry] + + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + + logger.info(question_answer_list) + openai_callback_handler = OpenAICallbackHandler() + + llm = ChatOpenAI(model_name=question_answer.model, + temperature=question_answer.temperature, + openai_api_key=question_answer.gptkey, + max_tokens=question_answer.max_tokens, + callbacks=[openai_callback_handler]) + + emb_dimension = repo.get_embeddings_dimension(question_answer.embedding) + oai_embeddings = OpenAIEmbeddings(api_key=question_answer.gptkey, model=question_answer.embedding) + + vector_store = await repo.create_pc_index(oai_embeddings, emb_dimension) + + retriever = vector_store.as_retriever(search_type='similarity', + search_kwargs={'k': question_answer.top_k, + 'namespace': question_answer.namespace} + ) + + if question_answer.system_context is not None and question_answer.system_context: + + # Contextualize question + contextualize_q_system_prompt = const.contextualize_q_system_prompt + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + + # Answer question + qa_system_prompt = question_answer.system_context + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = ChatMessageHistory() + return store[session_id] + + conversational_rag_chain = RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) + + result = conversational_rag_chain.invoke( + {"input": question_answer.question, 'chat_history': question_answer_list}, + config={"configurable": {"session_id": uuid.uuid4().hex} + }, # constructs a key "abc123" in `store`. + ) + + else: + # Contextualize question + contextualize_q_system_prompt = const.contextualize_q_system_prompt + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + + # Answer question + qa_system_prompt = const.qa_system_prompt + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = ChatMessageHistory() + return store[session_id] + + conversational_rag_chain = RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) + + result = conversational_rag_chain.invoke( + {"input": question_answer.question, 'chat_history': question_answer_list}, + config={"configurable": {"session_id": uuid.uuid4().hex} + }, # constructs a key "abc123" in `store`. + ) + + # print(store) + # print(question_answer_list) + + docs = result["context"] + # from pprint import pprint + # pprint(docs) + + ids = [] + sources = [] + content_chunks = None + if question_answer.debug: + content_chunks = [] + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) + content_chunks.append(doc.page_content) + else: + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) + + ids = list(set(ids)) + sources = list(set(sources)) + source = " ".join(sources) + metadata_id = ids[0] + + logger.info(result) + + result['answer'], success = verify_answer(result['answer']) + + question_answer_list.append((result['input'], result['answer'])) + + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + # success = bool(openai_callback_handler.successful_requests) + prompt_token_size = openai_callback_handler.total_tokens + + result_to_return = RetrievalResult( + answer=result['answer'], + namespace=question_answer.namespace, + sources=sources, + ids=ids, + source=source, + id=metadata_id, + prompt_token_size=prompt_token_size, + content_chunks=content_chunks, + success=success, + error_message=None, + chat_history_dict=chat_history_dict + ) + + return result_to_return + except Exception as e: + import traceback + traceback.print_exc() + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + result_to_return = RetrievalResult( + namespace=question_answer.namespace, + error_message=repr(e), + chat_history_dict=chat_history_dict + ) + raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + + +@inject_repo +async def ask_with_sequence(question_answer, repo=None) -> RetrievalResult: + try: + logger.info(question_answer) + # question = str + # namespace: str + # gptkey: str + # model: str =Field(default="gpt-3.5-turbo") + # temperature: float = Field(default=0.0) + # top_k: int = Field(default=5) + # max_tokens: int = Field(default=128) + # system_context: Optional[str] + # chat_history_dict : Dict[str, ChatEntry] + + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + + logger.info(question_answer_list) + openai_callback_handler = OpenAICallbackHandler() + + llm = ChatOpenAI(model_name=question_answer.model, + temperature=question_answer.temperature, + openai_api_key=question_answer.gptkey, + max_tokens=question_answer.max_tokens, + + callbacks=[openai_callback_handler]) + + emb_dimension = repo.get_embeddings_dimension(question_answer.embedding) + oai_embeddings = OpenAIEmbeddings(api_key=question_answer.gptkey, model=question_answer.embedding) + + vector_store = await repo.create_pc_index(oai_embeddings, emb_dimension) + idllmchain = get_idproduct_chain(llm) + res = idllmchain.invoke(question_answer.question) + + retriever = vector_store.as_retriever(search_type='similarity', search_kwargs={'k': question_answer.top_k, + 'namespace': question_answer.namespace}) + + # mydocs = retriever.get_relevant_documents( question_answer.question) + # from pprint import pprint + # pprint(len(mydocs)) + + if question_answer.system_context is not None and question_answer.system_context: + from langchain.chains import LLMChain + + # prompt_template = "Tell me a {adjective} joke" + # prompt = PromptTemplate( + # input_variables=["adjective"], template=prompt_template + # ) + # llm = LLMChain(llm=OpenAI(), prompt=prompt) + sys_template = """{system_context}. + + {context} + """ + + sys_prompt = PromptTemplate.from_template(sys_template) + + # llm_chain = LLMChain(llm=llm, prompt=prompt) + crc = ConversationalRetrievalChain.from_llm( + llm=llm, + retriever=retriever, + return_source_documents=True, + combine_docs_chain_kwargs={"prompt": sys_prompt} + ) + # from pprint import pprint + # pprint(crc.combine_docs_chain.llm_chain.prompt.messages) + # crc.combine_docs_chain.llm_chain.prompt.messages[0] = SystemMessagePromptTemplate.from_template(sys_prompt) + + result = crc.invoke({'question': question_answer.question, 'system_context': question_answer.system_context, + 'chat_history': question_answer_list}) + + else: + crc = ConversationalRetrievalChain.from_llm(llm=llm, + retriever=retriever, + return_source_documents=True) + + result = crc.invoke({'question': res.get('text'), 'chat_history': question_answer_list}) + + docs = result["source_documents"] + + ids = [] + sources = [] + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) + print(doc) + + ids = list(set(ids)) + sources = list(set(sources)) + source = " ".join(sources) + id = ids[0] + + logger.info(result) + + question_answer_list.append((result['question'], result['answer'])) + + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + success = bool(openai_callback_handler.successful_requests) + prompt_token_size = openai_callback_handler.total_tokens + + result_to_return = RetrievalResult( + answer=result['answer'], + namespace=question_answer.namespace, + sources=sources, + ids=ids, + source=source, + id=id, + prompt_token_size=prompt_token_size, + success=success, + error_message=None, + chat_history_dict=chat_history_dict + + ) + + return result_to_return + except Exception as e: + import traceback + traceback.print_exc() + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + result_to_return = RetrievalResult( + namespace=question_answer.namespace, + error_message=repr(e), + chat_history_dict=chat_history_dict + + ) + raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + + +@inject_repo +async def add_pc_item(item, repo=None) -> PineconeIndexingResult: + """ + Add items to namespace + :type repo: PineconeRepositoryBase + :param item: + :param repo: + :return: PineconeIndexingResult + """ + + return await repo.add_pc_item(item) + + +@inject_repo +async def delete_namespace(namespace: str, repo=None): + """ + Delete Namespace from index + :param namespace: + :param repo: + :return: + """ + # from tilellm.store.pinecone_repository import delete_pc_namespace + try: + return await repo.delete_pc_namespace(namespace) + except Exception as ex: + raise ex + + +@inject_repo +async def delete_id_from_namespace(metadata_id: str, namespace: str, repo=None): + """ + Delete items from namespace + :param metadata_id: + :param namespace: + :param repo: + :return: + """ + # from tilellm.store.pinecone_repository import delete_pc_ids_namespace # , delete_pc_ids_namespace1 + try: + return await repo.delete_pc_ids_namespace(metadata_id=metadata_id, namespace=namespace) + except Exception as ex: + logger.error(ex) + raise ex + + +@inject_repo +async def get_list_namespace(repo=None) -> PineconeNamespaceResult: + """ + Get list namespaces with namespace id and vector count + :param repo: + :return: list of all namespaces in index + """ + # from tilellm.store.pinecone_repository import pinecone_list_namespaces + try: + return await repo.pinecone_list_namespaces() + except Exception as ex: + raise ex + + +@inject_repo +async def get_ids_namespace(metadata_id: str, namespace: str, repo=None) -> PineconeItems: + """ + Get all items from namespace given id + :param metadata_id: + :param namespace: + :param repo: + :return: + """ + # from tilellm.store.pinecone_repository import get_pc_ids_namespace + try: + return await repo.get_pc_ids_namespace(metadata_id=metadata_id, namespace=namespace) + except Exception as ex: + raise ex + + +@inject_repo +async def get_listitems_namespace(namespace: str, repo=None) -> PineconeItems: + """ + Get all items from given namespace + :param namespace: namespace_id + :param repo: + :return: list of al items PineconeItems + """ + # from tilellm.store.pinecone_repository import get_pc_all_obj_namespace + try: + return await repo.get_pc_all_obj_namespace(namespace=namespace) + except Exception as ex: + raise ex + + +@inject_repo +async def get_desc_namespace(namespace: str, repo=None) -> PineconeDescNamespaceResult: + try: + return await repo.get_pc_desc_namespace(namespace=namespace) + except Exception as ex: + raise ex + + +@inject_repo +async def get_sources_namespace(source: str, namespace: str, repo=None) -> PineconeItems: + """ + Get all item from namespace given source + :param source: + :param namespace: + :param repo: + :return: + """ + # from tilellm.store.pinecone_repository import get_pc_sources_namespace + try: + return await repo.get_pc_sources_namespace(source=source, namespace=namespace) + except Exception as ex: + raise ex + + +def get_idproduct_chain(llm) -> LLMChain: + summary_template = """ + I want the product Identifier from this phrase (remember, it's composed by 5 digit like 36400. Ignore the other informations). Give me only the number. {question}. + """ + + summary_prompt_template = PromptTemplate( + input_variables=["question"], + template=summary_template, + ) + + return LLMChain(llm=llm, prompt=summary_prompt_template) + + +def verify_answer(s): + if s.endswith(""): + s = s[:-7] # Rimuove dalla fine della stringa + success = False + else: + success = True + return s, success diff --git a/tilellm/controller/openai_controller.py b/tilellm/controller/openai_controller.py index f81124c..693fdaa 100644 --- a/tilellm/controller/openai_controller.py +++ b/tilellm/controller/openai_controller.py @@ -8,7 +8,8 @@ # from tilellm.store.pinecone_repository import create_pc_index, get_embeddings_dimension from langchain_openai import OpenAIEmbeddings from langchain_community.callbacks.openai_info import OpenAICallbackHandler -from tilellm.models.item_model import RetrievalResult, ChatEntry +from tilellm.models.item_model import RetrievalResult, ChatEntry, PineconeIndexingResult, PineconeNamespaceResult, \ + PineconeDescNamespaceResult, PineconeItems from tilellm.shared.utility import inject_repo import tilellm.shared.const as const # from tilellm.store.pinecone_repository_base import PineconeRepositoryBase @@ -21,8 +22,10 @@ from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory + import logging + logger = logging.getLogger(__name__) @@ -174,7 +177,7 @@ async def ask_with_memory1(question_answer, repo=None): @inject_repo -async def ask_with_memory(question_answer, repo=None): +async def ask_with_memory(question_answer, repo=None) -> RetrievalResult: try: logger.info(question_answer) # question = str @@ -314,14 +317,22 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: # print(question_answer_list) docs = result["context"] - from pprint import pprint - pprint(docs) + # from pprint import pprint + # pprint(docs) ids = [] sources = [] - for doc in docs: - ids.append(doc.metadata['id']) - sources.append(doc.metadata['source']) + content_chunks = None + if question_answer.debug: + content_chunks = [] + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) + content_chunks.append(doc.page_content) + else: + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) ids = list(set(ids)) sources = list(set(sources)) @@ -329,7 +340,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: metadata_id = ids[0] logger.info(result) - print(result['answer']) + result['answer'], success = verify_answer(result['answer']) question_answer_list.append((result['input'], result['answer'])) @@ -337,8 +348,6 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} - - # success = bool(openai_callback_handler.successful_requests) prompt_token_size = openai_callback_handler.total_tokens @@ -350,12 +359,13 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: source=source, id=metadata_id, prompt_token_size=prompt_token_size, + content_chunks=content_chunks, success=success, error_message=None, chat_history_dict=chat_history_dict ) - return result_to_return.dict() + return result_to_return except Exception as e: import traceback traceback.print_exc() @@ -373,8 +383,9 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: ) raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + @inject_repo -async def ask_with_sequence(question_answer, repo=None): +async def ask_with_sequence(question_answer, repo=None) -> RetrievalResult: try: logger.info(question_answer) # question = str @@ -491,7 +502,7 @@ async def ask_with_sequence(question_answer, repo=None): ) - return result_to_return.dict() + return result_to_return except Exception as e: import traceback traceback.print_exc() @@ -512,13 +523,13 @@ async def ask_with_sequence(question_answer, repo=None): @inject_repo -async def add_pc_item(item, repo=None): +async def add_pc_item(item, repo=None) -> PineconeIndexingResult: """ Add items to namespace :type repo: PineconeRepositoryBase :param item: :param repo: - :return: + :return: PineconeIndexingResult """ return await repo.add_pc_item(item) @@ -556,7 +567,7 @@ async def delete_id_from_namespace(metadata_id: str, namespace: str, repo=None): @inject_repo -async def get_list_namespace(repo=None): +async def get_list_namespace(repo=None) -> PineconeNamespaceResult: """ Get list namespaces with namespace id and vector count :param repo: @@ -570,7 +581,7 @@ async def get_list_namespace(repo=None): @inject_repo -async def get_ids_namespace(metadata_id: str, namespace: str, repo=None): +async def get_ids_namespace(metadata_id: str, namespace: str, repo=None) -> PineconeItems: """ Get all items from namespace given id :param metadata_id: @@ -586,12 +597,12 @@ async def get_ids_namespace(metadata_id: str, namespace: str, repo=None): @inject_repo -async def get_listitems_namespace(namespace: str, repo=None): +async def get_listitems_namespace(namespace: str, repo=None) -> PineconeItems: """ Get all items from given namespace :param namespace: namespace_id :param repo: - :return: list of al items + :return: list of al items PineconeItems """ # from tilellm.store.pinecone_repository import get_pc_all_obj_namespace try: @@ -599,8 +610,9 @@ async def get_listitems_namespace(namespace: str, repo=None): except Exception as ex: raise ex + @inject_repo -async def get_desc_namespace(namespace: str, repo=None): +async def get_desc_namespace(namespace: str, repo=None) -> PineconeDescNamespaceResult: try: return await repo.get_pc_desc_namespace(namespace=namespace) except Exception as ex: @@ -609,7 +621,7 @@ async def get_desc_namespace(namespace: str, repo=None): @inject_repo -async def get_sources_namespace(source: str, namespace: str, repo=None): +async def get_sources_namespace(source: str, namespace: str, repo=None) -> PineconeItems: """ Get all item from namespace given source :param source: diff --git a/tilellm/models/item_model.py b/tilellm/models/item_model.py index f7bf973..9c1d1d8 100644 --- a/tilellm/models/item_model.py +++ b/tilellm/models/item_model.py @@ -13,9 +13,8 @@ class ItemSingle(BaseModel): embedding: str = Field(default_factory=lambda: "text-embedding-ada-002") namespace: str | None = None webhook: str = Field(default_factory=lambda: "") - chunk_size: int = Field(default_factory=lambda: 256) - chunk_overlap: int = Field(default_factory=lambda: 10) - + chunk_size: int = Field(default_factory=lambda: 1000) + chunk_overlap: int = Field(default_factory=lambda: 400) class MetadataItem(BaseModel): @@ -58,6 +57,7 @@ class QuestionAnswer(BaseModel): top_k: int = Field(default=5) max_tokens: int = Field(default=128) embedding: str = Field(default_factory=lambda: "text-embedding-ada-002") + debug: bool = Field(default_factory=lambda: False) system_context: Optional[str] = None chat_history_dict: Optional[Dict[str, ChatEntry]] = None @@ -76,15 +76,45 @@ def top_k_range(cls, v): return v +class QuestionToLLM(BaseModel): + question: str + llm_key: str + llm: str + model: str = Field(default="gpt-3.5-turbo") + temperature: float = Field(default=0.0) + max_tokens: int = Field(default=128) + debug: bool = Field(default_factory=lambda: False) + system_context: str = Field(default="You are a helpful AI bot. Always reply in the same language of the question.") + + @field_validator("temperature") + def temperature_range(cls, v): + """Ensures temperature is within valid range (usually 0.0 to 1.0).""" + if not 0.0 <= v <= 1.0: + raise ValueError("Temperature must be between 0.0 and 1.0.") + return v + + @field_validator("max_tokens") + def max_tokens_range(cls, v): + """Ensures max_tokens is a positive integer.""" + if not 50 <= v <= 2000: + raise ValueError("top_k must be a positive integer.") + return v + + +class SimpleAnswer(BaseModel): + content: str + + class RetrievalResult(BaseModel): answer: str = Field(default="No answer") - sources: Optional[List[str]] | None = None - source: str | None = None - id: str | None = None + success: bool = Field(default=False) namespace: str + id: str | None = None ids: Optional[List[str]] | None = None + source: str | None = None + sources: Optional[List[str]] | None = None + content_chunks: Optional[List[str]] | None = None prompt_token_size: int = Field(default=0) - success: bool = Field(default=False) error_message: Optional[str] | None = None chat_history_dict: Optional[Dict[str, ChatEntry]] diff --git a/tilellm/shared/const.py b/tilellm/shared/const.py index 1fa41f7..250d10d 100644 --- a/tilellm/shared/const.py +++ b/tilellm/shared/const.py @@ -7,6 +7,7 @@ PINECONE_API_KEY = None PINECONE_INDEX = None PINECONE_TEXT_KEY = None +VOYAGEAI_API_KEY = None contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ @@ -23,10 +24,11 @@ def populate_constant(): - global PINECONE_API_KEY, PINECONE_INDEX, PINECONE_TEXT_KEY + global PINECONE_API_KEY, PINECONE_INDEX, PINECONE_TEXT_KEY, VOYAGEAI_API_KEY PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") PINECONE_INDEX = os.environ.get("PINECONE_INDEX") PINECONE_TEXT_KEY = os.environ.get("PINECONE_TEXT_KEY") + VOYAGEAI_API_KEY = os.environ.get("VOYAGEAI_API_KEY") diff --git a/tilellm/shared/utility.py b/tilellm/shared/utility.py index c116174..e750809 100644 --- a/tilellm/shared/utility.py +++ b/tilellm/shared/utility.py @@ -3,6 +3,18 @@ import logging +from langchain_voyageai import VoyageAIEmbeddings +from langchain_openai import OpenAIEmbeddings +from tilellm.shared import const + +from langchain_openai.chat_models import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_cohere import ChatCohere +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_groq import ChatGroq +from langchain_aws.chat_models import ChatBedrockConverse, ChatBedrock + + logger = logging.getLogger(__name__) @@ -14,16 +26,17 @@ def inject_repo(func): :param func: :return: """ + @wraps(func) def wrapper(*args, **kwargs): repo_type = os.environ.get("PINECONE_TYPE") logger.info(f"pinecone type {repo_type}") if repo_type == 'pod': - from tilellm.store.pinecone_repository_pod import PineconeRepositoryPod + from tilellm.store.pinecone.pinecone_repository_pod import PineconeRepositoryPod repo = PineconeRepositoryPod() elif repo_type == 'serverless': - from tilellm.store.pinecone_repository_serverless import PineconeRepositoryServerless + from tilellm.store.pinecone.pinecone_repository_serverless import PineconeRepositoryServerless repo = PineconeRepositoryServerless() else: raise ValueError("Unknown repository type") @@ -32,3 +45,93 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +def inject_embedding(): + def decorator(func): + @wraps(func) + async def wrapper(self, item, *args, **kwargs): + + # Logica per determinare var1 e var2 basandosi su 'item' + if item.embedding == "text-embedding-ada-002": + embedding_obj = OpenAIEmbeddings(api_key=item.gptkey, model=item.embedding) + dimension = 1536 + elif item.embedding == "text-embedding-3-large": + embedding_obj = OpenAIEmbeddings(api_key=item.gptkey, model=item.embedding) + dimension = 3072 + elif item.embedding == "text-embedding-3-small": + embedding_obj = OpenAIEmbeddings(api_key=item.gptkey, model=item.embedding) + dimension = 1536 + elif item.embedding == "claude-3": + embedding_obj = VoyageAIEmbeddings(voyage_api_key=const.VOYAGEAI_API_KEY, model="voyage-multilingual-2") + # query_result = voyage.embed_query(text) + dimension = 1024 + else: + embedding_obj = OpenAIEmbeddings(api_key=item.gptkey, model=item.embedding) + dimension = 1536 + + + # Aggiungi var1 e var2 agli kwargs + kwargs['embedding_obj'] = embedding_obj + kwargs['embedding_dimension'] = dimension + + # Chiama la funzione originale con i nuovi kwargs + return await func(self, item, *args, **kwargs) + + return wrapper + + return decorator + + +def inject_llm(func): + @wraps(func) + async def wrapper(question, *args, **kwargs): + print(question) + if question.llm == "openai": + chat_model = ChatOpenAI(api_key=question.llm_key, + model=question.model, + temperature=question.temperature, + max_tokens=question.max_tokens) + + elif question.llm == "anthropic": + chat_model = ChatAnthropic(anthropic_api_key=question.llm_key, + model=question.model, + temperature=question.temperature, + max_tokens=question.max_tokens) + + elif question.llm == "cohere": + chat_model = ChatCohere(cohere_api_key=question.llm_key, + model=question.model, + temperature=question.temperature, + max_tokens=question.max_tokens) + + elif question.llm == "google": + chat_model = ChatGoogleGenerativeAI(google_api_key=question.llm_key, + model=question.model, + temperature=question.temperature, + max_tokens=question.max_tokens, + convert_system_message_to_human=True) + + elif question.llm == "groq": + chat_model = ChatGroq(api_key=question.llm_key, + model=question.model, + temperature=question.temperature, + max_tokens=question.max_tokens + ) + + else: + chat_model = ChatOpenAI(api_key=question.llm_key, + model=question.model, + temperature=question.temperature, + max_tokens=question.max_tokens) + + # Add chat_model agli kwargs + kwargs['chat_model'] = chat_model + + # Chiama la funzione originale con i nuovi kwargs + return await func(question, *args, **kwargs) + + return wrapper + + + diff --git a/tilellm/store/pinecone/__init__.py b/tilellm/store/pinecone/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tilellm/store/pinecone_repository_base.py b/tilellm/store/pinecone/pinecone_repository_base.py similarity index 97% rename from tilellm/store/pinecone_repository_base.py rename to tilellm/store/pinecone/pinecone_repository_base.py index 1c31a02..44c6fc7 100644 --- a/tilellm/store/pinecone_repository_base.py +++ b/tilellm/store/pinecone/pinecone_repository_base.py @@ -3,6 +3,7 @@ PineconeQueryResult, PineconeItems, PineconeIndexingResult, + PineconeNamespaceResult, PineconeItemNamespaceResult, PineconeIdSummaryResult, PineconeDescNamespaceResult @@ -50,7 +51,7 @@ async def delete_pc_ids_namespace(self, metadata_id: str, namespace: str): pass @staticmethod - async def get_pc_ids_namespace( metadata_id: str, namespace: str): + async def get_pc_ids_namespace( metadata_id: str, namespace: str) -> PineconeItems: """ Get from Pinecone all items from namespace given document id :param metadata_id: @@ -116,9 +117,8 @@ async def get_pc_ids_namespace( metadata_id: str, namespace: str): raise ex @staticmethod - async def pinecone_list_namespaces(): + async def pinecone_list_namespaces() -> PineconeNamespaceResult: import pinecone - from tilellm.models.item_model import PineconeNamespaceResult, PineconeItemNamespaceResult try: pc = pinecone.Pinecone( @@ -152,7 +152,7 @@ async def pinecone_list_namespaces(): raise ex @staticmethod - async def get_pc_all_obj_namespace(namespace: str): + async def get_pc_all_obj_namespace(namespace: str) -> PineconeItems: """ Query Pinecone to get all object :param namespace: @@ -217,11 +217,11 @@ async def get_pc_all_obj_namespace(namespace: str): raise ex @staticmethod - async def get_pc_desc_namespace(namespace: str): + async def get_pc_desc_namespace(namespace: str) -> PineconeDescNamespaceResult: """ Query Pinecone to get all object :param namespace: - :return: + :return: PineconeDescNamespaceResult """ import pinecone @@ -235,7 +235,7 @@ async def get_pc_desc_namespace(namespace: str): # vector_store = Pinecone.from_existing_index(const.PINECONE_INDEX, ) describe = index.describe_index_stats() - print(describe) + logger.debug(describe) namespaces = describe.get("namespaces", {}) total_vectors = 1 @@ -246,7 +246,7 @@ async def get_pc_desc_namespace(namespace: str): description = PineconeItemNamespaceResult(namespace=namespace, vector_count=total_vectors) logger.debug(f"pinecone total vector in {namespace}: {total_vectors}") - print(description) + batch_size = min([total_vectors, 10000]) pc_res = index.query( @@ -275,7 +275,7 @@ async def get_pc_desc_namespace(namespace: str): chunks_count=1) res = PineconeDescNamespaceResult(namespace_desc=description, ids=list(ids_count.values())) - print(res) + logger.debug(res) return res @@ -286,7 +286,7 @@ async def get_pc_desc_namespace(namespace: str): raise ex @staticmethod - async def get_pc_sources_namespace(source: str, namespace: str): + async def get_pc_sources_namespace(source: str, namespace: str) -> PineconeItems: """ Get from Pinecone all items from namespace given source :param source: diff --git a/tilellm/store/pinecone_repository_pod.py b/tilellm/store/pinecone/pinecone_repository_pod.py similarity index 92% rename from tilellm/store/pinecone_repository_pod.py rename to tilellm/store/pinecone/pinecone_repository_pod.py index 419354a..baa7cee 100644 --- a/tilellm/store/pinecone_repository_pod.py +++ b/tilellm/store/pinecone/pinecone_repository_pod.py @@ -1,13 +1,14 @@ from tilellm.models.item_model import (MetadataItem, PineconeIndexingResult ) +from tilellm.shared.utility import inject_embedding from tilellm.tools.document_tool_simple import (get_content_by_url, get_content_by_url_with_bs, load_document, load_from_wikipedia ) -from tilellm.store.pinecone_repository_base import PineconeRepositoryBase +from tilellm.store.pinecone.pinecone_repository_base import PineconeRepositoryBase from tilellm.shared import const from langchain_core.documents import Document @@ -21,12 +22,15 @@ class PineconeRepositoryPod(PineconeRepositoryBase): - async def add_pc_item(self, item): + @inject_embedding() + async def add_pc_item(self, item, embedding_obj=None, embedding_dimension=None) -> PineconeIndexingResult: """ Add items to name space into Pinecone index :param item: - :return: + :param embedding_obj: + :param embedding_dimension: + :return: PineconeIndexingResult """ logger.info(item) metadata_id = item.id @@ -45,10 +49,10 @@ async def add_pc_item(self, item): logger.warning(ex) pass - emb_dimension = self.get_embeddings_dimension(embedding) + emb_dimension = embedding_dimension # self.get_embeddings_dimension(embedding) # default text-embedding-ada-002 1536, text-embedding-3-large 3072, text-embedding-3-small 1536 - oai_embeddings = OpenAIEmbeddings(api_key=gpt_key, model=embedding) + oai_embeddings = embedding_obj # OpenAIEmbeddings(api_key=gpt_key, model=embedding) vector_store = await self.create_pc_index(embeddings=oai_embeddings, emb_dimension=emb_dimension) chunks = [] diff --git a/tilellm/store/pinecone_repository_serverless.py b/tilellm/store/pinecone/pinecone_repository_serverless.py similarity index 90% rename from tilellm/store/pinecone_repository_serverless.py rename to tilellm/store/pinecone/pinecone_repository_serverless.py index f9553ab..7cabc3b 100644 --- a/tilellm/store/pinecone_repository_serverless.py +++ b/tilellm/store/pinecone/pinecone_repository_serverless.py @@ -7,7 +7,8 @@ load_from_wikipedia ) -from tilellm.store.pinecone_repository_base import PineconeRepositoryBase +from tilellm.store.pinecone.pinecone_repository_base import PineconeRepositoryBase +from tilellm.shared.utility import inject_embedding from tilellm.shared import const from langchain_core.documents import Document @@ -22,11 +23,14 @@ class PineconeRepositoryServerless(PineconeRepositoryBase): - async def add_pc_item(self, item): + @inject_embedding() + async def add_pc_item(self, item, embedding_obj=None, embedding_dimension=None): """ Add items to name space into Pinecone index :param item: + :param embedding_obj: + :param embedding_dimension: :return: """ logger.info(item) @@ -35,7 +39,7 @@ async def add_pc_item(self, item): source = item.source type_source = item.type content = item.content - gpt_key = item.gptkey + #gpt_key = item.gptkey embedding = item.embedding namespace = item.namespace scrape_type = item.scrape_type @@ -47,12 +51,16 @@ async def add_pc_item(self, item): logger.warning(ex) pass - emb_dimension = self.get_embeddings_dimension(embedding) + emb_dimension = embedding_dimension # self.get_embeddings_dimension(embedding) # default text-embedding-ada-002 1536, text-embedding-3-large 3072, text-embedding-3-small 1536 - oai_embeddings = OpenAIEmbeddings(api_key=gpt_key, model=embedding) + oai_embeddings = embedding_obj # OpenAIEmbeddings(api_key=gpt_key, model=embedding) vector_store = await self.create_pc_index(embeddings=oai_embeddings, emb_dimension=emb_dimension) - + # textprova ="test degli embeddings di voyage" + # query_result = oai_embeddings.embed_query(textprova) + # print(f"len: {len(query_result)}") + # print(query_result) + # raise Exception chunks = [] total_tokens = 0 cost = 0 @@ -139,6 +147,8 @@ async def add_pc_item(self, item): pinecone_result = PineconeIndexingResult(id=metadata_id, chunks=len(chunks), total_tokens=total_tokens, cost=f"{cost:.6f}") except Exception as ex: + import traceback + traceback.print_exc() logger.error(repr(ex)) pinecone_result = PineconeIndexingResult(id=metadata_id, chunks=len(chunks), total_tokens=total_tokens, status=400, diff --git a/tilellm/tools/document_tool_simple.py b/tilellm/tools/document_tool_simple.py index 0f8c009..bc9ba5e 100644 --- a/tilellm/tools/document_tool_simple.py +++ b/tilellm/tools/document_tool_simple.py @@ -23,6 +23,9 @@ def get_content_by_url(url: str, scrape_type: int): ) docs = loader.load() + for doc in docs: + doc.metadata = clean_metadata(doc.metadata) + # from pprint import pprint # pprint(docs) @@ -108,4 +111,18 @@ def get_content_by_url_with_bs(url:str): return testi +def is_valid_value(value): + if isinstance(value, (str, int, float, bool)): + return True + elif isinstance(value, list) and all(isinstance(item, str) for item in value): + return True + return False + + +def clean_metadata(dictionary): + return {k: v for k, v in dictionary.items() if is_valid_value(v)} + + + +