Skip to content

Commit

Permalink
Moved to MultipleChoiceQuestion/MultipleChoiceEvaluation from `av…
Browse files Browse the repository at this point in the history
…iary` (#768)
  • Loading branch information
jamesbraza authored Jan 5, 2025
1 parent 48a6167 commit 8f73d9e
Show file tree
Hide file tree
Showing 11 changed files with 308 additions and 578 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ repos:
- aiohttp>=3.10.6 # Match pyproject.toml
- PyMuPDF>=1.24.12
- anyio
- fhaviary[llm]>=0.10.2 # Match pyproject.toml
- ldp>=0.14.5 # Match pyproject.toml
- fhaviary[llm]>=0.14 # Match pyproject.toml
- ldp>=0.17 # Match pyproject.toml
- html2text
- fh-llm-client
- httpx
Expand Down
9 changes: 8 additions & 1 deletion paperqa/_ldp_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"UIndexMemoryModel",
"_Memories",
"discounted_returns",
"evaluate_consensus",
"set_training_mode",
]

Expand All @@ -29,7 +30,12 @@
SimpleAgent,
SimpleAgentState,
)
from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager
from ldp.alg import (
Callback,
ComputeTrajectoryMetricsMixin,
RolloutManager,
evaluate_consensus,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.utils import discounted_returns
Expand Down Expand Up @@ -57,4 +63,5 @@ class Callback: # type: ignore[no-redef]
SimpleAgentState = None # type: ignore[assignment,misc]
UIndexMemoryModel = None # type: ignore[assignment,misc]
discounted_returns = None # type: ignore[assignment]
evaluate_consensus = None # type: ignore[assignment]
set_training_mode = None # type: ignore[assignment]
4 changes: 3 additions & 1 deletion paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.utils import MultipleChoiceQuestion
from llmclient import EmbeddingModel, LiteLLMModel

from paperqa.docs import Docs
Expand Down Expand Up @@ -238,10 +239,11 @@ def make_initial_state(self) -> EnvironmentState:
):
status_fn = clinical_trial_status

query: str | MultipleChoiceQuestion = self._query.query
return EnvironmentState(
docs=self._docs,
session=PQASession(
question=self._query.query,
question=query if isinstance(query, str) else query.question_prompt,
config_md5=self._query.settings.md5,
id=self._query.id,
),
Expand Down
9 changes: 8 additions & 1 deletion paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, ClassVar, Protocol
from uuid import UUID, uuid4

from aviary.utils import MultipleChoiceQuestion
from llmclient import LiteLLMModel, LLMModel
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -55,7 +56,13 @@ class MismatchedModelsError(Exception):
class QueryRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

query: str = ""
query: str | MultipleChoiceQuestion = Field(
default="",
description=(
"The query to be answered. Set to a multiple choice question when grading"
" (e.g. for training)."
),
)
id: UUID = Field(
default_factory=uuid4,
description="Identifier which will be propagated to the Answer object.",
Expand Down
184 changes: 147 additions & 37 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,49 @@
import logging
import re
from abc import ABC
from collections.abc import Awaitable, Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Self, assert_never
from typing import TYPE_CHECKING, Any, Self, assert_never, cast

from aviary.core import (
TASK_DATASET_REGISTRY,
Environment,
Frame,
Messages,
TaskDataset,
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY
from aviary.utils import (
DEFAULT_EVAL_MODEL_NAME,
MultipleChoiceEvaluation,
MultipleChoiceQuestion,
)
from llmclient import EmbeddingModel, LiteLLMModel, LLMModel

from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin
from paperqa._ldp_shims import (
Callback,
ComputeTrajectoryMetricsMixin,
evaluate_consensus,
)
from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_EVAL_MODEL_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_MAPPING,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from paperqa.types import DocDetails, PQASession

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .search import SearchIndex, maybe_get_manifest
from .tools import Complete
from .tools import Complete, EnvironmentState

if TYPE_CHECKING:
from ldp.data_structures import Trajectory
from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition

logger = logging.getLogger(__name__)

Expand All @@ -58,26 +67,22 @@ def __init__(
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[PQASession | str], Awaitable[LitQAEvaluation]] | None
) = None,
sources: str | list[str] | None = None,
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
evaluation_callback: (
Callable[[MultipleChoiceEvaluation], Awaitable] | None
) = None,
**env_kwargs,
):
super().__init__(
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
)
self._evaluation_from_answer = evaluation_from_answer
# Enables checking an Index has the right DOI(s)
self.sources: list[str] | None = (
[sources] if isinstance(sources, str) else sources
)
self._evaluation_callback = evaluation_callback
self._rewards = rewards
self.answer = ""
self.ideal = ""

async def validate_sources(
self, manifest_or_index: dict[str, DocDetails] | SearchIndex | None = None
Expand Down Expand Up @@ -120,7 +125,7 @@ async def step(
self, action: ToolRequestMessage
) -> tuple[Messages, float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
if not done or not isinstance(self._query.query, MultipleChoiceQuestion):
return messages, reward, done, truncated
# If the ensuring evaluation fails (e.g. due to OpenAI being down), we can:
# - Suppress the exception and declare the evaluation as incorrect, which can
Expand All @@ -130,23 +135,13 @@ async def step(
# incorrectly reward what otherwise was a good trajectory.
# - Don't suppress the exception, which leads to the trajectory failing, and
# removes it from the learnable pool. This is the only safe default behavior.
evaluation = await self._evaluation_from_answer(self.state.session.answer)
evaluation, self.state.session.graded_answer = await self._query.query.grade(
self.state.session.answer
)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
self.answer = evaluation.answer or ""
self.ideal = evaluation.ideal or ""
return messages, reward + self._rewards[evaluation.value], done, truncated

def export_frame(self) -> Frame:
return Frame(
state=self.state,
info={
"query": self._query,
"answer": self.answer,
"ideal": self.ideal,
},
)

def __deepcopy__(self, memo) -> Self:
copy_state = deepcopy(self.state, memo)
# We don't know the side effects of deep copying a litellm.Router,
Expand All @@ -162,7 +157,6 @@ def __deepcopy__(self, memo) -> Self:
copy_self = type(self)(
query=deepcopy(self._query, memo), # deepcopy for _docs_name
docs=copy_state.docs,
evaluation_from_answer=self._evaluation_from_answer,
sources=self.sources,
rewards=self._rewards,
evaluation_callback=self._evaluation_callback,
Expand All @@ -182,6 +176,120 @@ def __deepcopy__(self, memo) -> Self:
)


async def evaluate_consensus_sampling(
data: Iterable[GradablePaperQAEnvironment | Frame],
exclude_no_answer: bool = False,
num_samples: int = 1,
seed: int | None = None,
) -> tuple[dict[str, list[tuple[str, int]]], float]:
"""
Create consensus groups based on question and evaluate the consensus for each.
Args:
data: Data to evaluate consensus upon, either gradable environments or frames.
exclude_no_answer: Opt-in flag to filter out empty answers (due to the
Environment/Frame not having a graded answer). Use of this flag does not
affect the accuracy term of the return.
num_samples: Passed through to evaluate_consensus.
seed: Passed through to evaluate_consensus.
Returns:
Two-tuple of consensus list generated by collections.Counter.most_common (keys
are question, values are list of (answer, vote count)) and the proportion of
groups for which the consensus matches the ideal.
"""

def extract_question(x: GradablePaperQAEnvironment | Frame) -> str:
if isinstance(x, GradablePaperQAEnvironment):
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
else:
qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment]
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
if isinstance(query, str):
return query
if isinstance(query, MultipleChoiceQuestion):
return query.question_prompt
return query["question"]

def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str:
ses: PQASession | dict[str, Any] = (
x.state.session
if isinstance(x.state, EnvironmentState)
else cast(PQASession | dict[str, Any], x.state["session"]) # type: ignore[call-overload,index]
)
graded_answer = (
ses.graded_answer if isinstance(ses, PQASession) else ses["graded_answer"]
)
# One can filter the below empty string injection via the exclude_no_answer arg
return graded_answer or ""

def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str:
if isinstance(x, GradablePaperQAEnvironment):
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
else:
qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment]
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
if isinstance(query, str):
raise ValueError( # noqa: TRY004
f"We require a {MultipleChoiceQuestion.__name__} variant to extract"
" ideal answer, not a string."
)
if isinstance(query, MultipleChoiceQuestion):
return query.ideal_answer
return query["ideal_answer"]

try:
consensus, accuracy = await evaluate_consensus(
data=data,
grouping_fn=extract_question,
extract_answer_fn=extract_answer,
ideal_answer_fn=extract_ideal,
num_samples=num_samples,
seed=seed,
)
except TypeError:
raise ImportError(
"Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
) from None
if exclude_no_answer:
consensus = {
q: [(a, c) for a, c in answers if a] for q, answers in consensus.items()
}
return consensus, accuracy


class StoreForConsensusSamplingCallback(Callback):
"""Store environments or frames for later consensus sampling."""

def __init__(self):
super().__init__()
self.stored: list[GradablePaperQAEnvironment | Frame] = []

async def after_transition(
self,
traj_id: str, # noqa: ARG002
agent: "Agent", # noqa: ARG002
env: Environment,
transition: "Transition",
) -> None:
if not isinstance(env, GradablePaperQAEnvironment):
raise NotImplementedError(
f"So far only handled {GradablePaperQAEnvironment} in this callback,"
f" not {type(env)}."
)
if transition.done and not transition.failed: # Only store once
return
self.stored.append(env.export_frame())

async def evaluate_consensus_sampling(
self, num_samples: int = 1, seed: int | None = None
) -> tuple[dict[str, list[tuple[str, int]]], float]:
return await evaluate_consensus_sampling(
data=self.stored, num_samples=num_samples, seed=seed
)


class LitQATaskDataset(
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
Expand Down Expand Up @@ -218,24 +326,26 @@ def __init__(

def _make_gradable_environment(
self,
ideal: str,
ideal_answer: str,
distractors: str | list[str],
question: str,
sources: str | list[str] | None = None,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
mc_question = MultipleChoiceQuestion(
question=question,
eval_model=self._eval_model,
options=(
distractors
if isinstance(distractors, list)
else MultipleChoiceQuestion.split_options(distractors)
),
ideal_answer=ideal_answer,
**(self._question_kwargs or {}),
)
query = self._base_query.model_copy()
query.query = qa_prompt
query.query = mc_question
return GradablePaperQAEnvironment(
query=query,
docs=self._base_docs.model_copy(),
evaluation_from_answer=evaluation_from_answer,
sources=sources,
rewards=self._rewards,
**self._env_kwargs,
Expand Down Expand Up @@ -338,7 +448,7 @@ def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
) from exc
sources.append(doi)
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
ideal_answer=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
question=self.data.iloc[idx].question,
sources=sources,
Expand Down
Loading

0 comments on commit 8f73d9e

Please sign in to comment.