Skip to content

Commit

Permalink
add fact_check_agent
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 17, 2024
1 parent 7e43415 commit bcf3a94
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, List, Optional

from tools.utils import JsonUtil, ReportCallbackHandler

Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
else:
self._callback_manager = callbacks

async def run(self, report: Union[str, dict]):
async def run(self, report: str):
await self._callback_manager.on_run_start(
agent=self, agent_name=self.name, prompt=self.system_message
)
Expand Down Expand Up @@ -145,8 +145,6 @@ async def report_fact(self, report: str):
text.append(item)
return "\n\n".join(text)

async def _run(self, report: Union[str, Dict[str, str]]):
if isinstance(report, dict):
report = report["report"]
async def _run(self, report: str):
report = await self.report_fact(report)
return report
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

from editor_actor_agent import EditorActorAgent
from fact_check_agent import FactCheckerAgent
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
from research_agent import ResearchAgent
Expand All @@ -16,6 +17,7 @@ def __init__(
ranker_actor: RankingAgent,
editor_actor: EditorActorAgent,
reviser_actor: ReviserActorAgent,
checker_actor: FactCheckerAgent,
polish_actor: Optional[PolishAgent] = None,
user_agent: Optional[UserProxyAgent] = None,
use_reflection: bool = False,
Expand All @@ -25,6 +27,7 @@ def __init__(
self.revise_actor_instance = reviser_actor
self.ranker_actor_instance = ranker_actor
self.polish_actor_instance = polish_actor
self.checker_actor_instance = checker_actor
self.user_agent = user_agent
self.polish_actor = polish_actor
self.use_reflection = use_reflection
Expand Down Expand Up @@ -77,9 +80,9 @@ async def run(self, query, iterations=3):
immedia_report = list_reports[0]

revised_report = immedia_report

checked_report = await self.checker_actor_instance.run(report=revised_report["report"])
revised_report, path = await self.polish_actor_instance.run(
report=revised_report["report"],
report=checked_report,
summarize=revised_report["paragraphs"],
)
return revised_report, path
8 changes: 8 additions & 0 deletions erniebot-agent/applications/erniebot_researcher/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import gradio as gr
from editor_actor_agent import EditorActorAgent
from fact_check_agent import FactCheckerAgent
from langchain.embeddings.openai import OpenAIEmbeddings
from polish_agent import PolishAgent
from ranking_agent import RankingAgent
Expand Down Expand Up @@ -168,6 +169,12 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
llm_long=llm_long,
callbacks=ReportCallbackHandler(logger=logger),
)
checker_actor = FactCheckerAgent(
name="fact_check",
llm=llm,
retriever_db=retriever_sets["full_text"],
callbacks=ReportCallbackHandler(logger=logger),
)
polish_actor = PolishAgent(
name="polish",
llm=llm,
Expand All @@ -184,6 +191,7 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
"editor_actor": editor_actor,
"reviser_actor": reviser_actor,
"ranker_actor": ranker_actor,
"checker_actor": checker_actor,
"polish_actor": polish_actor,
}

Expand Down

0 comments on commit bcf3a94

Please sign in to comment.