Skip to content

Commit

Permalink
langchain agent modification
Browse files Browse the repository at this point in the history
  • Loading branch information
Emerald33 committed Mar 18, 2024
1 parent e1a58c5 commit 24c0729
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
29 changes: 15 additions & 14 deletions langchain_local/tools/rag_tool.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import os
import openai

from langchain.pydantic_v1 import BaseModel, Field
from langchain_local.tools import BaseTool, StructuredTool, tool
from langchain_core.tools import tool
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores.docarray import DocArrayInMemorySearch
from langchain.chains import RetrievalQA

from dotenv import load_dotenv, find_dotenv
from dotenv import load_dotenv
load_dotenv()

_ = load_dotenv(find_dotenv()) # reading local files

openai_api_key = os.getenv('OPENAI_API_KEY')
openai.api_key = os.environ.get('OPENAI_API_KEY')
api_key = os.getenv('OPENAI_API_KEY')


class RagTool(BaseModel):
query: str = Field(description = "This should be a search query")


@tool("rag-tool-one", args_schema=RagTool, return_direct=True)
@tool("rag-tool-one", args_schema=RagTool, return_direct=False)
def rag_one(query: str) -> str:
"""A tool that retrieves contents that are semantically relevant to the input query from the provided document.
Expand All @@ -29,15 +30,15 @@ def rag_one(query: str) -> str:
Returns:
str: top k amount of retrieved content from the uploaded document. content that are semantically similar to the input query.
"""
pdf_file_path = _ #placehloder code from html, the path to pdf file
pdf_file_path = #'/content/data/cover_docs.pdf'

loader = PyPDFLoader(pdf_file_path)
pages = loader.load_and_split()

text_splitter = RecursiveCharacterTextSplitter()
splits = text_splitter.split_documents(pages)

embedding = OpenAIEmbeddings(api_key = openai_api_key)
embedding = OpenAIEmbeddings(api_key = api_key)

vectordb = DocArrayInMemorySearch.from_documents(splits, embedding)

Expand All @@ -51,7 +52,7 @@ def rag_one(query: str) -> str:

"""The second RAG tool"""

@tool("rag-tool-two", args_schema=RagTool, return_direct=True)
@tool("rag-tool-two", args_schema=RagTool, return_direct=False)
def rag_two(query: str) -> str:
"""A tool that retrieves contents that are semantically relevant to the input query from the provided document.
Expand All @@ -61,19 +62,19 @@ def rag_two(query: str) -> str:
Returns:
str: top k amount of retrieved content from the uploaded document. content that are semantically similar to the input query.
"""
pdf_file_path = _ #placehloder code from html, the path to pdf file
pdf_file_path = #'/content/data/ptdf_sop.pdf'

loader = PyPDFLoader(pdf_file_path)
pages = loader.load_and_split()

text_splitter = RecursiveCharacterTextSplitter()
splits = text_splitter.split_documents(pages)

embedding = OpenAIEmbeddings(api_key = openai_api_key)
embedding = OpenAIEmbeddings(api_key = api_key)

vectordb = DocArrayInMemorySearch.from_documents(splits, embedding)

results = vectordb.similarity_search(query, k = 4 )
result_string = "\n\n".join(str(result) for result in results)

return result_string
return result_string
31 changes: 20 additions & 11 deletions langchain_local/tools/serper_tool.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import os
import pprint

from langchain_core.tools import Tool
from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())
from langchain_core.tools import tool
from langchain.pydantic_v1 import BaseModel, Field
from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper
from dotenv import load_dotenv

serper_api_key = os.getenv('SERPER_API_KEY')
load_dotenv()
os.environ["SERPER_API_KEY"] = os.getenv('SERPER_API_KEY')

search = GoogleSerperAPIWrapper(serper_api_key = serper_api_key)
class SerperTool(BaseModel):
query: str = Field(description = "This should be a search query")

serper_tool = Tool(
name = "Intermediate Answer",
func = search.run,
description="useful for when you need to ask with search"
)
@tool("serper_tool_main", args_schema=SerperTool, return_direct=False)
def serper_tool(query:str) -> str:
"""A useful for when you need to ask with search. Very useful when recent or specific information is needed from the web
"""
search = GoogleSerperAPIWrapper(k=4, type="search")
initial_result = search.results(query)
result = initial_result['organic']
results = ""
for r in result:
data = f"'Title':{r['title']}\n 'content':{r['snippet']}"
results += f"{data}\n\n"
return results

0 comments on commit 24c0729

Please sign in to comment.