From 1ea1a0c73b7a70a5b94d6516ea725461e7e74849 Mon Sep 17 00:00:00 2001 From: qingzhong1 Date: Thu, 4 Jan 2024 05:36:47 +0000 Subject: [PATCH] debug --- .../erniebot_researcher/ranking_agent.py | 16 ++++----- .../erniebot_researcher/research_agent.py | 12 +++---- .../tools/report_writing_tool.py | 8 +++-- .../tools/semantic_citation_tool.py | 2 +- .../erniebot_researcher/tools/utils.py | 8 +++-- .../applications/erniebot_researcher/ui.py | 35 ++++++++++--------- 6 files changed, 42 insertions(+), 39 deletions(-) diff --git a/erniebot-agent/applications/erniebot_researcher/ranking_agent.py b/erniebot-agent/applications/erniebot_researcher/ranking_agent.py index 56ca68dd5..58e1fb78e 100644 --- a/erniebot-agent/applications/erniebot_researcher/ranking_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/ranking_agent.py @@ -48,28 +48,28 @@ def __init__( self._callback_manager = callbacks async def _run(self, list_reports, query): - self._callback_manager.on_run_start(self.name, "") + await self._callback_manager.on_run_start(self.name, "") reports = [] for item in list_reports: - if self.check_format(item): + if await self.check_format(item): reports.append(item) if len(reports) == 0: if self.is_reset: - self._callback_manager.on_run_end(self.name, "所有的report都不是markdown格式,重新生成report") + await self._callback_manager.on_run_end(self.name, "所有的report都不是markdown格式,重新生成report") logger.info("所有的report都不是markdown格式,重新生成report") return [], None else: reports = list_reports best_report = await self.ranking(reports, query) - self._callback_manager.on_run_tool(self.ranking.description, best_report) - self._callback_manager.on_run_end(self.name, "") + await self._callback_manager.on_run_tool(self.ranking.description, best_report) + await self._callback_manager.on_run_end(self.name, "") return reports, best_report - def check_format(self, report): + async def check_format(self, report): while True: try: messages = [HumanMessage(content=get_markdown_check_prompt(report))] - response = self.llm.chat(messages=messages, temperature=0.001) + response = await self.llm.chat(messages=messages, temperature=0.001) result = response.content l_index = result.index("{") r_index = result.index("}") @@ -80,6 +80,6 @@ def check_format(self, report): elif result_dict["accept"] is False or result_dict["accept"] == "false": return False except Exception as e: - self._callback_manager.on_run_error("格式检查", str(e)) + await self._callback_manager.on_run_error("格式检查", str(e)) logger.error(e) continue diff --git a/erniebot-agent/applications/erniebot_researcher/research_agent.py b/erniebot-agent/applications/erniebot_researcher/research_agent.py index 8b6fac2f9..de4e9e1bb 100644 --- a/erniebot-agent/applications/erniebot_researcher/research_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/research_agent.py @@ -1,13 +1,11 @@ import json import logging from collections import OrderedDict -from typing import List, Optional +from typing import Optional from tools.utils import ReportCallbackHandler, add_citation -from erniebot_agent.agents.agent import Agent from erniebot_agent.chat_models.erniebot import BaseERNIEBot -from erniebot_agent.file.base import File from erniebot_agent.memory import HumanMessage from erniebot_agent.prompt import PromptTemplate @@ -20,7 +18,7 @@ """ -class ResearchAgent(Agent): +class ResearchAgent: """ ResearchAgent, refer to https://github.com/assafelovic/gpt-researcher/blob/master/examples/permchain_agents/research_team.py @@ -100,11 +98,11 @@ async def run_search_summary(self, query): value = doc["url"] url_dict[key] = value else: - print(f"summary size exceed {SUMMARIZE_MAX_LENGTH}") + logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}") break return responses, url_dict - async def _run(self, query, files: Optional[List[File]] = None): + async def run(self, query): """ Runs the ResearchAgent Returns: @@ -199,5 +197,5 @@ async def _run(self, query, files: Optional[List[File]] = None): report, url_index, self.agent_name, self.report_type, self.dir_path, citation_search ) await self._callback_manager.on_run_tool(tool_name=self.citation.description, response=final_report) - await self._callback_manager.on_run_end(tool_name=self.name, response=f"报告存储在{path}") + await self._callback_manager.on_run_end(self, agent_name=self.name, response=f"报告存储在{path}") return final_report, path diff --git a/erniebot-agent/applications/erniebot_researcher/tools/report_writing_tool.py b/erniebot-agent/applications/erniebot_researcher/tools/report_writing_tool.py index 1de3ad26d..aa32d1655 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/report_writing_tool.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/report_writing_tool.py @@ -6,6 +6,7 @@ from pydantic import Field + from erniebot_agent.chat_models.erniebot import BaseERNIEBot from erniebot_agent.memory import HumanMessage from erniebot_agent.prompt import PromptTemplate @@ -128,9 +129,10 @@ class ReportWritingTool(Tool): input_type: Type[ToolParameterView] = ReportWritingToolInputView ouptut_type: Type[ToolParameterView] = ReportWritingToolOutputView - def __init__(self, llm: BaseERNIEBot) -> None: + def __init__(self, llm: BaseERNIEBot, llm_long: BaseERNIEBot) -> None: super().__init__() self.llm = llm + self.llm_long = llm_long async def __call__( self, @@ -145,7 +147,7 @@ async def __call__( research_summary = research_summary[: TOKEN_MAX_LENGTH - 600] report_type_func = get_report_by_type(report_type) messages = [HumanMessage(report_type_func(question, research_summary, outline))] - response = await self.llm.chat(messages, system=agent_role_prompt) + response = await self.llm_long.chat(messages, system=agent_role_prompt) final_report = response.content if final_report == "": raise Exception("报告生成错误") @@ -177,5 +179,5 @@ async def __call__( if meta_data: for index, (key, val) in enumerate(meta_data.items()): url_index[val] = {"name": key, "index": index + 1} - # final_report=postprocess(final_report) + # final_report = postprocess(final_report) return final_report, url_index diff --git a/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py b/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py index 0ef4e5b60..8315d34d2 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/semantic_citation_tool.py @@ -46,7 +46,7 @@ async def __call__( if "参考文献" in chunk_text: output_text.append(chunk_text) break - elif "#" in chunk_text: + elif "#" in chunk_text or "摘要" in chunk_text or "关键词" in chunk_text: output_text.append(chunk_text) continue else: diff --git a/erniebot-agent/applications/erniebot_researcher/tools/utils.py b/erniebot-agent/applications/erniebot_researcher/tools/utils.py index 7ef42025a..a6f49f079 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/utils.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/utils.py @@ -187,6 +187,8 @@ def postprocess(report): abstract_json = json.loads(abstract_json[l_index : r_index + 1]) abstract = abstract_json["摘要"] key = abstract_json["关键词"] + if type(key) is list: + key = ",".join(key) break except Exception as e: print(e) @@ -196,12 +198,12 @@ def postprocess(report): paragraphs = [] title = report_list[0] paragraphs.append(title) - paragraphs.append("**摘要:**" + abstract) - paragraphs.append("**关键词:**" + key) + paragraphs.append("**摘要** " + abstract) + paragraphs.append("**关键词** " + key) content = "" for item in report_list[1:]: if "#" not in item: - content += item + content += item + "\n" else: if len(content) > 300: paragraphs.append(content) diff --git a/erniebot-agent/applications/erniebot_researcher/ui.py b/erniebot-agent/applications/erniebot_researcher/ui.py index 8dd19945b..7783d0c41 100644 --- a/erniebot-agent/applications/erniebot_researcher/ui.py +++ b/erniebot-agent/applications/erniebot_researcher/ui.py @@ -5,11 +5,11 @@ import os import gradio as gr -from EditorActorAgent import EditorActorAgent +from editor_actor_agent import EditorActorAgent from langchain.embeddings.openai import OpenAIEmbeddings -from RankingAgent import RankingAgent -from ResearchAgent import ResearchAgent -from ReviserActorAgent import ReviserActorAgent +from ranking_agent import RankingAgent +from research_agent import ResearchAgent +from reviser_actor_agent import ReviserActorAgent from tools.intent_detection_tool import IntentDetectionTool from tools.outline_generation_tool import OutlineGenerationTool from tools.ranking_tool import TextRankingTool @@ -19,6 +19,7 @@ from tools.task_planning_tool import TaskPlanningTool from tools.utils import FaissSearch, build_index, write_md_to_pdf +from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings from erniebot_agent.retrieval import BaizhongSearch @@ -89,13 +90,14 @@ def generate_report(query, history=[]): knowledge_base_name=args.knowledge_base_name_abstract, knowledge_base_id=args.knowledge_base_id_abstract, ) - - intent_detection_tool = IntentDetectionTool() - outline_generation_tool = OutlineGenerationTool() - ranking_tool = TextRankingTool() - report_writing_tool = ReportWritingTool() + llm = ERNIEBot(model="ernie-4.0") + llm_long = ERNIEBot(model="ernie-longtext") + intent_detection_tool = IntentDetectionTool(llm) + outline_generation_tool = OutlineGenerationTool(llm) + ranking_tool = TextRankingTool(llm, llm_long) + report_writing_tool = ReportWritingTool(llm, llm_long) summarization_tool = TextSummarizationTool() - task_planning_tool = TaskPlanningTool() + task_planning_tool = TaskPlanningTool(llm=llm) semantic_citation_tool = SemanticCitationTool() dir_path = f"./outputs/erniebot/{hashlib.sha1(query.encode()).hexdigest()}" target_path = f"./outputsl/erniebot/{hashlib.sha1(query.encode()).hexdigest()}/revised" @@ -120,18 +122,17 @@ def generate_report(query, history=[]): summarize_tool=summarization_tool, faiss_name_citation=args.faiss_name_citation, embeddings=embeddings, + llm=llm, ) research_actor.append(research_agent) - editor_actor = EditorActorAgent(name="editor") - reviser_actor = ReviserActorAgent(name="reviser") - ranker_actor = RankingAgent( - name="ranker", - ranking_tool=ranking_tool, - ) + editor_actor = EditorActorAgent(name="editor", llm=llm) + reviser_actor = ReviserActorAgent(name="reviser", llm=llm) + ranker_actor = RankingAgent(name="ranker", ranking_tool=ranking_tool, llm=llm) list_reports = [] for researcher in research_actor: report, _ = asyncio.run(researcher.run(query)) list_reports.append(report) + breakpoint() for i in range(args.iterations): if len(list_reports) > 1: list_reports, immedia_report = asyncio.run(ranker_actor._run(list_reports, query)) @@ -185,7 +186,7 @@ def launch_ui(): clear = gr.Button("清除", variant="primary", scale=1) submit.click(generate_report, inputs=[query_textbox], outputs=[report, report_url]) clear.click(lambda _: ([None, None]), outputs=[report, report_url]) - recording = gr.Textbox(label="历史记录") + recording = gr.Textbox(label="历史记录", max_lines=10) with gr.Row(): clear_recoding = gr.Button(value="记录清除") submit_recoding = gr.Button(value="记录更新")