Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Dec 6, 2024
1 parent a5c0ce7 commit 2e7f59d
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 41 deletions.
Empty file.
9 changes: 4 additions & 5 deletions examples/paper_provable_scaling_law/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
"""
from __future__ import annotations
import json
import random
import argparse
from loguru import logger
import agentscope
from agentscope.server import RpcAgentServerLauncher

from utils.dataset import Dataset
from utils.cache import Cache
from utils.worker import MixedGenerator, MixedJudge, Generator, Judge
from competition import Competition

import agentscope
from agentscope.server import RpcAgentServerLauncher


def run(
generator: MixedGenerator,
Expand Down Expand Up @@ -129,7 +128,7 @@ def main(conf: dict) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str)
args = parser.parse_args()
config = json.load(open(args.config, "r"))
config = json.load(open(args.config, "r", encoding="utf-8"))
agentscope.init(
project=config["project"],
model_configs=config["models"],
Expand Down
24 changes: 19 additions & 5 deletions examples/paper_provable_scaling_law/competition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
"""Competition module."""
from tqdm import tqdm
from typing import List
from loguru import logger
Expand Down Expand Up @@ -32,6 +33,8 @@ def competition(
"""
if method == "knockout":
return self.knockout(question, candidates, **kwargs)
elif method == "league":
return self.league(question, candidates, **kwargs)
else:
raise NotImplementedError

Expand All @@ -57,13 +60,13 @@ def knockout(
k=k,
category=question["category"],
)
round = 0
round_num = 0
knockout_traj = {
"final": None,
"detail": {},
}
while len(candidates) > 1:
round += 1
round_num += 1
winners = []
if len(candidates) % 2 == 1:
winners.append(candidates[-1])
Expand All @@ -81,7 +84,7 @@ def knockout(
rounds_detail = []
for i, pair in tqdm(
enumerate(pairs),
desc=f"Round {round}",
desc=f"Round {round_num}",
position=1,
total=len(pairs),
):
Expand All @@ -99,9 +102,9 @@ def knockout(
winners.append(candidates[i * 2])
else:
winners.append(candidates[i * 2 + 1])
knockout_traj["detail"][f"round_{round}"] = rounds_detail
knockout_traj["detail"][f"round_{round_num}"] = rounds_detail
candidates = winners
logger.info(f"Round {round} done")
logger.info(f"Round {round_num} done")
knockout_traj["final"] = candidates[0]
self.cache.save_knockout(
detail=knockout_traj,
Expand All @@ -111,3 +114,14 @@ def knockout(
category=question["category"],
)
return candidates[0]

def league(
self,
question: dict,
candidates: List[dict],
candidate_num: int,
k: int,
) -> dict:
"""Run league competition."""
# TBD
pass
79 changes: 79 additions & 0 deletions examples/paper_provable_scaling_law/configs/example_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
{
"servers": [
{
"host": "localhost",
"port": null
},
{
"host": "localhost",
"port": null
}
],
"models": [
{
"model_type": "openai_chat",
"config_name": "qwen2.5",
"model_name": "",
"api_key": "EMPTY",
"client_args": {
"base_url": "http://127.0.0.1:8010/v1/"
},
"generate_args": {
"temperature": 0.1,
"max_tokens": 2048
}
},
{
"model_type": "openai_chat",
"config_name": "llama3.1",
"model_name": "",
"api_key": "EMPTY",
"client_args": {
"base_url": "http://127.0.0.1:8011/v1/"
},
"generate_args": {
"temperature": 0.1,
"max_tokens": 2048
}
}
],
"project": "provable_scaling_law",
"job": "qwen_llama_cot",
"dataset": {
"name": "mmlu_pro",
"max_instance": 6,
"categories": [
"math",
"physics"
]
},
"generate": {
"workers": [
{
"type": "mmlu_pro",
"model": "qwen2.5"
},
{
"type": "mmlu_pro",
"model": "llama3.1"
}
]
},
"judge": {
"workers": [
{
"type": "mmlu_pro",
"model": "qwen2.5"
},
{
"type": "mmlu_pro",
"model": "llama3.1"
}
]
},
"competition": {
"method": "knockout",
"candidate_num": 16,
"k": 4
}
}
16 changes: 14 additions & 2 deletions examples/paper_provable_scaling_law/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
"""Cache module."""
import os
import json
from typing import List
Expand All @@ -8,6 +9,8 @@


class Cache:
"""A cache for storing and loading data"""

def __init__(
self,
project_name: str,
Expand All @@ -32,8 +35,13 @@ def __init__(
self.data_dir,
"comparison",
)
self.competition_dir = os.path.join(
self.data_dir,
"competition",
)
os.makedirs(self.generation_dir, exist_ok=True)
os.makedirs(self.comparsion_dir, exist_ok=True)
os.makedirs(self.competition_dir, exist_ok=True)

def save_generation_stats(
self,
Expand All @@ -51,6 +59,7 @@ def save_generation_stats(
f"{instance_id}.json",
),
"w",
encoding="utf-8",
),
ensure_ascii=False,
indent=2,
Expand All @@ -74,6 +83,7 @@ def save_generation(
f"{instance_id}.jsonl",
),
"w",
encoding="utf-8",
) as f:
for i, c in enumerate(candidates):
c["cid"] = i
Expand Down Expand Up @@ -114,6 +124,7 @@ def save_pairwise_comparison(
with open(
os.path.join(pairwise_dir, f"{cid_a}-{cid_b}.json"),
"w",
encoding="utf-8",
) as f:
json.dump(detail, f, ensure_ascii=False, indent=2)

Expand Down Expand Up @@ -143,11 +154,12 @@ def save_knockout(
k: int,
category: str = DEFAULT_CATEGORY,
) -> None:
knockout_dir = os.path.join(self.comparsion_dir, "knockout", category)
knockout_dir = os.path.join(self.competition_dir, "knockout", category)
os.makedirs(knockout_dir, exist_ok=True)
with open(
os.path.join(knockout_dir, f"{instance_id}_{n}_{k}.json"),
"w",
encoding="utf-8",
) as f:
json.dump(detail, f, ensure_ascii=False, indent=2)

Expand All @@ -159,7 +171,7 @@ def load_knockout(
category: str = DEFAULT_CATEGORY,
) -> dict:
knockout_file = os.path.join(
self.comparsion_dir,
self.competition_dir,
"knockout",
category,
f"{instance_id}_{n}_{k}.json",
Expand Down
31 changes: 17 additions & 14 deletions examples/paper_provable_scaling_law/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# preprocess and load datasets
"""preprocess and load datasets"""

from __future__ import annotations
import os
Expand All @@ -14,19 +14,21 @@


class Dataset(ABC):
"""Base class for datasets."""

@classmethod
@abstractmethod
def preprocess(cls) -> None:
pass
"""Preprocess the dataset."""

@classmethod
@abstractmethod
def format_sample(cls, sample: dict) -> dict:
"""Format a sample into a dict."""
pass

@classmethod
def from_dict(cls, config: dict) -> Dataset:
"""Load a dataset from a config."""
if config["name"] == "mmlu_pro":
return MMLUPro(
max_instance=config["max_instance"],
Expand All @@ -37,10 +39,12 @@ def from_dict(cls, config: dict) -> Dataset:

@abstractmethod
def calculate_stats(self, sample: dict, candidates: List[dict]) -> dict:
pass
"""Calculate statistics for a sample and its candidates."""


class MMLUPro(Dataset):
"""MMLU-Pro dataset."""

PROMPT_TEMPLATE = """
Question: {question}
Options:
Expand All @@ -52,6 +56,11 @@ class MMLUPro(Dataset):
def __init__(self, categories: List[str], max_instance: int):
self.categories = categories
self.max_instance = max_instance
self.cur_category_index = 0
self.cur_instance_index = 0
self.total_samples = 0
self.samples = []
self.pbar = None

@classmethod
def preprocess(cls) -> None:
Expand All @@ -72,13 +81,13 @@ def preprocess(cls) -> None:
vali_filtered = ds["validation"].filter(
lambda example: example["category"] == category,
)
category = category.replace(" ", "_").lower()
ct = category.replace(" ", "_").lower()
test_filtered.to_json(
os.path.join(
DATASET_DIR,
"mmlu_pro",
"test",
f"{category}.jsonl",
f"{ct}.jsonl",
),
lines=True,
force_ascii=False,
Expand All @@ -88,12 +97,12 @@ def preprocess(cls) -> None:
DATASET_DIR,
"mmlu_pro",
"validation",
f"{category}.jsonl",
f"{ct}.jsonl",
),
lines=True,
force_ascii=False,
)
print(f"Saved test and validation data for category: {category}")
print(f"Saved test and validation data for category: {ct}")

@classmethod
def format_sample(cls, sample: dict) -> dict:
Expand Down Expand Up @@ -185,9 +194,3 @@ def __next__(self) -> dict:
self.cur_instance_index += 1
self.pbar.update(1)
return self.format_sample(sample)


if __name__ == "__main__":
samples = MMLUPro.load(category="physics", max_instance=1)
print(samples[0])
print(MMLUPro.format_sample(samples[0])["question"])
Loading

0 comments on commit 2e7f59d

Please sign in to comment.