Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 4, 2024
1 parent 5054ad5 commit 1ea1a0c
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 39 deletions.
16 changes: 8 additions & 8 deletions erniebot-agent/applications/erniebot_researcher/ranking_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,28 @@ def __init__(
self._callback_manager = callbacks

async def _run(self, list_reports, query):
self._callback_manager.on_run_start(self.name, "")
await self._callback_manager.on_run_start(self.name, "")
reports = []
for item in list_reports:
if self.check_format(item):
if await self.check_format(item):
reports.append(item)
if len(reports) == 0:
if self.is_reset:
self._callback_manager.on_run_end(self.name, "所有的report都不是markdown格式,重新生成report")
await self._callback_manager.on_run_end(self.name, "所有的report都不是markdown格式,重新生成report")
logger.info("所有的report都不是markdown格式,重新生成report")
return [], None
else:
reports = list_reports
best_report = await self.ranking(reports, query)
self._callback_manager.on_run_tool(self.ranking.description, best_report)
self._callback_manager.on_run_end(self.name, "")
await self._callback_manager.on_run_tool(self.ranking.description, best_report)
await self._callback_manager.on_run_end(self.name, "")
return reports, best_report

def check_format(self, report):
async def check_format(self, report):
while True:
try:
messages = [HumanMessage(content=get_markdown_check_prompt(report))]
response = self.llm.chat(messages=messages, temperature=0.001)
response = await self.llm.chat(messages=messages, temperature=0.001)
result = response.content
l_index = result.index("{")
r_index = result.index("}")
Expand All @@ -80,6 +80,6 @@ def check_format(self, report):
elif result_dict["accept"] is False or result_dict["accept"] == "false":
return False
except Exception as e:
self._callback_manager.on_run_error("格式检查", str(e))
await self._callback_manager.on_run_error("格式检查", str(e))
logger.error(e)
continue
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
import logging
from collections import OrderedDict
from typing import List, Optional
from typing import Optional

from tools.utils import ReportCallbackHandler, add_citation

from erniebot_agent.agents.agent import Agent
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file.base import File
from erniebot_agent.memory import HumanMessage
from erniebot_agent.prompt import PromptTemplate

Expand All @@ -20,7 +18,7 @@
"""


class ResearchAgent(Agent):
class ResearchAgent:
"""
ResearchAgent, refer to
https://github.com/assafelovic/gpt-researcher/blob/master/examples/permchain_agents/research_team.py
Expand Down Expand Up @@ -100,11 +98,11 @@ async def run_search_summary(self, query):
value = doc["url"]
url_dict[key] = value
else:
print(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
break
return responses, url_dict

async def _run(self, query, files: Optional[List[File]] = None):
async def run(self, query):
"""
Runs the ResearchAgent
Returns:
Expand Down Expand Up @@ -199,5 +197,5 @@ async def _run(self, query, files: Optional[List[File]] = None):
report, url_index, self.agent_name, self.report_type, self.dir_path, citation_search
)
await self._callback_manager.on_run_tool(tool_name=self.citation.description, response=final_report)
await self._callback_manager.on_run_end(tool_name=self.name, response=f"报告存储在{path}")
await self._callback_manager.on_run_end(self, agent_name=self.name, response=f"报告存储在{path}")
return final_report, path
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pydantic import Field


from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage
from erniebot_agent.prompt import PromptTemplate
Expand Down Expand Up @@ -128,9 +129,10 @@ class ReportWritingTool(Tool):
input_type: Type[ToolParameterView] = ReportWritingToolInputView
ouptut_type: Type[ToolParameterView] = ReportWritingToolOutputView

def __init__(self, llm: BaseERNIEBot) -> None:
def __init__(self, llm: BaseERNIEBot, llm_long: BaseERNIEBot) -> None:
super().__init__()
self.llm = llm
self.llm_long = llm_long

async def __call__(
self,
Expand All @@ -145,7 +147,7 @@ async def __call__(
research_summary = research_summary[: TOKEN_MAX_LENGTH - 600]
report_type_func = get_report_by_type(report_type)
messages = [HumanMessage(report_type_func(question, research_summary, outline))]
response = await self.llm.chat(messages, system=agent_role_prompt)
response = await self.llm_long.chat(messages, system=agent_role_prompt)
final_report = response.content
if final_report == "":
raise Exception("报告生成错误")
Expand Down Expand Up @@ -177,5 +179,5 @@ async def __call__(
if meta_data:
for index, (key, val) in enumerate(meta_data.items()):
url_index[val] = {"name": key, "index": index + 1}
# final_report=postprocess(final_report)
# final_report = postprocess(final_report)
return final_report, url_index
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def __call__(
if "参考文献" in chunk_text:
output_text.append(chunk_text)
break
elif "#" in chunk_text:
elif "#" in chunk_text or "摘要" in chunk_text or "关键词" in chunk_text:
output_text.append(chunk_text)
continue
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def postprocess(report):
abstract_json = json.loads(abstract_json[l_index : r_index + 1])
abstract = abstract_json["摘要"]
key = abstract_json["关键词"]
if type(key) is list:
key = ",".join(key)
break
except Exception as e:
print(e)
Expand All @@ -196,12 +198,12 @@ def postprocess(report):
paragraphs = []
title = report_list[0]
paragraphs.append(title)
paragraphs.append("**摘要:**" + abstract)
paragraphs.append("**关键词:**" + key)
paragraphs.append("**摘要** " + abstract)
paragraphs.append("**关键词** " + key)
content = ""
for item in report_list[1:]:
if "#" not in item:
content += item
content += item + "\n"
else:
if len(content) > 300:
paragraphs.append(content)
Expand Down
35 changes: 18 additions & 17 deletions erniebot-agent/applications/erniebot_researcher/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import os

import gradio as gr
from EditorActorAgent import EditorActorAgent
from editor_actor_agent import EditorActorAgent
from langchain.embeddings.openai import OpenAIEmbeddings
from RankingAgent import RankingAgent
from ResearchAgent import ResearchAgent
from ReviserActorAgent import ReviserActorAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
from reviser_actor_agent import ReviserActorAgent
from tools.intent_detection_tool import IntentDetectionTool
from tools.outline_generation_tool import OutlineGenerationTool
from tools.ranking_tool import TextRankingTool
Expand All @@ -19,6 +19,7 @@
from tools.task_planning_tool import TaskPlanningTool
from tools.utils import FaissSearch, build_index, write_md_to_pdf

from erniebot_agent.chat_models import ERNIEBot
from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
from erniebot_agent.retrieval import BaizhongSearch

Expand Down Expand Up @@ -89,13 +90,14 @@ def generate_report(query, history=[]):
knowledge_base_name=args.knowledge_base_name_abstract,
knowledge_base_id=args.knowledge_base_id_abstract,
)

intent_detection_tool = IntentDetectionTool()
outline_generation_tool = OutlineGenerationTool()
ranking_tool = TextRankingTool()
report_writing_tool = ReportWritingTool()
llm = ERNIEBot(model="ernie-4.0")
llm_long = ERNIEBot(model="ernie-longtext")
intent_detection_tool = IntentDetectionTool(llm)
outline_generation_tool = OutlineGenerationTool(llm)
ranking_tool = TextRankingTool(llm, llm_long)
report_writing_tool = ReportWritingTool(llm, llm_long)
summarization_tool = TextSummarizationTool()
task_planning_tool = TaskPlanningTool()
task_planning_tool = TaskPlanningTool(llm=llm)
semantic_citation_tool = SemanticCitationTool()
dir_path = f"./outputs/erniebot/{hashlib.sha1(query.encode()).hexdigest()}"
target_path = f"./outputsl/erniebot/{hashlib.sha1(query.encode()).hexdigest()}/revised"
Expand All @@ -120,18 +122,17 @@ def generate_report(query, history=[]):
summarize_tool=summarization_tool,
faiss_name_citation=args.faiss_name_citation,
embeddings=embeddings,
llm=llm,
)
research_actor.append(research_agent)
editor_actor = EditorActorAgent(name="editor")
reviser_actor = ReviserActorAgent(name="reviser")
ranker_actor = RankingAgent(
name="ranker",
ranking_tool=ranking_tool,
)
editor_actor = EditorActorAgent(name="editor", llm=llm)
reviser_actor = ReviserActorAgent(name="reviser", llm=llm)
ranker_actor = RankingAgent(name="ranker", ranking_tool=ranking_tool, llm=llm)
list_reports = []
for researcher in research_actor:
report, _ = asyncio.run(researcher.run(query))
list_reports.append(report)
breakpoint()
for i in range(args.iterations):
if len(list_reports) > 1:
list_reports, immedia_report = asyncio.run(ranker_actor._run(list_reports, query))
Expand Down Expand Up @@ -185,7 +186,7 @@ def launch_ui():
clear = gr.Button("清除", variant="primary", scale=1)
submit.click(generate_report, inputs=[query_textbox], outputs=[report, report_url])
clear.click(lambda _: ([None, None]), outputs=[report, report_url])
recording = gr.Textbox(label="历史记录")
recording = gr.Textbox(label="历史记录", max_lines=10)
with gr.Row():
clear_recoding = gr.Button(value="记录清除")
submit_recoding = gr.Button(value="记录更新")
Expand Down

0 comments on commit 1ea1a0c

Please sign in to comment.