diff --git a/CHANGELOG.md b/CHANGELOG.md index 40010e6..890d8f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ *Andrea Sponziello* ### **Copyrigth**: *Tiledesk SRL* +## [2024-06-15] +### 0.2.1 +- update: langchain v. 0.1.16 +- modified: prompt for q&A + ## [2024-06-08] ### 0.2.0 - refactor: refactor repository in order to manage pod and serverless diff --git a/pyproject.toml b/pyproject.toml index 1ed2e93..1f325ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tilellm" -version = "0.2.0" +version = "0.2.1" description = "tiledesk for RAG" authors = ["Gianluca Lorenzo "] repository = "https://github.com/Tiledesk/tiledesk-llm" @@ -18,14 +18,14 @@ jsonschema= "^4.20.0" redis= "^5.0.0" aioredis= "^2.0.0" #redismutex = "^1.0.0" -langchain = "^0.1.9" +langchain = "^0.1.16" jq = "^1.6.0" openai = "^1.12.0" -langchain_openai = "^0.0.7" +langchain_openai = "0.0.x" pinecone-client = "^3.1.0" python-dotenv = "^1.0.1" -langchain_community = "^0.0.24" -tiktoken = "^0.6.0" +langchain_community = "0.0.x" +tiktoken = "0.6.x" beautifulsoup4 ="^4.12.3" #uvicorn = "^0.28" unstructured= "^0.12.6" diff --git a/tilellm/controller/openai_controller.py b/tilellm/controller/openai_controller.py index bead0d1..38712ec 100644 --- a/tilellm/controller/openai_controller.py +++ b/tilellm/controller/openai_controller.py @@ -1,5 +1,7 @@ +import uuid + import fastapi -from langchain.chains import ConversationalRetrievalChain, LLMChain # Per la conversazione va usata questa classe +from langchain.chains import ConversationalRetrievalChain, LLMChain # Deprecata from langchain_core.prompts import PromptTemplate, SystemMessagePromptTemplate from langchain_openai import ChatOpenAI # from tilellm.store.pinecone_repository import add_pc_item as pinecone_add_item @@ -8,15 +10,24 @@ from langchain_community.callbacks.openai_info import OpenAICallbackHandler from tilellm.models.item_model import RetrievalResult, ChatEntry from tilellm.shared.utility import inject_repo +import tilellm.shared.const as const # from tilellm.store.pinecone_repository_base import PineconeRepositoryBase +from langchain.chains import create_history_aware_retriever +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory + import logging logger = logging.getLogger(__name__) @inject_repo -async def ask_with_memory(question_answer, repo=None): +async def ask_with_memory1(question_answer, repo=None): try: logger.info(question_answer) @@ -59,6 +70,7 @@ async def ask_with_memory(question_answer, repo=None): # pprint(len(mydocs)) if question_answer.system_context is not None and question_answer.system_context: + print("blocco if") from langchain.chains import LLMChain # prompt_template = "Tell me a {adjective} joke" @@ -78,6 +90,7 @@ async def ask_with_memory(question_answer, repo=None): llm=llm, retriever=retriever, return_source_documents=True, + verbose=True, combine_docs_chain_kwargs={"prompt": sys_prompt} ) # from pprint import pprint @@ -90,15 +103,22 @@ async def ask_with_memory(question_answer, repo=None): ) else: + print("blocco else") + #PromptTemplate.from_template() crc = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, - return_source_documents=True) + return_source_documents=True, + verbose=True) + + # 'Use the following pieces of context to answer the user\'s question. If you don\'t know the answer, just say that you don\'t know, don\'t try to make up an answer.', result = crc.invoke({'question': question_answer.question, 'chat_history': question_answer_list} ) docs = result["source_documents"] + from pprint import pprint + pprint(result) ids = [] sources = [] @@ -153,6 +173,206 @@ async def ask_with_memory(question_answer, repo=None): raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) +@inject_repo +async def ask_with_memory(question_answer, repo=None): + try: + logger.info(question_answer) + # question = str + # namespace: str + # gptkey: str + # model: str =Field(default="gpt-3.5-turbo") + # temperature: float = Field(default=0.0) + # top_k: int = Field(default=5) + # max_tokens: int = Field(default=128) + # system_context: Optional[str] + # chat_history_dict : Dict[str, ChatEntry] + + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + + logger.info(question_answer_list) + openai_callback_handler = OpenAICallbackHandler() + + llm = ChatOpenAI(model_name=question_answer.model, + temperature=question_answer.temperature, + openai_api_key=question_answer.gptkey, + max_tokens=question_answer.max_tokens, + callbacks=[openai_callback_handler]) + + emb_dimension = repo.get_embeddings_dimension(question_answer.embedding) + oai_embeddings = OpenAIEmbeddings(api_key=question_answer.gptkey, model=question_answer.embedding) + + vector_store = await repo.create_pc_index(oai_embeddings, emb_dimension) + + retriever = vector_store.as_retriever(search_type='similarity', + search_kwargs={'k': question_answer.top_k, + 'namespace': question_answer.namespace} + ) + + if question_answer.system_context is not None and question_answer.system_context: + + # Contextualize question + contextualize_q_system_prompt = const.contextualize_q_system_prompt + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + + # Answer question + qa_system_prompt = question_answer.system_context + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = ChatMessageHistory() + return store[session_id] + + conversational_rag_chain = RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) + + result = conversational_rag_chain.invoke( + {"input": question_answer.question, 'chat_history': question_answer_list}, + config={"configurable": {"session_id": uuid.uuid4().hex} + }, # constructs a key "abc123" in `store`. + ) + + else: + # Contextualize question + contextualize_q_system_prompt = const.contextualize_q_system_prompt + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + + # Answer question + qa_system_prompt = const.qa_system_prompt + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = ChatMessageHistory() + return store[session_id] + + conversational_rag_chain = RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) + + result = conversational_rag_chain.invoke( + {"input": question_answer.question, 'chat_history': question_answer_list}, + config={"configurable": {"session_id": uuid.uuid4().hex} + }, # constructs a key "abc123" in `store`. + ) + + # print(store) + # print(question_answer_list) + + docs = result["context"] + from pprint import pprint + pprint(docs) + + ids = [] + sources = [] + for doc in docs: + ids.append(doc.metadata['id']) + sources.append(doc.metadata['source']) + + ids = list(set(ids)) + sources = list(set(sources)) + source = " ".join(sources) + metadata_id = ids[0] + + logger.info(result) + print(result['answer']) + result['answer'], success = verify_answer(result['answer']) + + question_answer_list.append((result['input'], result['answer'])) + + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + + + # success = bool(openai_callback_handler.successful_requests) + prompt_token_size = openai_callback_handler.total_tokens + + result_to_return = RetrievalResult( + answer=result['answer'], + namespace=question_answer.namespace, + sources=sources, + ids=ids, + source=source, + id=metadata_id, + prompt_token_size=prompt_token_size, + success=success, + error_message=None, + chat_history_dict=chat_history_dict + ) + + return result_to_return.dict() + except Exception as e: + import traceback + traceback.print_exc() + question_answer_list = [] + if question_answer.chat_history_dict is not None: + for key, entry in question_answer.chat_history_dict.items(): + question_answer_list.append((entry.question, entry.answer)) + chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list] + chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)} + + result_to_return = RetrievalResult( + namespace=question_answer.namespace, + error_message=repr(e), + chat_history_dict=chat_history_dict + ) + raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + @inject_repo async def ask_with_sequence(question_answer, repo=None): try: @@ -407,3 +627,12 @@ def get_idproduct_chain(llm) -> LLMChain: ) return LLMChain(llm=llm, prompt=summary_prompt_template) + + +def verify_answer(s): + if s.endswith(""): + s = s[:-7] # Rimuove dalla fine della stringa + success = False + else: + success = True + return s, success diff --git a/tilellm/shared/const.py b/tilellm/shared/const.py index 62cd2c7..1fa41f7 100644 --- a/tilellm/shared/const.py +++ b/tilellm/shared/const.py @@ -8,6 +8,19 @@ PINECONE_INDEX = None PINECONE_TEXT_KEY = None +contextualize_q_system_prompt = """Given a chat history and the latest user question \ + which might reference context in the chat history, formulate a standalone question \ + which can be understood without the chat history. Do NOT answer the question, \ + just reformulate it if needed and otherwise return it as is.""" + +qa_system_prompt = """You are an helpful assistant for question-answering tasks. \ + Use ONLY the following pieces of retrieved context to answer the question. \ + If you don't know the answer, just say that you don't know. \ + If none of the retrieved context answer the question, add this word to the end \ + + + {context}""" + def populate_constant(): global PINECONE_API_KEY, PINECONE_INDEX, PINECONE_TEXT_KEY