diff --git a/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py b/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py index ef6cfe59..b6adc319 100644 --- a/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py +++ b/erniebot-agent/applications/erniebot_researcher/fact_check_agent.py @@ -1,6 +1,6 @@ import logging import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional from tools.utils import JsonUtil, ReportCallbackHandler @@ -66,7 +66,7 @@ def __init__( else: self._callback_manager = callbacks - async def run(self, report: Union[str, dict]): + async def run(self, report: str): await self._callback_manager.on_run_start( agent=self, agent_name=self.name, prompt=self.system_message ) @@ -145,8 +145,6 @@ async def report_fact(self, report: str): text.append(item) return "\n\n".join(text) - async def _run(self, report: Union[str, Dict[str, str]]): - if isinstance(report, dict): - report = report["report"] + async def _run(self, report: str): report = await self.report_fact(report) return report diff --git a/erniebot-agent/applications/erniebot_researcher/research_team.py b/erniebot-agent/applications/erniebot_researcher/research_team.py index fa2d057e..f904b04d 100644 --- a/erniebot-agent/applications/erniebot_researcher/research_team.py +++ b/erniebot-agent/applications/erniebot_researcher/research_team.py @@ -2,6 +2,7 @@ from typing import List, Optional from editor_actor_agent import EditorActorAgent +from fact_check_agent import FactCheckerAgent from polish_agent import PolishAgent from ranking_agent import RankingAgent from research_agent import ResearchAgent @@ -16,6 +17,7 @@ def __init__( ranker_actor: RankingAgent, editor_actor: EditorActorAgent, reviser_actor: ReviserActorAgent, + checker_actor: FactCheckerAgent, polish_actor: Optional[PolishAgent] = None, user_agent: Optional[UserProxyAgent] = None, use_reflection: bool = False, @@ -25,6 +27,7 @@ def __init__( self.revise_actor_instance = reviser_actor self.ranker_actor_instance = ranker_actor self.polish_actor_instance = polish_actor + self.checker_actor_instance = checker_actor self.user_agent = user_agent self.polish_actor = polish_actor self.use_reflection = use_reflection @@ -77,9 +80,9 @@ async def run(self, query, iterations=3): immedia_report = list_reports[0] revised_report = immedia_report - + checked_report = await self.checker_actor_instance.run(report=revised_report["report"]) revised_report, path = await self.polish_actor_instance.run( - report=revised_report["report"], + report=checked_report, summarize=revised_report["paragraphs"], ) return revised_report, path diff --git a/erniebot-agent/applications/erniebot_researcher/ui.py b/erniebot-agent/applications/erniebot_researcher/ui.py index 5222119c..e4a22ac4 100644 --- a/erniebot-agent/applications/erniebot_researcher/ui.py +++ b/erniebot-agent/applications/erniebot_researcher/ui.py @@ -5,6 +5,7 @@ import gradio as gr from editor_actor_agent import EditorActorAgent +from fact_check_agent import FactCheckerAgent from langchain.embeddings.openai import OpenAIEmbeddings from polish_agent import PolishAgent from ranking_agent import RankingAgent @@ -168,6 +169,12 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): llm_long=llm_long, callbacks=ReportCallbackHandler(logger=logger), ) + checker_actor = FactCheckerAgent( + name="fact_check", + llm=llm, + retriever_db=retriever_sets["full_text"], + callbacks=ReportCallbackHandler(logger=logger), + ) polish_actor = PolishAgent( name="polish", llm=llm, @@ -184,6 +191,7 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path): "editor_actor": editor_actor, "reviser_actor": reviser_actor, "ranker_actor": ranker_actor, + "checker_actor": checker_actor, "polish_actor": polish_actor, }