diff --git a/langchain_folder/agent.py b/langchain_folder/agent.py index ca9d312..92ec51e 100644 --- a/langchain_folder/agent.py +++ b/langchain_folder/agent.py @@ -3,42 +3,77 @@ import tempfile +from langchain_core.tools import Tool, tool +from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper +from langchain.pydantic_v1 import BaseModel, Field +from langchain_community.document_loaders import PyPDFLoader from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores.docarray import DocArrayInMemorySearch +from langchain.chains import RetrievalQA +from langchain_openai import ChatOpenAI from langchain_community.llms import ollama from langchain_community.chat_models import ChatOllama from langchain.prompts import ChatPromptTemplate from langchain_core.utils.function_calling import convert_to_openai_function from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.prompts import MessagesPlaceholder -from langchain.schema.runnable.config import RunnableConfig -from langchain.schema.runnable import RunnablePassthrough, Runnable +from langchain.schema.runnable import RunnablePassthrough from langchain.agents import AgentExecutor from langchain.memory import ConversationBufferMemory from langchain.agents.format_scratchpad import format_to_openai_functions -from langchain.pydantic_v1 import BaseModel, Field -from langchain_core.tools import tool -from langchain_community.document_loaders import PyPDFLoader -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.vectorstores.docarray import DocArrayInMemorySearch -from langchain.chains import RetrievalQA +from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages, +) +from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser from dotenv import load_dotenv load_dotenv() -from langchain_folder.tools.serper_tool import serper_tool +from tools.serper_tool import serper_tool api_key = os.getenv('OPENAI_API_KEY') os.environ["SERPER_API_KEY"] = os.getenv('SERPER_API_KEY') -file_path = [] +file_path = ['/workspaces/Full_Agent_build/langchain_folder/data/cover_docs.pdf', '/workspaces/Full_Agent_build/langchain_folder/data/data_two.pdf'] + + +def main(): + st.title("Multi-Agent Chatbot") + st.write("Ask questions based on the uploaded document and get a response") + + texts = st.text_area("Enter your questions here") + + file_uploaded = st.sidebar.file_uploader( + "Upload two PDF files containing any document", key="file_upload", + accept_multiple_files = True, + type=['pdf'] + ) + + temp_dir = tempfile.mkdtemp() + + if st.sidebar.button("Upload PDF File"): + if file_uploaded: + for file in file_uploaded: + file_dest = os.path.join(temp_dir, file.name) + bytes_data = file.read() + with open(file_dest, "wb") as f: + f.write(bytes_data) + file_path.append(file_dest) + + if st.button("Ask Question"): + with st.spinner(text="Generating"): + response = generate(texts) + st.markdown(response) + -class RagToolOne(BaseModel): +class RagTool(BaseModel): query: str = Field(description = "This should be a search query") -@tool("rag-tool-one", args_schema=RagToolOne, return_direct=False) +@tool("rag-tool-one", args_schema=RagTool, return_direct=False) def rag_one(query: str) -> str: """A tool that retrieves contents that are semantically relevant to the input query from the provided document. @@ -49,8 +84,8 @@ def rag_one(query: str) -> str: str: top k amount of retrieved content from the uploaded document. content that are semantically similar to the input query. """ try: - pdf_file_path = file_path[0] - + pdf_file_path = file_path[-2] + loader = PyPDFLoader(pdf_file_path) pages = loader.load_and_split() @@ -63,7 +98,7 @@ def rag_one(query: str) -> str: results = vectordb.similarity_search(query, k = 4 ) result_string = "\n\n".join(str(result) for result in results) - + except FileNotFoundError: result_string = "" @@ -73,11 +108,8 @@ def rag_one(query: str) -> str: """The second RAG tool""" -class RagToolTwo(BaseModel): - query: str = Field(description = "This should be a search query") - -@tool("rag-tool-two", args_schema=RagToolTwo, return_direct=False) +@tool("rag-tool-two", args_schema=RagTool, return_direct=False) def rag_two(query: str) -> str: """A tool that retrieves contents that are semantically relevant to the input query from the provided document. @@ -88,7 +120,7 @@ def rag_two(query: str) -> str: str: top k amount of retrieved content from the uploaded document. content that are semantically similar to the input query. """ try: - pdf_file_path = file_path[1] + pdf_file_path = file_path[-1] loader = PyPDFLoader(pdf_file_path) pages = loader.load_and_split() @@ -110,25 +142,40 @@ def rag_two(query: str) -> str: tools = [rag_one, rag_two, serper_tool] -functions = [convert_to_openai_function(f) for f in tools] -model = ChatOpenAI(temperature= 0, api_key = api_key).bind_functions(functions=functions) -#model = ChatOllama(model="mistral").bind(functions = functions) + +# functions = [convert_to_openai_function(f) for f in tools] +# model = ChatOpenAI(temperature= 0, api_key = api_key).bind_functions(functions=functions) +model = ChatOpenAI(temperature= 0, api_key = api_key).bind_tools(tools=tools) + def generate(query: str): """Function for interracting with the AI Agent""" prompt = ChatPromptTemplate.from_messages([ - ("system", "You are an expert agent with the ability to decide if a function is needed and route queries to the right function"), + ("system", "You are an expert agent with the ability to decide if a function is needed and route queries to the right function. Don't ask, just route to the function"), MessagesPlaceholder(variable_name="chat_history"), ("user", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad") ]) - - agent_chain = RunnablePassthrough.assign( - agent_scratchpad = lambda x: format_to_openai_functions(x["intermediate_steps"]) - ) | prompt | model | OpenAIFunctionsAgentOutputParser() - - memory = ConversationBufferMemory(return_message = True, memory_key = "chat_history") + + agent_chain = ( + { + "input": lambda x: x["input"], + "agent_scratchpad": lambda x: format_to_openai_tool_messages( + x["intermediate_steps"] + ), + "chat_history": lambda x: x["chat_history"], + } + | prompt + | model + | OpenAIToolsAgentOutputParser() +) + + # agent_chain = RunnablePassthrough.assign( + # agent_scratchpad = lambda x: format_to_openai_functions(x["intermediate_steps"]) + # ) | prompt | model | OpenAIFunctionsAgentOutputParser() + + memory = ConversationBufferMemory(return_messages = True, memory_key = "chat_history") agent_executor = AgentExecutor(agent = agent_chain, tools = tools, verbose= False, memory = memory) @@ -136,28 +183,7 @@ def generate(query: str): return result['output'] -def main(): - st.title("Multi-Agent Chatbot") - st.write("Ask questions based on the uploaded document and get a response") - texts = st.text_area("Enter your questions here") - - file_uploaded = st.sidebar.file_uploader( - "Upload two PDF files containing any document", key="file_upload", - accept_multiple_files = True - ) - - if st.sidebar.button("Upload PDF File"): - if file_uploaded: - for file in file_uploaded: - while len(file_path) > 2: - file_path.pop(0) - file_path.append(file.name) - - if st.button("Ask Question"): - with st.spinner(text="Generating"): - response = generate(texts) - st.markdown(response) if __name__ =="__main__": diff --git a/langchain_folder/tools/serper_tool.py b/langchain_folder/tools/serper_tool.py index 29aa6ea..2498bbc 100644 --- a/langchain_folder/tools/serper_tool.py +++ b/langchain_folder/tools/serper_tool.py @@ -15,7 +15,7 @@ class SerperTool(BaseModel): @tool("serper_tool_main", args_schema=SerperTool, return_direct=False) def serper_tool(query:str) -> str: - """A useful for when you need to ask with search. Very useful when recent or specific information is needed from the web + """A useful for when you need to ask with search. Very useful when recent or specific information is needed from the web """ search = GoogleSerperAPIWrapper(k=4, type="search") initial_result = search.results(query)