diff --git a/erniebot-agent/applications/erniebot_researcher/README.md b/erniebot-agent/applications/erniebot_researcher/README.md index 020d8341..69bd7d51 100644 --- a/erniebot-agent/applications/erniebot_researcher/README.md +++ b/erniebot-agent/applications/erniebot_researcher/README.md @@ -63,10 +63,35 @@ pip install -r requirements.txt ``` wget https://paddlenlp.bj.bcebos.com/pipelines/fonts/SimSun.ttf ``` +> 第四步:创建索引 +首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量: +``` +export EB_AGENT_ACCESS_TOKEN= +export AISTUDIO_ACCESS_TOKEN= +export AZURE_OPENAI_API_KEY= +export OPENAI_API_KEY= +export AZURE_OPENAI_ENDPOINT= +export OPENAI_API_VERSION= +``` -> 第四步:运行 +如果用户有url链接,你可以传入存储url链接的txt或者json文件。 +在txt中,每一行存储文件的路径和对应的url链接,例如'https://zhuanlan.zhihu.com/p/659457816 data/Ai_Agent的起源.md' +在json文件中,字典的每一个键是文件的路径,值是url链接 +如果用户不传入url文件,则默认文件的路径为其url链接 + +用户可以自己传入文件摘要的存储路径。其中摘要需要用json文件存储。其中json文件内存储的是多个字典,每个字典有3组键值对,"page_content"存储文件的摘要,"url"是文件的url链接,"name"是文章的名字。 + +``` +python ./tools/preprocessing.py \ +--index_name_full_text \ +--index_name_abstract \ +--path_full_text \ +--url_path \ +--path_abstract +``` + +> 第五步:运行 -首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量: ``` export EB_AGENT_ACCESS_TOKEN= @@ -78,23 +103,23 @@ Base版本示例运行: ``` python sample_report_example.py --num_research_agent 2 \ - --index_name_full_text \ - --index_name_abstract + --index_name_full_text \ + --index_name_abstract ``` Base版本WebUI运行: ``` python ui.py --num_research_agent 2 \ - --index_name_full_text \ - --index_name_abstract + --index_name_full_text \ + --index_name_abstract ``` 高阶版本多智能体自动调度示例脚本运行: ``` -python sample_group_agent.py --index_name_full_text \ - --index_name_abstract +python sample_group_agent.py --index_name_full_text \ + --index_name_abstract ``` ## Reference diff --git a/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py b/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py index 5b2047ab..15fa6fd6 100644 --- a/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py @@ -154,8 +154,8 @@ async def verifications(self, facts_problems: List[dict]): for item in facts_problems: question = item["question"] claim = item["fact"] - context = self.retriever_db.search(question) - context = [i["content"] for i in context] + context = await self.retriever_db(question) + context = [i["content"] for i in context["documents"]] item["evidence"] = context anwser = await self.generate_anwser(question, context) item["anwser"] = anwser diff --git a/erniebot-agent/applications/erniebot_researcher/polish_agent.py b/erniebot-agent/applications/erniebot_researcher/polish_agent.py index 2653f7a7..71b0e5c5 100644 --- a/erniebot-agent/applications/erniebot_researcher/polish_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/polish_agent.py @@ -38,6 +38,8 @@ def __init__( citation_index_name: str, dir_path: str, report_type: str, + build_index_function: Any, + search_tool: Any, system_message: Optional[SystemMessage] = None, callbacks=None, ): @@ -58,6 +60,8 @@ def __init__( self.prompt_template_polish = PromptTemplate( template=self.template_polish, input_variables=["content"] ) + self.build_index_function = build_index_function + self.search_tool = search_tool if callbacks is None: self._callback_manager = ReportCallbackHandler() else: @@ -143,7 +147,13 @@ async def _run(self, report, summarize=None): final_report = await self.polish_paragraph(report, abstract, key) await self._callback_manager.on_tool_start(self, tool=self.citation_tool, input_args=final_report) if summarize is not None: - citation_search = add_citation(summarize, self.citation_index_name, self.embeddings) + citation_search = add_citation( + summarize, + self.citation_index_name, + self.embeddings, + self.build_index_function, + self.search_tool, + ) final_report, path = await self.citation_tool( report=final_report, agent_name=self.name, diff --git a/erniebot-agent/applications/erniebot_researcher/requirements.txt b/erniebot-agent/applications/erniebot_researcher/requirements.txt index a0d2c82c..854c9dba 100644 --- a/erniebot-agent/applications/erniebot_researcher/requirements.txt +++ b/erniebot-agent/applications/erniebot_researcher/requirements.txt @@ -4,3 +4,5 @@ langchain scikit-learn markdown WeasyPrint==52.5 +openai +langchain_openai diff --git a/erniebot-agent/applications/erniebot_researcher/research_agent.py b/erniebot-agent/applications/erniebot_researcher/research_agent.py index 6ec76b39..778668c1 100644 --- a/erniebot-agent/applications/erniebot_researcher/research_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/research_agent.py @@ -75,25 +75,21 @@ def __init__( async def run_search_summary(self, query: str): responses = [] - url_dict = {} - results = self.retriever_fulltext_db.search(query, top_k=3) + results = await self.retriever_fulltext_db(query, top_k=3) length_limit = 0 await self._callback_manager.on_tool_start(agent=self, tool=self.summarize_tool, input_args=query) - for doc in results: + for doc in results["documents"]: res = await self.summarize_tool(doc["content"], query) # Add reference to avoid hallucination - data = {"summary": res, "url": doc["url"], "name": doc["title"]} + data = {"summary": res, "url": doc["meta"]["url"], "name": doc["meta"]["name"]} length_limit += len(res) if length_limit < SUMMARIZE_MAX_LENGTH: responses.append(data) - key = doc["title"] - value = doc["url"] - url_dict[key] = value else: logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}") break await self._callback_manager.on_tool_end(self, tool=self.summarize_tool, response=responses) - return responses, url_dict + return responses async def run(self, query: str): """ @@ -117,8 +113,8 @@ async def run(self, query: str): if self.use_context_planning: sub_queries = [] - res = self.retriever_abstract_db.search(query, top_k=3) - context = [item["content"] for item in res] + res = await self.retriever_abstract_db(query, top_k=3) + context = [item["content"] for item in res["documents"]] context_content = "" await self._callback_manager.on_tool_start( agent=self, tool=self.task_planning_tool, input_args=query @@ -157,7 +153,7 @@ async def run(self, query: str): # Run Sub-Queries paragraphs_item = [] for sub_query in sub_queries: - research_result, url_dict = await self.run_search_summary(sub_query) + research_result = await self.run_search_summary(sub_query) paragraphs_item.extend(research_result) paragraphs = [] diff --git a/erniebot-agent/applications/erniebot_researcher/sample_group_agent.py b/erniebot-agent/applications/erniebot_researcher/sample_group_agent.py index 1fec42ac..48d17398 100644 --- a/erniebot-agent/applications/erniebot_researcher/sample_group_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/sample_group_agent.py @@ -6,19 +6,19 @@ from editor_actor_agent import EditorActorAgent from group_agent import GroupChat, GroupChatManager -from langchain.embeddings.openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings from polish_agent import PolishAgent from ranking_agent import RankingAgent from research_agent import ResearchAgent from reviser_actor_agent import ReviserActorAgent from tools.intent_detection_tool import IntentDetectionTool from tools.outline_generation_tool import OutlineGenerationTool +from tools.preprocessing import get_retriver_by_type from tools.ranking_tool import TextRankingTool from tools.report_writing_tool import ReportWritingTool from tools.semantic_citation_tool import SemanticCitationTool from tools.summarization_tool import TextSummarizationTool from tools.task_planning_tool import TaskPlanningTool -from tools.utils import FaissSearch, build_index from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings @@ -71,22 +71,29 @@ default="openai_embedding", help="['openai_embedding','baizhong','ernie_embedding']", ) +parser.add_argument( + "--use_frame", + type=str, + default="langchain", + choices=["langchain", "llama_index"], + help="['langchain','llama_index']", +) args = parser.parse_args() -def get_retrievers(): +def get_retrievers(build_index_function, retrieval_tool): if args.embedding_type == "openai_embedding": - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) - abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) - abstract_search = FaissSearch(abstract_db, embeddings=embeddings) - retriever_search = FaissSearch(paper_db, embeddings=embeddings) + embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada") + paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings) + abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings) + abstract_search = retrieval_tool(abstract_db, embeddings=embeddings) + retriever_search = retrieval_tool(paper_db, embeddings=embeddings) elif args.embedding_type == "ernie_embedding": embeddings = ErnieEmbeddings(aistudio_access_token=access_token) - paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) - abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) - abstract_search = FaissSearch(abstract_db, embeddings=embeddings) - retriever_search = FaissSearch(paper_db, embeddings=embeddings) + paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings) + abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings) + abstract_search = retrieval_tool(abstract_db, embeddings=embeddings) + retriever_search = retrieval_tool(paper_db, embeddings=embeddings) elif args.embedding_type == "baizhong": embeddings = ErnieEmbeddings(aistudio_access_token=access_token) retriever_search = BaizhongSearch( @@ -102,7 +109,9 @@ def get_retrievers(): return {"full_text": retriever_search, "abstract": abstract_search, "embeddings": embeddings} -def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): +def get_agents( + retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool +): research_actor = ResearchAgent( name="generate_report", system_message=SystemMessage("你是一个报告生成助手。你可以根据用户的指定内容生成一份报告手稿"), @@ -134,6 +143,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): dir_path=target_path, report_type=args.report_type, citation_tool=tool_sets["semantic_citation"], + build_index_function=build_index_function, + search_tool=retrieval_tool, ) return { "research_agents": research_actor, @@ -171,10 +182,12 @@ def main(query): os.makedirs(target_path, exist_ok=True) llm_long = ERNIEBot(model="ernie-longtext") llm = ERNIEBot(model="ernie-4.0") - - retriever_sets = get_retrievers() + build_index_function, retrieval_tool = get_retriver_by_type(args.use_frame) + retriever_sets = get_retrievers(build_index_function, retrieval_tool) tool_sets = get_tools(llm, llm_long) - agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path) + agent_sets = get_agents( + retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool + ) research_actor = agent_sets["research_agents"] report = asyncio.run(research_actor.run(query)) report = {"report": report[0], "paragraphs": report[1]} diff --git a/erniebot-agent/applications/erniebot_researcher/sample_report_example.py b/erniebot-agent/applications/erniebot_researcher/sample_report_example.py index bf9d1bf3..a1b45d9f 100644 --- a/erniebot-agent/applications/erniebot_researcher/sample_report_example.py +++ b/erniebot-agent/applications/erniebot_researcher/sample_report_example.py @@ -6,7 +6,7 @@ from editor_actor_agent import EditorActorAgent from fact_check_agent import FactCheckerAgent -from langchain.embeddings.openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings from polish_agent import PolishAgent from ranking_agent import RankingAgent from research_agent import ResearchAgent @@ -14,12 +14,12 @@ from reviser_actor_agent import ReviserActorAgent from tools.intent_detection_tool import IntentDetectionTool from tools.outline_generation_tool import OutlineGenerationTool +from tools.preprocessing import get_retriver_by_type from tools.ranking_tool import TextRankingTool from tools.report_writing_tool import ReportWritingTool from tools.semantic_citation_tool import SemanticCitationTool from tools.summarization_tool import TextSummarizationTool from tools.task_planning_tool import TaskPlanningTool -from tools.utils import FaissSearch, build_index from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings @@ -71,25 +71,31 @@ default="openai_embedding", help="['openai_embedding','baizhong','ernie_embedding']", ) - +parser.add_argument( + "--use_frame", + type=str, + default="langchain", + choices=["langchain", "llama_index"], + help="['langchain','llama_index']", +) args = parser.parse_args() os.environ["api_type"] = args.api_type access_token = os.environ.get("EB_AGENT_ACCESS_TOKEN", None) -def get_retrievers(): +def get_retrievers(build_index_function, retrieval_tool): if args.embedding_type == "openai_embedding": - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) - abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) - abstract_search = FaissSearch(abstract_db, embeddings=embeddings) - retriever_search = FaissSearch(paper_db, embeddings=embeddings) + embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada") + paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings) + abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings) + abstract_search = retrieval_tool(abstract_db, embeddings=embeddings) + retriever_search = retrieval_tool(paper_db, embeddings=embeddings) elif args.embedding_type == "ernie_embedding": embeddings = ErnieEmbeddings(aistudio_access_token=access_token) - paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) - abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) - abstract_search = FaissSearch(abstract_db, embeddings=embeddings) - retriever_search = FaissSearch(paper_db, embeddings=embeddings) + paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings) + abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings) + abstract_search = retrieval_tool(abstract_db, embeddings=embeddings) + retriever_search = retrieval_tool(paper_db, embeddings=embeddings) elif args.embedding_type == "baizhong": embeddings = ErnieEmbeddings(aistudio_access_token=access_token) retriever_search = BaizhongSearch( @@ -125,7 +131,7 @@ def get_tools(llm, llm_long): } -def get_agents(retriever_sets, tool_sets, llm, llm_long): +def get_agents(retriever_sets, tool_sets, llm, llm_long, build_index_function, retrieval_tool): dir_path = f"{args.save_path}/{hashlib.sha1(query.encode()).hexdigest()}" os.makedirs(dir_path, exist_ok=True) @@ -167,6 +173,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long): llm_long=llm_long, name="ranker", ranking_tool=tool_sets["ranking"], + build_index_function=build_index_function, + search_tool=retrieval_tool, ) return { "research_actor": research_actor, @@ -181,10 +189,10 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long): def main(query): llm_long = ERNIEBot(model="ernie-longtext") llm = ERNIEBot(model="ernie-4.0") - - retriever_sets = get_retrievers() + build_index_function, retrieval_tool = get_retriver_by_type(args.use_frame) + retriever_sets = get_retrievers(build_index_function, retrieval_tool) tool_sets = get_tools(llm, llm_long) - agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long) + agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, build_index_function, retrieval_tool) research_team = ResearchTeam(**agent_sets) report, file_path = asyncio.run(research_team.run(query)) diff --git a/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py b/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py new file mode 100644 index 00000000..4f92091e --- /dev/null +++ b/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py @@ -0,0 +1,286 @@ +import argparse +import json +import os +from typing import Any, Dict, List, Optional + +import faiss +import jsonlines +from langchain.docstore.document import Document +from langchain.text_splitter import SpacyTextSplitter +from langchain.vectorstores import FAISS +from langchain_community.document_loaders import DirectoryLoader +from llama_index import ( + ServiceContext, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, +) +from llama_index.vector_stores.faiss import FaissVectorStore + +from erniebot_agent.memory import HumanMessage, Message +from erniebot_agent.prompt import PromptTemplate +from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool +from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool + +ABSTRACTPROMPT = """ +{{content}} ,请用中文对上述文章进行总结。 +总结需要有概括性,不允许输出与文章内容无关的信息,字数控制在500字以内 +总结为: +""" + + +class GenerateAbstract: + def __init__(self, llm, chunk_size: int = 1500, chunk_overlap=0, path="./abstract.json"): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.llm = llm + self.text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + self.prompt = PromptTemplate(ABSTRACTPROMPT, input_variables=["content"]) + self.path = path + + def data_load(self, dir_path): + loader = DirectoryLoader(path=dir_path) + docs = loader.load() + return docs + + def split_documents(self, docs: List[Document]): + return self.text_splitter.split_documents(docs) + + async def generate_abstract(self, content: str): + content = self.prompt.format(report=content) + messages: List[Message] = [HumanMessage(content)] + response = await self.llm.chat(messages) + return response.content + + async def tackle_file(self, docs: List[Document]): + docs = self.split_documents(docs) + summaries = [] + for doc in docs: + import time + + time1 = time.time() + summary = await self.generate_abstract(doc.page_content) + time2 = time.time() + print(time2 - time1) + summaries.append(summary) + summary = "\n".join(summaries) + if len(summaries) > 1: + summary = await self.generate_abstract(summary) + return summary + + def write_json(self, data: List[dict]): + json_str = json.dumps(data, ensure_ascii=False) + with open(self.path, "w") as json_file: + json_file.write(json_str) + + async def run(self, data_dir, url_path=None): + if url_path: + _, suffix = os.path.splitext(url_path) + assert suffix == ".json" + with open(url_path) as f: + url_dict = json.load(f) + else: + url_dict = None + docs = self.data_load(data_dir) + summary_list = [] + for item in docs: + summary = await self.tackle_file([item]) + url = url_dict.get(item.metadata["source"], item.metadata["source"]) + if url_dict and item.metadata["source"] in url_dict: + item.metadata["source"] = url_dict[item.metadata["source"]] + + summary_list.append({"page_content": summary, "url": url, "name": item.metadata["source"]}) + self.write_json(summary_list) + return self.path + + +def add_url(url_dict: Dict, path: Optional[str] = None): + if not path: + path = "./url.json" + json_str = json.dumps(url_dict, ensure_ascii=False) + with open(path, "w") as json_file: + json_file.write(json_str) + + +def preprocess(data_dir, url_path=None): + loader = DirectoryLoader(path=data_dir) + docs = loader.load() + if url_path: + _, suffix = os.path.splitext(url_path) + assert suffix == ".json" + with open(url_path) as f: + url_dict = json.load(f) + for item in docs: + if "source" not in item.metadata: + item.metadata["source"] = "" + if item.metadata["source"] in url_dict: + item.metadata["url"] = url_dict[item.metadata["source"]] + else: + item.metadata["url"] = item.metadata["source"] + return docs + + +def build_index_langchain( + faiss_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None, use_data=False +): + if os.path.exists(faiss_name): + db = FAISS.load_local(faiss_name, embeddings) + elif abstract and not use_data: + all_docs = [] + with jsonlines.open(path) as reader: + for obj in reader: + if type(obj) is list: + for item in obj: + if "url" in item: + metadata = {"url": item["url"], "name": item["name"]} + else: + metadata = {"name": item["name"]} + doc = Document(page_content=item["page_content"], metadata=metadata) + all_docs.append(doc) + elif type(obj) is dict: + if "url" in obj: + metadata = {"url": obj["url"], "name": obj["name"]} + else: + metadata = {"name": obj["name"]} + doc = Document(page_content=obj["page_content"], metadata=metadata) + all_docs.append(doc) + db = FAISS.from_documents(all_docs, embeddings) + db.save_local(faiss_name) + elif not abstract and not use_data: + documents = preprocess(path, url_path) + text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) + docs = text_splitter.split_documents(documents) + docs_tackle = [] + for item in docs: + item.metadata["name"] = item.metadata["source"].split("/")[-1].split(".")[0] + docs_tackle.append(item) + db = FAISS.from_documents(docs_tackle, embeddings) + db.save_local(faiss_name) + elif use_data: + db = FAISS.from_documents(origin_data, embeddings) + db.save_local(faiss_name) + return db + + +def build_index_llama( + faiss_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None, use_data=False +): + if embeddings.model == "text-embedding-ada-002": + d = 1536 + elif embeddings.model == "ernie-text-embedding": + d = 384 + else: + raise ValueError(f"model {embeddings.model} not support") + + faiss_index = faiss.IndexFlatIP(d) + vector_store = FaissVectorStore(faiss_index=faiss_index) + if os.path.exists(faiss_name): + vector_store = FaissVectorStore.from_persist_dir(persist_dir=faiss_name) + storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=faiss_name) + return storage_context + from llama_index.node_parser import SentenceSplitter + + documents = SimpleDirectoryReader(path).load_data() + text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter) + index = VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + show_progress=True, + service_context=service_context, + ) + index.storage_context.persist(persist_dir=faiss_name) + return storage_context + + +def get_retriver_by_type(frame_type): + retriver_function = { + "langchain": [build_index_langchain, LangChainRetrievalTool], + "llama_index": [build_index_llama, LlamaIndexRetrievalTool], + } + return retriver_function[frame_type] + + +if __name__ == "__main__": + import asyncio + + from langchain_openai import AzureOpenAIEmbeddings + + from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings + + parser = argparse.ArgumentParser() + parser.add_argument( + "--index_name_full_text", + type=str, + default="", + help="The name of the full-text knowledge base(faiss)", + ) + parser.add_argument( + "--index_name_abstract", type=str, default="", help="The name of the abstract base(faiss)" + ) + parser.add_argument("--path_full_text", type=str, default="", help="Full-text data storage folder path") + parser.add_argument("--path_abstract", type=str, default="", help="json file path to store summary") + parser.add_argument( + "--embedding_type", + type=str, + default="openai_embedding", + choices=["openai_embedding", "baizhong", "ernie_embedding"], + help="['openai_embedding','baizhong','ernie_embedding']", + ) + parser.add_argument("--url_path", type=str, default="", help="json file path to store url link") + parser.add_argument( + "--use_frame", + type=str, + default="langchain", + choices=["langchain", "llama_index"], + help="['langchain','llama_index']", + ) + args = parser.parse_args() + access_token = os.environ["AISTUDIO_ACCESS_TOKEN"] + if args.embedding_type == "openai_embedding": + embeddings: Any = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada") + else: + embeddings = ErnieEmbeddings(aistudio_access_token=access_token) + _, suffix = os.path.splitext(args.url_path) + if suffix == ".txt": + url_path = args.url_path.replace(".txt", ".json") + with open(args.url_path, "r") as f: + data = f.read() + data_list = data.split("\n") + url_dict = {} + for item in data_list: + url, path = item.split(" ") + url_dict[path] = url + url_path + add_url(url_dict, path=url_path) + else: + url_path = args.url_path + if not args.path_abstract: + from erniebot_agent.chat_models import ERNIEBot + + llm = ERNIEBot(model="ernie-4.0") + generate_abstract = GenerateAbstract(llm=llm) + abstract_path = asyncio.run(generate_abstract.run(args.path_full_text, url_path)) + else: + abstract_path = args.path_abstract + build_index_fuction, retrieval_tool = get_retriver_by_type(args.use_frame) + full_text_db = build_index_fuction( + faiss_name=args.index_name_full_text, + embeddings=embeddings, + path=args.path_full_text, + url_path=url_path, + ) + abstract_db = build_index_fuction( + faiss_name=args.index_name_abstract, + embeddings=embeddings, + path=abstract_path, + abstract=True, + url_path=url_path, + ) + retrieval_full = retrieval_tool(full_text_db) + retrieval_abstract = retrieval_tool(abstract_db) + print(asyncio.run(retrieval_full("agent的发展"))) + print(asyncio.run(retrieval_abstract("agent的发展", top_k=2))) diff --git a/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py b/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py index 25e47e75..7db2ef10 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py @@ -41,26 +41,27 @@ def __init__(self, theta_min=0.4, theta_max=0.95, citation_num=5) -> None: self.recoder_cite_list: List = [] self.recoder_cite_title: List = [] - def add_url_sentences(self, sententces: str, citation_faiss_research): + async def add_url_sentences(self, sententces: str, citation_faiss_research): sentence_splits = sententces.split("。") output_sent = [] for sentence in sentence_splits: if not sentence: continue try: - query_result = citation_faiss_research.search(query=sentence, top_k=3, filters=None) + query_result = await citation_faiss_research(query=sentence, top_k=3, filters=None) except Exception as e: output_sent.append(sentence) logger.error(f"Faiss search error: {e}") continue if len(sentence.strip()) > 0: - if not self.is_punctuation(sentence[-1]): + if not self.is_punctuation(sentence[-1]) or sentence[-1] == "%": sentence += "。" - for item in query_result: - source = item["url"] + for item in query_result["documents"]: + source = item["meta"]["url"] + breakpoint() if item["score"] >= self.theta_min and item["score"] <= self.theta_max: if source not in self.recoder_cite_list: - self.recoder_cite_title.append(item["title"]) + self.recoder_cite_title.append(item["name"]) self.recoder_cite_list.append(source) self.recoder_cite_dict[source] = 1 index = len(self.recoder_cite_list) @@ -87,7 +88,7 @@ def add_url_sentences(self, sententces: str, citation_faiss_research): output_sent.append(sentence) return output_sent - def add_url_report(self, report: str, citation_faiss_research): + async def add_url_report(self, report: str, citation_faiss_research): list_data = report.split("\n\n") output_text = [] for chunk_text in list_data: @@ -98,7 +99,7 @@ def add_url_report(self, report: str, citation_faiss_research): output_text.append(chunk_text) continue else: - output_sent = self.add_url_sentences(chunk_text, citation_faiss_research) + output_sent = await self.add_url_sentences(chunk_text, citation_faiss_research) chunk_text = "".join(output_sent) output_text.append(chunk_text) report = "\n\n".join(output_text) @@ -142,7 +143,7 @@ async def __call__( self.theta_max = theta_max if citation_num: self.citation_num = citation_num - report = self.add_url_report(report, citation_faiss_research) + report = await self.add_url_report(report, citation_faiss_research) report = self.add_reference_report(report) path = write_md_to_pdf(agent_name + "__" + report_type, dir_path, report) return report, path diff --git a/erniebot-agent/applications/erniebot_researcher/tools/utils.py b/erniebot-agent/applications/erniebot_researcher/tools/utils.py index 5c7bf166..7391b530 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/utils.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/utils.py @@ -8,11 +8,7 @@ import jsonlines import markdown # type: ignore from langchain.docstore.document import Document -from langchain.document_loaders import PyPDFDirectoryLoader from langchain.output_parsers.json import parse_json_markdown -from langchain.text_splitter import SpacyTextSplitter -from langchain.vectorstores import FAISS -from sklearn.metrics.pairwise import cosine_similarity from weasyprint import CSS, HTML from weasyprint.fonts import FontConfiguration @@ -99,71 +95,6 @@ async def on_tool_error(self, agent: Any, tool, error): self.logger.error(f"Tool调用失败,错误信息:{error}") -class FaissSearch: - def __init__(self, db, embeddings): - self.db = db - self.embeddings = embeddings - - def search(self, query: str, top_k: int = 10, **kwargs): - docs = self.db.similarity_search(query, top_k) - para_result = self.embeddings.embed_documents([i.page_content for i in docs]) - query_result = self.embeddings.embed_query(query) - similarities = cosine_similarity([query_result], para_result).reshape((-1,)) - retrieval_results = [] - for index, doc in enumerate(docs): - retrieval_results.append( - { - "content": doc.page_content, - "score": similarities[index], - "title": doc.metadata["name"], - "url": doc.metadata["url"], - } - ) - return retrieval_results - - -def build_index(faiss_name, embeddings, path=None, abstract=False, origin_data=None, use_data=False): - if os.path.exists(faiss_name): - db = FAISS.load_local(faiss_name, embeddings) - elif abstract and not use_data: - all_docs = [] - with jsonlines.open(path) as reader: - for obj in reader: - if type(obj) is list: - for item in obj: - if "url" in item: - metadata = {"url": item["url"], "name": item["name"]} - else: - metadata = {"name": item["name"]} - doc = Document(page_content=item["page_content"], metadata=metadata) - all_docs.append(doc) - elif type(obj) is dict: - if "url" in obj: - metadata = {"url": obj["url"], "name": obj["name"]} - else: - metadata = {"name": obj["name"]} - doc = Document(page_content=obj["page_content"], metadata=metadata) - all_docs.append(doc) - db = FAISS.from_documents(all_docs, embeddings) - db.save_local(faiss_name) - elif not abstract and not use_data: - loader = PyPDFDirectoryLoader(path) - documents = loader.load() - text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) - docs = text_splitter.split_documents(documents) - docs_tackle = [] - for item in docs: - item.metadata["name"] = item.metadata["source"].split("/")[-1].replace(".pdf", "") - item.metadata["url"] = item.metadata["source"] - docs_tackle.append(item) - db = FAISS.from_documents(docs_tackle, embeddings) - db.save_local(faiss_name) - elif use_data: - db = FAISS.from_documents(origin_data, embeddings) - db.save_local(faiss_name) - return db - - def write_to_file(filename: str, text: str) -> None: """Write text to a file @@ -215,7 +146,7 @@ def write_to_json(filename: str, list_data: list, mode="w") -> None: file.write(item) -def add_citation(paragraphs, faiss_name, embeddings): +def add_citation(paragraphs, faiss_name, embeddings, build_index, FaissSearch): if os.path.exists(faiss_name): shutil.rmtree(faiss_name) list_data = [] @@ -225,7 +156,7 @@ def add_citation(paragraphs, faiss_name, embeddings): faiss_db = build_index( faiss_name=faiss_name, use_data=True, embeddings=embeddings, origin_data=list_data ) - faiss_search = FaissSearch(db=faiss_db, embeddings=embeddings) + faiss_search = FaissSearch(db=faiss_db) return faiss_search diff --git a/erniebot-agent/applications/erniebot_researcher/ui.py b/erniebot-agent/applications/erniebot_researcher/ui.py index e4a22ac4..0e61cfd4 100644 --- a/erniebot-agent/applications/erniebot_researcher/ui.py +++ b/erniebot-agent/applications/erniebot_researcher/ui.py @@ -6,7 +6,7 @@ import gradio as gr from editor_actor_agent import EditorActorAgent from fact_check_agent import FactCheckerAgent -from langchain.embeddings.openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings from polish_agent import PolishAgent from ranking_agent import RankingAgent from research_agent import ResearchAgent @@ -14,12 +14,13 @@ from reviser_actor_agent import ReviserActorAgent from tools.intent_detection_tool import IntentDetectionTool from tools.outline_generation_tool import OutlineGenerationTool +from tools.preprocessing import get_retriver_by_type from tools.ranking_tool import TextRankingTool from tools.report_writing_tool import ReportWritingTool from tools.semantic_citation_tool import SemanticCitationTool from tools.summarization_tool import TextSummarizationTool from tools.task_planning_tool import TaskPlanningTool -from tools.utils import FaissSearch, ReportCallbackHandler, build_index, setup_logging +from tools.utils import ReportCallbackHandler, setup_logging from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings @@ -69,9 +70,17 @@ parser.add_argument( "--embedding_type", type=str, + choices=["openai_embedding", "ernie_embedding", "baizhong"], default="openai_embedding", help="['openai_embedding','baizhong','ernie_embedding']", ) +parser.add_argument( + "--use_frame", + type=str, + default="langchain", + choices=["langchain", "llama_index"], + help="['langchain','llama_index']", +) parser.add_argument("--save_path", type=str, default="./output/erniebot", help="The report save path") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="the host of server") parser.add_argument("--server_port", type=int, default=8878, help="the port of server") @@ -88,19 +97,19 @@ def get_logs(path=args.log_path): return content -def get_retrievers(): +def get_retrievers(build_index_function, retrieval_tool): if args.embedding_type == "openai_embedding": - embeddings = OpenAIEmbeddings(deployment="text-embedding-ada") - paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) - abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) - abstract_search = FaissSearch(abstract_db, embeddings=embeddings) - retriever_search = FaissSearch(paper_db, embeddings=embeddings) + embeddings = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada") + paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings) + abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings) + abstract_search = retrieval_tool(abstract_db) + retriever_search = retrieval_tool(paper_db) elif args.embedding_type == "ernie_embedding": embeddings = ErnieEmbeddings(aistudio_access_token=access_token) - paper_db = build_index(faiss_name=args.index_name_full_text, embeddings=embeddings) - abstract_db = build_index(faiss_name=args.index_name_abstract, embeddings=embeddings) - abstract_search = FaissSearch(abstract_db, embeddings=embeddings) - retriever_search = FaissSearch(paper_db, embeddings=embeddings) + paper_db = build_index_function(faiss_name=args.index_name_full_text, embeddings=embeddings) + abstract_db = build_index_function(faiss_name=args.index_name_abstract, embeddings=embeddings) + abstract_search = retrieval_tool(abstract_db) + retriever_search = retrieval_tool(paper_db) elif args.embedding_type == "baizhong": embeddings = ErnieEmbeddings(aistudio_access_token=access_token) retriever_search = BaizhongSearch( @@ -136,7 +145,9 @@ def get_tools(llm, llm_long): } -def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): +def get_agents( + retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool +): research_actor = [] for i in range(args.num_research_agent): agents_name = "agent_" + str(i) @@ -185,6 +196,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): report_type=args.report_type, citation_tool=tool_sets["semantic_citation"], callbacks=ReportCallbackHandler(logger=logger), + build_index_function=build_index_function, + search_tool=retrieval_tool, ) return { "research_actor": research_actor, @@ -203,9 +216,13 @@ def generate_report(query, history=[]): os.makedirs(target_path, exist_ok=True) llm = ERNIEBot(model="ernie-4.0") llm_long = ERNIEBot(model="ernie-longtext") - retriever_sets = get_retrievers() + build_index_function, retrieval_tool = get_retriver_by_type(args.use_frame) + + retriever_sets = get_retrievers(build_index_function, retrieval_tool) tool_sets = get_tools(llm, llm_long) - agent_sets = get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path) + agent_sets = get_agents( + retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool + ) team_actor = ResearchTeam(**agent_sets, use_reflection=True) report, path = asyncio.run(team_actor.run(query, args.iterations)) return report, path