Skip to content

Commit

Permalink
update cache and drawer
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Dec 10, 2024
1 parent 7ec66f8 commit 3c51120
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 53 deletions.
28 changes: 14 additions & 14 deletions examples/paper_provable_scaling_law/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ def load_knockout(
return {}
return json.load(open(knockout_file, "r", encoding="utf-8"))


def save_knockout_stats(
self,
stats: dict,
n: int,
k: int,
category: str,
) -> None:
knockout_stats_dir = os.path.join(
self.competition_dir,
Expand All @@ -210,24 +210,24 @@ def save_knockout_stats(
with open(
os.path.join(
knockout_stats_dir,
f"{n}_{k}.json",
f"{category}_{n}_{k}.json",
),
"w",
encoding="utf-8",
) as f:
json.dump(stats, f, ensure_ascii=False, indent=2)

def load_knockout_stats(
self,
n: int,
k: int,
self, n: int, k: int, categories: List[str]
) -> dict:
knockout_stats_file = os.path.join(
self.competition_dir,
"knockout_stats",
f"{n}_{k}.json",
)
return {} if not os.path.exists(knockout_stats_file) else json.load(
open(knockout_stats_file, "r", encoding="utf-8")
)

result = {}
for category in categories:
knockout_stats_file = os.path.join(
self.competition_dir,
"knockout_stats",
f"{category}_{n}_{k}.json",
)
result[category] = json.load(
open(knockout_stats_file, "r", encoding="utf-8")
)
return result
184 changes: 155 additions & 29 deletions examples/paper_provable_scaling_law/utils/competition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Competition module."""
from __future__ import annotations
from abc import abstractmethod
from tqdm import tqdm
from typing import List
from loguru import logger
from agentscope.rpc import async_func, RpcMeta
Expand Down Expand Up @@ -49,7 +48,14 @@ def calculate_stats(


class Knockout(Competition):
def __init__(self, judge: MixedJudge, cache: Cache, n: int, k: int):
def __init__(
self,
judge: MixedJudge,
cache: Cache,
n: int,
k: int,
skip_same: bool = True,
):
"""
Args:
n (`int`): the number of candidates
Expand All @@ -58,6 +64,7 @@ def __init__(self, judge: MixedJudge, cache: Cache, n: int, k: int):
super().__init__(judge, cache)
self.n = n
self.k = k
self.skip_same = skip_same

def competition(
self,
Expand All @@ -83,43 +90,60 @@ def competition(
"final": None,
"detail": {},
}
all_same = False
while len(candidates) > 1:
round_num += 1
winners = []
if len(candidates) % 2 == 1:
winners.append(candidates[-1])
pairs = []
if not all_same or self.skip_same:
seen_answers = set()
for i in range(len(candidates)):
seen_answers.add(candidates[i]["answer"])
if len(seen_answers) == 1:
all_same = True
for i in range(1, len(candidates), 2):
# pair-wise compare
pairs.append(
self.judge.pairwise_compare(
question=question,
candidate_a=candidates[i - 1],
candidate_b=candidates[i],
k=self.k,
),
)
if all_same:
pairs.append(None)
else:
pairs.append(
self.judge.pairwise_compare(
question=question,
candidate_a=candidates[i - 1],
candidate_b=candidates[i],
k=self.k,
),
)
rounds_detail = []
for i, pair in tqdm(
enumerate(pairs),
desc=f"Round {round_num}",
total=len(pairs),
leave=False,
):
pair = pair.result()
rounds_detail.append(
{
"winner": pair["winner"],
"a": pair["a"],
"b": pair["b"],
"score_a": pair["score_a"],
"score_b": pair["score_b"],
},
)
if pair["winner"] == candidates[i * 2]["cid"]:
for i, pair in enumerate(pairs):
if all_same:
rounds_detail.append(
{
"winner": candidates[i * 2]["cid"],
"a": candidates[i * 2]["cid"],
"b": candidates[i * 2 + 1]["cid"],
"score_a": 0,
"score_b": 0,
},
)
winners.append(candidates[i * 2])
else:
winners.append(candidates[i * 2 + 1])
pair = pair.result()
rounds_detail.append(
{
"winner": pair["winner"],
"a": pair["a"],
"b": pair["b"],
"score_a": pair["score_a"],
"score_b": pair["score_b"],
},
)
if pair["winner"] == candidates[i * 2]["cid"]:
winners.append(candidates[i * 2])
else:
winners.append(candidates[i * 2 + 1])
knockout_traj["detail"][f"round_{round_num}"] = rounds_detail
candidates = winners
logger.info(f"Round {round_num} done")
Expand Down Expand Up @@ -218,10 +242,112 @@ def calculate_stats(
category_stats[category]["acc"][
candidate_num
] /= category_stats[category]["cnt"]
self.cache.save_knockout_stats(category_stats, n, k)
self.cache.save_knockout_stats(
category_stats[category], n, k, category
)
logger.info("Finished calculating knockout stats")


class UCB(Competition):
def __init__(
self,
judge: MixedJudge,
cache: Cache,
n: int,
k: int,
t: int,
n_opponent: int,
c_bonous: float,
):
super().__init__(judge, cache)
self.n = n
self.k = k
self.t = t
self.n_opponent = n_opponent
self.c_bonous = c_bonous

def competition(
self,
question: dict,
candidates: List[dict],
) -> dict:
"""Run ucb competition."""
import numpy as np

candidates = candidates[: self.n]
ucb = np.ones(self.n, dtype=np.float64)
lcb = np.zeros(self.n, dtype=np.float64)
avg_win_rate = np.full(self.n, 0.5, dtype=np.float64)
win_cnt_matrix = np.zeros((self.n, self.n), dtype=np.float64)
lose_cnt_matrix = np.zeros((self.n, self.n), dtype=np.float64)
# whether the candidate is active or not
active_signal = np.ones(self.n, dtype=np.bool_)
for t in range(self.t):
# top_id = np.argmax(ucb + np.random.randn(self.n) * 1e-8 + (active_signal - 1) * 10)
# find activate candidate id where active_signal == 1
active_candidate_ids = np.where(active_signal)[0]
for idx in active_candidate_ids:
opponent_num = self.n_opponent
candidate_opponent_list = [
x for x in active_candidate_ids if x != idx
]
opponent_list = []
while opponent_num > len(candidate_opponent_list):
opponent_list.extend(candidate_opponent_list)
opponent_num -= len(candidate_opponent_list)
if opponent_num > 0:
opponent_list.extend(
np.random.choice(
candidate_opponent_list,
size=opponent_num,
replace=False,
)
)
futures = []
for opponent_id in opponent_list:
future = self.judge.pairwise_compare(
question,
candidates[idx]["raw"],
candidates[opponent_id]["raw"],
k=self.k,
)
futures.append(future)
for future in futures:
result = future.result()
win_cnt_matrix[idx][opponent_id] += result["score_a"]
lose_cnt_matrix[idx][opponent_id] += result["score_b"]

while True:
for idx in np.where(active_signal)[0]:
total_win_count = np.sum(
win_cnt_matrix[idx] * active_signal
)
total_lose_count = np.sum(
lose_cnt_matrix[idx] * active_signal
)
total_count = total_win_count + total_lose_count
if total_count >= 1:
avg_win_rate[idx] = total_win_count / total_count
bonus = np.sqrt(self.c_bonous / total_count)
ucb[idx] = min(avg_win_rate[idx] + bonus, 1.0)
lcb[idx] = max(avg_win_rate[idx] - bonus, 0.0)
max_lcb = np.max(lcb * active_signal)
update_active_signal = False
for idx in np.where(active_signal)[0]:
if ucb[idx] < max_lcb:
active_signal[idx] = False
update_active_signal = True
if not update_active_signal:
break

seen_answers = set()
for idx in np.where(active_signal)[0]:
if candidates[idx]["answer"] not in seen_answers:
seen_answers.add(candidates[idx]["answer"])
if len(seen_answers) <= 1:
break


class League(Competition):
def __init__(self, judge: MixedJudge, cache: Cache, n: int, k: int):
"""
Expand Down
56 changes: 46 additions & 10 deletions examples/paper_provable_scaling_law/utils/drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def draw_acc(
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(n=config["n"], k=config["k"])
).load_knockout_stats(
n=config["n"], k=config["k"], categories=categories
)
for config in configs
]
for category in categories:
Expand All @@ -37,7 +39,9 @@ def draw_acc(
color=configs[i]["color"],
)
ax.set_title(f"{dataset_name}: {category}")
ax.grid(True, linestyle="dashed", linewidth=1, color="gray", alpha=0.5)
ax.grid(
True, linestyle="dashed", linewidth=1, color="gray", alpha=0.5
)
ax.set_xlabel("N")
ax.set_ylabel("Accuracy")
ax.legend(
Expand Down Expand Up @@ -74,18 +78,26 @@ def draw_p_cmp(
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(n=config["n"], k=config["k"])
).load_knockout_stats(
n=config["n"], k=config["k"], categories=categories
)
for config in configs
]
for i, stats in enumerate(run_stats):
for category in categories:
all_correct_cnt = 0
all_wrong_cnt = 0
fig, ax = plt.subplots(figsize=(3.5, 3))
p_gens = []
p_cmps = []
for qid, stat in stats[category]["details"].items():
if stat["cmp"]["valid"] > 0:
p_gens.append(stat["acc"]["1"])
p_cmps.append(stat["cmp"]["p_cmp"])
if stat["acc"]["1"] == 0:
all_wrong_cnt += 1
if stat["acc"]["1"] == 1:
all_correct_cnt += 1
ax.scatter(
p_gens,
p_cmps,
Expand All @@ -95,28 +107,52 @@ def draw_p_cmp(
)
above_count = sum(1 for p_cmp in p_cmps if p_cmp > 0.5)
below_count = sum(1 for p_cmp in p_cmps if p_cmp <= 0.5)
ax.set_title(f"{dataset_name}: {category}")
ax.set_title(
f"({configs[i]['label']}) {dataset_name}: {category}"
)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.10, 1.10)
ax.set_xlabel("$P_{gen}$")
ax.set_ylabel("$P_{comp}$")
ax.axhline(y=0.5, color="black", linestyle="dotted", linewidth=1.0)
ax.grid(True, linestyle="dashed", linewidth=1, color="gray", alpha=0.5)
ax.axhline(
y=0.5, color="black", linestyle="dotted", linewidth=1.0
)
ax.grid(
True,
linestyle="dashed",
linewidth=1,
color="gray",
alpha=0.5,
)
ax.text(
-0.05,
1.05,
"#Above = " + str(above_count),
1.03,
"#[$P_{comp}>0.5$] = " + str(above_count),
fontsize=9,
verticalalignment="center",
)
ax.text(
-0.05,
-0.05,
"#Below = " + str(below_count),
"#[$P_{comp}≤0.5$] = " + str(below_count),
fontsize=9,
verticalalignment="center",
)

ax.text(
-0.05,
-0.33,
"#[$P_{gen}$=0] = " + str(all_wrong_cnt),
fontsize=9,
verticalalignment="center",
)
ax.text(
0.65,
-0.33,
"#[$P_{gen}$=1] = " + str(all_correct_cnt),
fontsize=9,
verticalalignment="center",
)
ax.legend(loc="upper right")
plt.tight_layout()
plt.savefig(
os.path.join(
Expand Down

0 comments on commit 3c51120

Please sign in to comment.