Skip to content

Commit

Permalink
update league
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Dec 24, 2024
1 parent dae2a73 commit d7ea31d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
16 changes: 6 additions & 10 deletions examples/paper_provable_scaling_law/competitions/league.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,19 @@ def calculate_stats(
)
# calculate acc
while candidate_num < self.n:
candidate_num *= 2
candidate_num += 1
question_stats["acc"][f"{candidate_num}"] = 0
for i in range(0, self.n, candidate_num):
sub_matrix = score_matrix[
i : i + candidate_num,
i : i + candidate_num,
]
for i in range(0, self.n):
indices = [(i + j) % self.n for j in range(candidate_num)]
sub_matrix = score_matrix[np.ix_(indices, indices)]
sub_board = [
sum(sub_matrix[j]) for j in range(candidate_num)
]
sub_final = candidates[i + np.argmax(sub_board)]
sub_final = candidates[indices[np.argmax(sub_board)]]
question_stats["acc"][f"{candidate_num}"] += int(
sub_final["answer"] == target,
)
question_stats["acc"][f"{candidate_num}"] /= (
self.n // candidate_num
)
question_stats["acc"][f"{candidate_num}"] /= self.n
if (
str(candidate_num)
not in category_stats[question["category"]]["acc"]
Expand Down
62 changes: 47 additions & 15 deletions examples/paper_provable_scaling_law/utils/drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
from collections import defaultdict
from typing import List
from matplotlib import pyplot as plt
import numpy as np
from .cache import Cache


FIGURE_DIR = os.path.join(os.path.dirname(__file__), "imgs")


class KnockoutFigureDrawer:
class CompetitionFigureDrawer:

@classmethod
def construct_suffix(cls, competition_type: str, config: dict):
if competition_type == "knockout":
return f"{config['n']}_{config['k']}"
elif competition_type == "league":
return f"{config['n']}_{config['k']}_{config['m']}"
else:
raise ValueError(f"Unknown competition type: {competition_type}")

@classmethod
def draw_line(
cls,
Expand Down Expand Up @@ -61,6 +72,7 @@ def draw_line(
@classmethod
def draw_acc(
cls,
competition_type: str,
dataset_name: str,
categories: List[str],
configs: List[dict],
Expand All @@ -72,10 +84,10 @@ def draw_acc(
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(
n=config["n"],
k=config["k"],
).load_competition_stats(
competition_type=competition_type,
categories=categories,
suffix=cls.construct_suffix(competition_type, config),
)
for config in configs
]
Expand All @@ -86,7 +98,12 @@ def draw_acc(
line = {"acc": stat[category]["acc"]}
line.update(configs[i])
lines.append(line)
cls.draw_line(category=category, lines=lines)
cls.draw_line(
dataset_name=dataset_name,
category=category,
lines=lines,
figure_dir=figure_dir,
)
# draw all
all_lines = []
for i, stat in enumerate(stats):
Expand All @@ -112,6 +129,7 @@ def draw_acc(
@classmethod
def draw_majority_vote(
cls,
competition_type: str,
dataset_name: str,
categories: List[str],
configs: List[dict],
Expand All @@ -123,10 +141,10 @@ def draw_majority_vote(
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(
n=config["n"],
k=config["k"],
).load_competition_stats(
competition_type=competition_type,
categories=categories,
suffix=cls.construct_suffix(competition_type, config),
)
for config in configs
]
Expand All @@ -153,6 +171,7 @@ def draw_majority_vote(
@classmethod
def draw_p_cmp(
cls,
competition_type: str,
dataset_name: str,
categories: List[str],
configs: List[dict],
Expand All @@ -164,18 +183,17 @@ def draw_p_cmp(
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(
n=config["n"],
k=config["k"],
).load_competition_stats(
competition_type=competition_type,
categories=categories,
suffix=cls.construct_suffix(competition_type, config),
)
for config in configs
]
for i, stats in enumerate(run_stats):
for category in categories:
all_correct_cnt = 0
all_wrong_cnt = 0
fig, ax = plt.subplots(figsize=(3.5, 3))
right_p_gens = []
right_p_cmps = []
wrong_p_gens = []
Expand All @@ -192,6 +210,10 @@ def draw_p_cmp(
all_wrong_cnt += 1
if stat["acc"]["1"] == 1:
all_correct_cnt += 1
fig = plt.figure(figsize=(4, 3))
gs = fig.add_gridspec(1, 2, width_ratios=[3, 1], wspace=0.0)
ax = fig.add_subplot(gs[0])
ax_hist = fig.add_subplot(gs[1], sharey=ax)
ax.scatter(
right_p_gens,
right_p_cmps,
Expand All @@ -209,6 +231,15 @@ def draw_p_cmp(
color=configs[i]["color"],
marker="x",
)
bins = np.linspace(-0.0, 1.0, 50)
ax_hist.hist(
right_p_cmps + wrong_p_cmps,
bins=bins,
orientation="horizontal",
color=configs[i]["color"],
alpha=0.6,
)
ax_hist.set_axis_off()
above_count = sum(
1 for p_cmp in right_p_cmps + wrong_p_cmps if p_cmp > 0.5
)
Expand Down Expand Up @@ -285,6 +316,7 @@ def draw_p_cmp(
@classmethod
def draw_subset_acc(
cls,
competition_type: str,
threshold: float,
dataset_name: str,
categories: List[str],
Expand All @@ -297,10 +329,10 @@ def draw_subset_acc(
Cache(
project_name=config["project"],
job_name=config["job"],
).load_knockout_stats(
n=config["n"],
k=config["k"],
).load_competition_stats(
competition_type=competition_type,
categories=categories,
suffix=cls.construct_suffix(competition_type, config),
)
for config in configs
]
Expand Down

0 comments on commit d7ea31d

Please sign in to comment.