Skip to content

Commit

Permalink
update langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 23, 2024
1 parent e7bfe4f commit b6efb46
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 150 deletions.
41 changes: 33 additions & 8 deletions erniebot-agent/applications/erniebot_researcher/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,35 @@ pip install -r requirements.txt
```
wget https://paddlenlp.bj.bcebos.com/pipelines/fonts/SimSun.ttf
```
> 第四步:创建索引
首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:
```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
export AZURE_OPENAI_API_KEY=<openai-api-token>
export OPENAI_API_KEY=<openai-api-token>
export AZURE_OPENAI_ENDPOINT=<openai-endpoint>
export OPENAI_API_VERSION=<openai-api-version>
```

> 第四步:运行
如果用户有url链接,你可以传入存储url链接的txt或者json文件。
在txt中,每一行存储文件的路径和对应的url链接,例如'https://zhuanlan.zhihu.com/p/659457816 data/Ai_Agent的起源.md'
在json文件中,字典的每一个键是文件的路径,值是url链接
如果用户不传入url文件,则默认文件的路径为其url链接

用户可以自己传入文件摘要的存储路径。其中摘要需要用json文件存储。其中json文件内存储的是多个字典,每个字典有3组键值对,"page_content"存储文件的摘要,"url"是文件的url链接,"name"是文章的名字。

```
python ./tools/preprocessing.py \
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text> \
--path_full_text <the folder path of your full text> \
--url_path <the path of your url text> \
--path_abstract <the json path of your abstract text>
```

> 第五步:运行
首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:

```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
Expand All @@ -78,23 +103,23 @@ Base版本示例运行:

```
python sample_report_example.py --num_research_agent 2 \
--index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

Base版本WebUI运行:

```
python ui.py --num_research_agent 2 \
--index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

高阶版本多智能体自动调度示例脚本运行:

```
python sample_group_agent.py --index_name_full_text <your full text> \
--index_name_abstract <your abstract text>
python sample_group_agent.py --index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text>
```

## Reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ async def verifications(self, facts_problems: List[dict]):
for item in facts_problems:
question = item["question"]
claim = item["fact"]
context = self.retriever_db.search(question)
context = [i["content"] for i in context]
context = await self.retriever_db(question)
context = [i["content"] for i in context["documents"]]
item["evidence"] = context
anwser = await self.generate_anwser(question, context)
item["anwser"] = anwser
Expand Down
12 changes: 11 additions & 1 deletion erniebot-agent/applications/erniebot_researcher/polish_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
citation_index_name: str,
dir_path: str,
report_type: str,
build_index_function: Any,
search_tool: Any,
system_message: Optional[SystemMessage] = None,
callbacks=None,
):
Expand All @@ -58,6 +60,8 @@ def __init__(
self.prompt_template_polish = PromptTemplate(
template=self.template_polish, input_variables=["content"]
)
self.build_index_function = build_index_function
self.search_tool = search_tool
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
Expand Down Expand Up @@ -143,7 +147,13 @@ async def _run(self, report, summarize=None):
final_report = await self.polish_paragraph(report, abstract, key)
await self._callback_manager.on_tool_start(self, tool=self.citation_tool, input_args=final_report)
if summarize is not None:
citation_search = add_citation(summarize, self.citation_index_name, self.embeddings)
citation_search = add_citation(
summarize,
self.citation_index_name,
self.embeddings,
self.build_index_function,
self.search_tool,
)
final_report, path = await self.citation_tool(
report=final_report,
agent_name=self.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ langchain
scikit-learn
markdown
WeasyPrint==52.5
openai
langchain_openai
18 changes: 7 additions & 11 deletions erniebot-agent/applications/erniebot_researcher/research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,21 @@ def __init__(

async def run_search_summary(self, query: str):
responses = []
url_dict = {}
results = self.retriever_fulltext_db.search(query, top_k=3)
results = await self.retriever_fulltext_db(query, top_k=3)
length_limit = 0
await self._callback_manager.on_tool_start(agent=self, tool=self.summarize_tool, input_args=query)
for doc in results:
for doc in results["documents"]:
res = await self.summarize_tool(doc["content"], query)
# Add reference to avoid hallucination
data = {"summary": res, "url": doc["url"], "name": doc["title"]}
data = {"summary": res, "url": doc["meta"]["url"], "name": doc["meta"]["name"]}
length_limit += len(res)
if length_limit < SUMMARIZE_MAX_LENGTH:
responses.append(data)
key = doc["title"]
value = doc["url"]
url_dict[key] = value
else:
logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
break
await self._callback_manager.on_tool_end(self, tool=self.summarize_tool, response=responses)
return responses, url_dict
return responses

async def run(self, query: str):
"""
Expand All @@ -117,8 +113,8 @@ async def run(self, query: str):

if self.use_context_planning:
sub_queries = []
res = self.retriever_abstract_db.search(query, top_k=3)
context = [item["content"] for item in res]
res = await self.retriever_abstract_db(query, top_k=3)
context = [item["content"] for item in res["documents"]]
context_content = ""
await self._callback_manager.on_tool_start(
agent=self, tool=self.task_planning_tool, input_args=query
Expand Down Expand Up @@ -157,7 +153,7 @@ async def run(self, query: str):
# Run Sub-Queries
paragraphs_item = []
for sub_query in sub_queries:
research_result, url_dict = await self.run_search_summary(sub_query)
research_result = await self.run_search_summary(sub_query)
paragraphs_item.extend(research_result)

paragraphs = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@

from editor_actor_agent import EditorActorAgent
from group_agent import GroupChat, GroupChatManager
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
from reviser_actor_agent import ReviserActorAgent
from tools.intent_detection_tool import IntentDetectionTool
from tools.outline_generation_tool import OutlineGenerationTool
from tools.preprocessing import get_retriver_by_type
from tools.ranking_tool import TextRankingTool
from tools.report_writing_tool import ReportWritingTool
from tools.semantic_citation_tool import SemanticCitationTool
from tools.summarization_tool import TextSummarizationTool
from tools.task_planning_tool import TaskPlanningTool
from tools.utils import FaissSearch, build_index

from erniebot_agent.chat_models import ERNIEBot
from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
Expand Down Expand Up @@ -71,22 +71,29 @@
default="openai_embedding",
help="['openai_embedding','baizhong','ernie_embedding']",
)
parser.add_argument(
"--use_frame",
type=str,
default="langchain",
choices=["langchain", "llama_index"],
help="['langchain','llama_index']",
)
args = parser.parse_args()


def get_retrievers():
def get_retrievers(build_index_function, retrieval_tool):
if args.embedding_type == "openai_embedding":
embeddings = OpenAIEmbeddings(deployment="text-embedding-ada")
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada")
paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(paper_db, embeddings=embeddings)
elif args.embedding_type == "ernie_embedding":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(paper_db, embeddings=embeddings)
elif args.embedding_type == "baizhong":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
retriever_search = BaizhongSearch(
Expand All @@ -102,7 +109,9 @@ def get_retrievers():
return {"full_text": retriever_search, "abstract": abstract_search, "embeddings": embeddings}


def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
def get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
):
research_actor = ResearchAgent(
name="generate_report",
system_message=SystemMessage("你是一个报告生成助手。你可以根据用户的指定内容生成一份报告手稿"),
Expand Down Expand Up @@ -134,6 +143,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
dir_path=target_path,
report_type=args.report_type,
citation_tool=tool_sets["semantic_citation"],
build_index_function=build_index_function,
search_tool=retrieval_tool,
)
return {
"research_agents": research_actor,
Expand Down Expand Up @@ -171,10 +182,12 @@ def main(query):
os.makedirs(target_path, exist_ok=True)
llm_long = ERNIEBot(model="ernie-longtext")
llm = ERNIEBot(model="ernie-4.0")

retriever_sets = get_retrievers()
build_index_function, retrieval_tool = get_retriver_by_type(args.use_frame)
retriever_sets = get_retrievers(build_index_function, retrieval_tool)
tool_sets = get_tools(llm, llm_long)
agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path)
agent_sets = get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
)
research_actor = agent_sets["research_agents"]
report = asyncio.run(research_actor.run(query))
report = {"report": report[0], "paragraphs": report[1]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@

from editor_actor_agent import EditorActorAgent
from fact_check_agent import FactCheckerAgent
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
from research_team import ResearchTeam
from reviser_actor_agent import ReviserActorAgent
from tools.intent_detection_tool import IntentDetectionTool
from tools.outline_generation_tool import OutlineGenerationTool
from tools.preprocessing import get_retriver_by_type
from tools.ranking_tool import TextRankingTool
from tools.report_writing_tool import ReportWritingTool
from tools.semantic_citation_tool import SemanticCitationTool
from tools.summarization_tool import TextSummarizationTool
from tools.task_planning_tool import TaskPlanningTool
from tools.utils import FaissSearch, build_index

from erniebot_agent.chat_models import ERNIEBot
from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
Expand Down Expand Up @@ -71,25 +71,31 @@
default="openai_embedding",
help="['openai_embedding','baizhong','ernie_embedding']",
)

parser.add_argument(
"--use_frame",
type=str,
default="langchain",
choices=["langchain", "llama_index"],
help="['langchain','llama_index']",
)
args = parser.parse_args()
os.environ["api_type"] = args.api_type
access_token = os.environ.get("EB_AGENT_ACCESS_TOKEN", None)


def get_retrievers():
def get_retrievers(build_index_function, retrieval_tool):
if args.embedding_type == "openai_embedding":
embeddings = OpenAIEmbeddings(deployment="text-embedding-ada")
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada")
paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(paper_db, embeddings=embeddings)
elif args.embedding_type == "ernie_embedding":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = FaissSearch(abstract_db, embeddings=embeddings)
retriever_search = FaissSearch(paper_db, embeddings=embeddings)
paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings)
abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings)
abstract_search = retrieval_tool(abstract_db, embeddings=embeddings)
retriever_search = retrieval_tool(paper_db, embeddings=embeddings)
elif args.embedding_type == "baizhong":
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
retriever_search = BaizhongSearch(
Expand Down Expand Up @@ -125,7 +131,7 @@ def get_tools(llm, llm_long):
}


def get_agents(retriever_sets, tool_sets, llm, llm_long):
def get_agents(retriever_sets, tool_sets, llm, llm_long, build_index_function, retrieval_tool):
dir_path = f"{args.save_path}/{hashlib.sha1(query.encode()).hexdigest()}"
os.makedirs(dir_path, exist_ok=True)

Expand Down Expand Up @@ -167,6 +173,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long):
llm_long=llm_long,
name="ranker",
ranking_tool=tool_sets["ranking"],
build_index_function=build_index_function,
search_tool=retrieval_tool,
)
return {
"research_actor": research_actor,
Expand All @@ -181,10 +189,10 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long):
def main(query):
llm_long = ERNIEBot(model="ernie-longtext")
llm = ERNIEBot(model="ernie-4.0")

retriever_sets = get_retrievers()
build_index_function, retrieval_tool = get_retriver_by_type(args.use_frame)
retriever_sets = get_retrievers(build_index_function, retrieval_tool)
tool_sets = get_tools(llm, llm_long)
agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long)
agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, build_index_function, retrieval_tool)
research_team = ResearchTeam(**agent_sets)

report, file_path = asyncio.run(research_team.run(query))
Expand Down
Loading

0 comments on commit b6efb46

Please sign in to comment.