Skip to content

Commit

Permalink
Added aviary-paper-data to the splits so there is now a test split
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Jan 9, 2025
1 parent 8b0ed16 commit acbb08b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
11 changes: 8 additions & 3 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_MAPPING,
read_litqa_v2_from_hub,
Expand Down Expand Up @@ -408,6 +409,7 @@ def compute_trajectory_metrics(
class LitQAv2TaskSplit(StrEnum):
TRAIN = "train"
EVAL = "eval"
TEST = "test"


class LitQAv2TaskDataset(LitQATaskDataset):
Expand All @@ -416,20 +418,23 @@ class LitQAv2TaskDataset(LitQATaskDataset):
def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
train_eval_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
test_dataset: str = DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
read_data_kwargs: Mapping[str, Any] | None = None,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(
labbench_dataset, **(read_data_kwargs or {})
train_df, eval_df, test_df = read_litqa_v2_from_hub(
train_eval_dataset, test_dataset, **(read_data_kwargs or {})
)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
elif split == LitQAv2TaskSplit.EVAL:
self.data = eval_df
elif split == LitQAv2TaskSplit.TEST:
self.data = test_df
else:
assert_never(split)

Expand Down
30 changes: 19 additions & 11 deletions paperqa/litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,30 @@ def make_discounted_returns(


DEFAULT_LABBENCH_HF_HUB_NAME = "futurehouse/lab-bench"
# Test split from Aviary paper's section 4.3: https://doi.org/10.48550/arXiv.2412.21154
DEFAULT_AVIARY_PAPER_HF_HUB_NAME = "futurehouse/aviary-paper-data"


def read_litqa_v2_from_hub(
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
train_eval_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
test_dataset: str = DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
randomize: bool = True,
seed: int | None = None,
train_eval_split: float = 0.8,
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Read LitQA v2 JSONL into train and eval DataFrames.
Read LitQA v2 JSONL into train, eval, and test DataFrames.
Args:
labbench_dataset: The Hugging Face Hub dataset's name corresponding with the
LAB-Bench dataset.
train_eval_dataset: Hugging Face Hub dataset's name corresponding with train
and eval splits.
test_dataset: Hugging Face Hub dataset's name corresponding with a test split.
randomize: Opt-out flag to shuffle the dataset after loading in by question.
seed: Random seed to use for the shuffling.
train_eval_split: Train/eval split fraction, default is 80% train 20% eval.
Raises:
DatasetNotFoundError: If the LAB-Bench dataset is not found, or the
DatasetNotFoundError: If any of the datasets are not found, or the
user is unauthenticated.
"""
try:
Expand All @@ -67,9 +71,13 @@ def read_litqa_v2_from_hub(
" `pip install paper-qa[datasets]`."
) from exc

litqa_v2 = load_dataset(labbench_dataset, "LitQA2")["train"].to_pandas()
litqa_v2["distractors"] = litqa_v2["distractors"].apply(list)
train_eval = load_dataset(train_eval_dataset, "LitQA2")["train"].to_pandas()
test = load_dataset(test_dataset, "LitQA2")["test"].to_pandas()
# Convert to list so it's not unexpectedly a numpy array
train_eval["distractors"] = train_eval["distractors"].apply(list)
test["distractors"] = test["distractors"].apply(list)
if randomize:
litqa_v2 = litqa_v2.sample(frac=1, random_state=seed)
num_train = int(len(litqa_v2) * train_eval_split)
return litqa_v2[:num_train], litqa_v2[num_train:]
train_eval = train_eval.sample(frac=1, random_state=seed)
test = test.sample(frac=1, random_state=seed)
num_train = int(len(train_eval) * train_eval_split)
return train_eval[:num_train], train_eval[num_train:], test
2 changes: 1 addition & 1 deletion tests/test_litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_make_discounted_returns(

def test_creating_litqa_questions() -> None:
"""Test making LitQA eval questions after downloading from Hugging Face Hub."""
_, eval_split = read_litqa_v2_from_hub(seed=42)
eval_split = read_litqa_v2_from_hub(seed=42)[1]
assert len(eval_split) > 3
assert [
MultipleChoiceQuestion(
Expand Down
6 changes: 5 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ class TestTaskDataset:

@pytest.mark.parametrize(
("split", "expected_length"),
[(LitQAv2TaskSplit.TRAIN, 159), (LitQAv2TaskSplit.EVAL, 40)],
[
(LitQAv2TaskSplit.TRAIN, 159),
(LitQAv2TaskSplit.EVAL, 40),
(LitQAv2TaskSplit.TEST, 49),
],
)
@pytest.mark.asyncio
async def test___len__(
Expand Down

0 comments on commit acbb08b

Please sign in to comment.