Skip to content

Commit

Permalink
Fixing gather_evidence and complete response messages (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 15, 2025
1 parent cf377fc commit 7bb570c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 32 deletions.
15 changes: 6 additions & 9 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def make_clinical_trial_status(


def clinical_trial_status(state: "EnvironmentState") -> str:
relevant_contexts = state.get_relevant_contexts()
return make_clinical_trial_status(
total_paper_count=len(
{
Expand All @@ -172,9 +173,8 @@ def clinical_trial_status(state: "EnvironmentState") -> str:
relevant_paper_count=len(
{
c.text.doc.dockey
for c in state.session.contexts
if c.score > state.RELEVANT_SCORE_CUTOFF
and CLINICAL_TRIALS_BASE
for c in relevant_contexts
if CLINICAL_TRIALS_BASE
not in getattr(c.text.doc, "other", {}).get("client_source", [])
}
),
Expand All @@ -189,15 +189,12 @@ def clinical_trial_status(state: "EnvironmentState") -> str:
relevant_clinical_trials=len(
{
c.text.doc.dockey
for c in state.session.contexts
if c.score > state.RELEVANT_SCORE_CUTOFF
and CLINICAL_TRIALS_BASE
for c in relevant_contexts
if CLINICAL_TRIALS_BASE
in getattr(c.text.doc, "other", {}).get("client_source", [])
}
),
evidence_count=len(
[c for c in state.session.contexts if c.score > state.RELEVANT_SCORE_CUTOFF]
),
evidence_count=len(relevant_contexts),
cost=state.session.cost,
)

Expand Down
39 changes: 16 additions & 23 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paperqa.docs import Docs
from paperqa.settings import Settings
from paperqa.sources.clinical_trials import add_clinical_trials_to_docs
from paperqa.types import DocDetails, PQASession
from paperqa.types import Context, DocDetails, PQASession

from .search import get_directory_index

Expand All @@ -35,18 +35,11 @@ def make_status(


def default_status(state: "EnvironmentState") -> str:
relevant_contexts = state.get_relevant_contexts()
return make_status(
total_paper_count=len(state.docs.docs),
relevant_paper_count=len(
{
c.text.doc.dockey
for c in state.session.contexts
if c.score > state.RELEVANT_SCORE_CUTOFF
}
),
evidence_count=len(
[c for c in state.session.contexts if c.score > state.RELEVANT_SCORE_CUTOFF]
),
relevant_paper_count=len({c.text.doc.dockey for c in relevant_contexts}),
evidence_count=len(relevant_contexts),
cost=state.session.cost,
)

Expand Down Expand Up @@ -80,6 +73,11 @@ def status(self) -> str:
return self.status_fn(cast(Self, self))
return default_status(self)

def get_relevant_contexts(self) -> list[Context]:
return [
c for c in self.session.contexts if c.score > self.RELEVANT_SCORE_CUTOFF
]

def record_action(self, action: ToolRequestMessage) -> None:
self.session.add_tokens(action)
self.session.tool_history.append([tc.function.name for tc in action.tool_calls])
Expand Down Expand Up @@ -227,7 +225,8 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:

logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.")
original_question = state.session.question
l1_all = l1_relevant = l0 = len(state.session.contexts)
l1 = l0 = len(state.session.contexts)
l1_relevant = l0_relevant = len(state.get_relevant_contexts())

try:
# Swap out the question with the more specific question
Expand All @@ -245,14 +244,8 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
f"{self.TOOL_FN_NAME}_aget_evidence"
),
)
l1_all = len(state.session.contexts)
l1_relevant = len(
[
c
for c in state.session.contexts
if c.score > state.RELEVANT_SCORE_CUTOFF
]
)
l1 = len(state.session.contexts)
l1_relevant = len(state.get_relevant_contexts())
finally:
state.session.question = original_question

Expand Down Expand Up @@ -284,7 +277,7 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
)

return (
f"Added {l1_all - l0} pieces of evidence, {l1_relevant - l0} of which were"
f"Added {l1 - l0} pieces of evidence, {l1_relevant - l0_relevant} of which were"
f" relevant.{best_evidence}\n\n" + status
)

Expand Down Expand Up @@ -412,10 +405,10 @@ async def complete(

logger.info(
f"Completing '{state.session.question}' as"
f" '{'a success' if has_successful_answer else 'unsure'}'."
f" '{'certain' if has_successful_answer else 'unsure'}'."
)
# Return answer and status to simplify postprocessing of tool response
return f"{'Success' if has_successful_answer else 'Unsure'} | {state.status}"
return f"{'Certain' if has_successful_answer else 'Unsure'} | {state.status}"


class ClinicalTrialsSearch(NamedTool):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,17 @@ def new_status(state: EnvironmentState) -> str:
gather_evidence_initialized_callback.assert_awaited_once_with(env_state)
gather_evidence_completed_callback.assert_awaited_once_with(env_state)

split = re.split(
r"(\d+) pieces of evidence, (\d+) of which were relevant",
response,
maxsplit=1,
)
assert len(split) == 4, "Unexpected response shape"
total_added_1, relevant_added_1 = int(split[1]), int(split[2])
assert all(
x >= 0 for x in (total_added_1, relevant_added_1)
), "Expected non-negative counts"
assert len(env_state.get_relevant_contexts()) == relevant_added_1
# ensure 1 piece of top evidence is returned
assert "\n1." in response, "gather_evidence did not return any results"
assert (
Expand All @@ -591,6 +602,21 @@ def new_status(state: EnvironmentState) -> str:
response = await gather_evidence_tool.gather_evidence(
session.question, state=env_state
)

split = re.split(
r"(\d+) pieces of evidence, (\d+) of which were relevant",
response,
maxsplit=1,
)
assert len(split) == 4, "Unexpected response shape"
total_added_2, relevant_added_2 = int(split[1]), int(split[2])
assert all(
x >= 0 for x in (total_added_2, relevant_added_2)
), "Expected non-negative counts"
assert (
len(env_state.get_relevant_contexts())
== relevant_added_1 + relevant_added_2
)
# ensure both evidences are returned
assert "\n1." in response, "gather_evidence did not return any results"
assert "\n2." in response, "gather_evidence should return 2 contexts"
Expand Down

0 comments on commit 7bb570c

Please sign in to comment.