From abeddbca0e85ebbdc37472e355ed52de389f207c Mon Sep 17 00:00:00 2001 From: mskarlin <12701035+mskarlin@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:03:52 -0800 Subject: [PATCH] Expose more contexts as output from gather_evidence tool (#793) Co-authored-by: James Braza --- paperqa/agents/tools.py | 14 ++++++++++---- paperqa/settings.py | 6 ++++++ tests/test_agents.py | 20 +++++++++++++++++++- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index 28cd5458..fa86eef9 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -253,12 +253,18 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: sorted_contexts = sorted( state.session.contexts, key=lambda x: x.score, reverse=True ) - best_evidence = ( - f" Best evidence:\n\n{sorted_contexts[0].context}" - if sorted_contexts - else "" + + top_contexts = "\n".join( + [ + f"{n + 1}. {sc.context}\n" + for n, sc in enumerate( + sorted_contexts[: self.settings.agent.agent_evidence_n] + ) + ] ) + best_evidence = f" Best evidence(s):\n\n{top_contexts}" if top_contexts else "" + if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks: await asyncio.gather( *( diff --git a/paperqa/settings.py b/paperqa/settings.py index b7243b67..2120c62f 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -467,6 +467,12 @@ class AgentSettings(BaseModel): ) search_count: int = 8 wipe_context_on_answer_failure: bool = True + agent_evidence_n: int = Field( + default=1, + ge=1, + description="Top n ranked evidences shown to the " + "agent after the GatherEvidence tool.", + ) timeout: float = Field( default=500.0, description=( diff --git a/tests/test_agents.py b/tests/test_agents.py index 6a5ed814..e239a624 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -572,12 +572,30 @@ def new_status(state: EnvironmentState) -> str: summary_llm_model=summary_llm_model, embedding_model=embedding_model, ) - await gather_evidence_tool.gather_evidence(session.question, state=env_state) + + response = await gather_evidence_tool.gather_evidence( + session.question, state=env_state + ) if callback_type == "async": gather_evidence_initialized_callback.assert_awaited_once_with(env_state) gather_evidence_completed_callback.assert_awaited_once_with(env_state) + # ensure 1 piece of top evidence is returned + assert "\n1." in response, "gather_evidence did not return any results" + assert ( + "\n2." not in response + ), "gather_evidence should return only 1 context, not 2" + + # now adjust to give the agent 2x pieces of evidence + gather_evidence_tool.settings.agent.agent_evidence_n = 2 + response = await gather_evidence_tool.gather_evidence( + session.question, state=env_state + ) + # ensure both evidences are returned + assert "\n1." in response, "gather_evidence did not return any results" + assert "\n2." in response, "gather_evidence should return 2 contexts" + assert session.contexts, "Evidence did not return any results" assert not session.answer, "Expected no answer yet"