Skip to content

Commit

Permalink
update drawer
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Dec 18, 2024
1 parent c358cfb commit a8f9002
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 79 deletions.
1 change: 1 addition & 0 deletions examples/paper_provable_scaling_law/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def main(conf: dict) -> None:
for i, w in enumerate(conf["judgement"]["workers"])
],
cache=cache,
random=config["judgement"].get("random", True),
to_dist={
"host": master_launcher.host,
"port": master_launcher.port,
Expand Down
33 changes: 33 additions & 0 deletions examples/paper_provable_scaling_law/competitions/ucb.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def calculate_stats(self, dataset: Dataset):
"acc": {
"0": 0,
},
"pool_size": {
"0": 0,
},
"cnt": 0,
"details": {},
}
Expand All @@ -274,6 +277,8 @@ def calculate_stats(self, dataset: Dataset):
"avg": sum(1 for x in candidates if x["answer"] == target)
/ len(candidates),
}
question_stats["pool_size"] = {"0": self.n}
category_stats[question["category"]]["pool_size"]["0"] += self.n
question_stats["acc"]["0"] = question_stats["acc"]["avg"]
category_stats[question["category"]]["acc"]["0"] += question_stats[
"acc"
Expand Down Expand Up @@ -309,6 +314,15 @@ def calculate_stats(self, dataset: Dataset):
)
* active_signal
)
elif self.win_indicator == "lcb":
scores = (
np.array(
ucb_result["detail"][f"round_{round_num}"][
"lcb"
],
)
* active_signal
)
else:
scores = (
np.array(
Expand All @@ -322,20 +336,38 @@ def calculate_stats(self, dataset: Dataset):
final_ids = np.where(
np.isclose(scores, max_score, atol=1e-8),
)[0].tolist()
# final_ids = ucb_result["detail"][f"round_{round_num}"][
# "active_ids"
# ]
if (
str(round_num)
not in category_stats[question["category"]]["acc"]
):
category_stats[question["category"]]["acc"][
str(round_num)
] = 0
category_stats[question["category"]]["pool_size"][
str(round_num)
] = 0
question_stats["acc"][str(round_num)] = sum(
int(candidates[final_idx]["answer"] == target)
for final_idx in final_ids
) / len(final_ids)
if (
question_stats["acc"][str(round_num)] == 0
or question_stats["acc"][str(round_num)] == 1
):
question_stats["pool_size"][str(round_num)] = 1
else:
question_stats["pool_size"][str(round_num)] = len(
final_ids,
)
category_stats[question["category"]]["acc"][
str(round_num)
] += question_stats["acc"][str(round_num)]
category_stats[question["category"]]["pool_size"][
str(round_num)
] += question_stats["pool_size"][str(round_num)]
category_stats[question["category"]]["cnt"] += 1
question_stats["cmp"] = {
"valid": valid_cmp,
Expand All @@ -348,4 +380,5 @@ def calculate_stats(self, dataset: Dataset):
for category, stats in category_stats.items():
for t in stats["acc"]:
stats["acc"][t] /= stats["cnt"]
stats["pool_size"][t] /= stats["cnt"]
self.cache.save_ucb_stats(stats, n, k, t, category)
135 changes: 92 additions & 43 deletions examples/paper_provable_scaling_law/utils/drawer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import os
from collections import defaultdict
from typing import List
from matplotlib import pyplot as plt
from .cache import Cache
Expand All @@ -17,28 +18,15 @@ def draw_acc(
configs: List[dict],
sub_dir: str = "default",
) -> None:
figure_dir = os.path.join(FIGURE_DIR, sub_dir)
os.makedirs(figure_dir, exist_ok=True)
stats = [
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(
n=config["n"],
k=config["k"],
categories=categories,
)
for config in configs
]
for category in categories:
def draw_line(category: str, lines: List[dict]):
fig, ax = plt.subplots(figsize=(4, 3))
for i, stat in enumerate(stats):
for line in lines:
ax.plot(
stat[category]["acc"].keys(),
stat[category]["acc"].values(),
label=configs[i]["label"],
marker=configs[i]["marker"],
color=configs[i]["color"],
line["acc"].keys(),
line["acc"].values(),
label=line["label"],
marker=line["marker"],
color=line["color"],
)
ax.set_title(f"{dataset_name}: {category}")
ax.grid(
Expand Down Expand Up @@ -70,6 +58,43 @@ def draw_acc(
pad_inches=0.02,
)

figure_dir = os.path.join(FIGURE_DIR, sub_dir)
os.makedirs(figure_dir, exist_ok=True)
stats = [
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(
n=config["n"],
k=config["k"],
categories=categories,
)
for config in configs
]
# draw categories
for category in categories:
lines = []
for i, stat in enumerate(stats):
line = {"acc": stat[category]["acc"]}
line.update(configs[i])
lines.append(line)
draw_line(category=category, lines=lines)
# draw all
all_lines = []
for i, stat in enumerate(stats):
all_cnt = 0
all_acc = defaultdict(float)
for category in categories:
for k, v in stat[category]["acc"].items():
all_acc[k] += v * stat[category]["cnt"]
all_cnt += stat[category]["cnt"]
for k in all_acc:
all_acc[k] /= all_cnt
line = {"acc": all_acc}
line.update(configs[i])
all_lines.append(line)
draw_line(category="all", lines=all_lines)

@classmethod
def draw_p_cmp(
cls,
Expand Down Expand Up @@ -115,13 +140,15 @@ def draw_p_cmp(
ax.scatter(
right_p_gens,
right_p_cmps,
15,
label=configs[i]["label"],
alpha=0.6,
color=configs[i]["color"],
)
ax.scatter(
wrong_p_gens,
wrong_p_cmps,
15,
label=configs[i]["label"],
alpha=0.6,
color=configs[i]["color"],
Expand Down Expand Up @@ -211,29 +238,15 @@ def draw_acc(
configs: List[dict],
sub_dir: str = "default",
) -> None:
figure_dir = os.path.join(FIGURE_DIR, sub_dir)
os.makedirs(figure_dir, exist_ok=True)
stats = [
Cache(
project_name=config["project"],
job_name=config["job"],
).load_ucb_stats(
n=config["n"],
k=config["k"],
t=config["t"],
categories=categories,
)
for config in configs
]
for category in categories:
def draw_line(category: str, lines: List[dict]):
fig, ax = plt.subplots(figsize=(4, 3))
for i, stat in enumerate(stats):
for line in lines:
ax.plot(
stat[category]["acc"].keys(),
stat[category]["acc"].values(),
label=configs[i]["label"],
marker=configs[i]["marker"],
color=configs[i]["color"],
line["acc"].keys(),
line["acc"].values(),
label=line["label"],
marker=line["marker"],
color=line["color"],
)
ax.set_title(f"{dataset_name}: {category}")
ax.grid(
Expand Down Expand Up @@ -271,6 +284,42 @@ def draw_acc(
pad_inches=0.02,
)

figure_dir = os.path.join(FIGURE_DIR, sub_dir)
os.makedirs(figure_dir, exist_ok=True)
stats = [
Cache(
project_name=config["project"],
job_name=config["job"],
).load_ucb_stats(
n=config["n"],
k=config["k"],
t=config["t"],
categories=categories,
)
for config in configs
]
for category in categories:
lines = []
for i, stat in enumerate(stats):
line = {"acc": stat[category]["acc"]}
line.update(configs[i])
lines.append(line)
draw_line(category=category, lines=lines)
all_lines = []
for i, stat in enumerate(stats):
all_cnt = 0
all_acc = defaultdict(float)
for category in categories:
for k, v in stat[category]["acc"].items():
all_acc[k] += v * stat[category]["cnt"]
all_cnt += stat[category]["cnt"]
for k in all_acc:
all_acc[k] /= all_cnt
line = {"acc": all_acc}
line.update(configs[i])
all_lines.append(line)
draw_line(category="all", lines=all_lines)

@classmethod
def draw_p_cmp(
cls,
Expand Down Expand Up @@ -303,8 +352,8 @@ def draw_p_cmp(
wrong_p_gens = []
wrong_p_cmps = []
for qid, stat in stats[category]["details"].items():
if stat["cmp"]["valid"] > 0:
if stat["acc"][str(configs[i]["t"])] >= 0.5:
if stat["cmp"]["valid"] > 0.5:
if stat["acc"][str(configs[i]["t"])] == 1:
right_p_gens.append(stat["acc"]["avg"])
right_p_cmps.append(stat["cmp"]["p_cmp"])
else:
Expand Down
Loading

0 comments on commit a8f9002

Please sign in to comment.