From d7ea31dc886b444b8de71e7ed415b58e46cfec44 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 24 Dec 2024 10:38:19 +0800 Subject: [PATCH] update league --- .../competitions/league.py | 16 ++--- .../utils/drawer.py | 62 ++++++++++++++----- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/examples/paper_provable_scaling_law/competitions/league.py b/examples/paper_provable_scaling_law/competitions/league.py index ebb410afd..b6d54d59d 100644 --- a/examples/paper_provable_scaling_law/competitions/league.py +++ b/examples/paper_provable_scaling_law/competitions/league.py @@ -177,23 +177,19 @@ def calculate_stats( ) # calculate acc while candidate_num < self.n: - candidate_num *= 2 + candidate_num += 1 question_stats["acc"][f"{candidate_num}"] = 0 - for i in range(0, self.n, candidate_num): - sub_matrix = score_matrix[ - i : i + candidate_num, - i : i + candidate_num, - ] + for i in range(0, self.n): + indices = [(i + j) % self.n for j in range(candidate_num)] + sub_matrix = score_matrix[np.ix_(indices, indices)] sub_board = [ sum(sub_matrix[j]) for j in range(candidate_num) ] - sub_final = candidates[i + np.argmax(sub_board)] + sub_final = candidates[indices[np.argmax(sub_board)]] question_stats["acc"][f"{candidate_num}"] += int( sub_final["answer"] == target, ) - question_stats["acc"][f"{candidate_num}"] /= ( - self.n // candidate_num - ) + question_stats["acc"][f"{candidate_num}"] /= self.n if ( str(candidate_num) not in category_stats[question["category"]]["acc"] diff --git a/examples/paper_provable_scaling_law/utils/drawer.py b/examples/paper_provable_scaling_law/utils/drawer.py index ed8f93102..a40fe7e2c 100644 --- a/examples/paper_provable_scaling_law/utils/drawer.py +++ b/examples/paper_provable_scaling_law/utils/drawer.py @@ -3,13 +3,24 @@ from collections import defaultdict from typing import List from matplotlib import pyplot as plt +import numpy as np from .cache import Cache FIGURE_DIR = os.path.join(os.path.dirname(__file__), "imgs") -class KnockoutFigureDrawer: +class CompetitionFigureDrawer: + + @classmethod + def construct_suffix(cls, competition_type: str, config: dict): + if competition_type == "knockout": + return f"{config['n']}_{config['k']}" + elif competition_type == "league": + return f"{config['n']}_{config['k']}_{config['m']}" + else: + raise ValueError(f"Unknown competition type: {competition_type}") + @classmethod def draw_line( cls, @@ -61,6 +72,7 @@ def draw_line( @classmethod def draw_acc( cls, + competition_type: str, dataset_name: str, categories: List[str], configs: List[dict], @@ -72,10 +84,10 @@ def draw_acc( Cache( project_name=config["project"], job_name=config["job"], - ).load_knockout_stats( - n=config["n"], - k=config["k"], + ).load_competition_stats( + competition_type=competition_type, categories=categories, + suffix=cls.construct_suffix(competition_type, config), ) for config in configs ] @@ -86,7 +98,12 @@ def draw_acc( line = {"acc": stat[category]["acc"]} line.update(configs[i]) lines.append(line) - cls.draw_line(category=category, lines=lines) + cls.draw_line( + dataset_name=dataset_name, + category=category, + lines=lines, + figure_dir=figure_dir, + ) # draw all all_lines = [] for i, stat in enumerate(stats): @@ -112,6 +129,7 @@ def draw_acc( @classmethod def draw_majority_vote( cls, + competition_type: str, dataset_name: str, categories: List[str], configs: List[dict], @@ -123,10 +141,10 @@ def draw_majority_vote( Cache( project_name=config["project"], job_name=config["job"], - ).load_knockout_stats( - n=config["n"], - k=config["k"], + ).load_competition_stats( + competition_type=competition_type, categories=categories, + suffix=cls.construct_suffix(competition_type, config), ) for config in configs ] @@ -153,6 +171,7 @@ def draw_majority_vote( @classmethod def draw_p_cmp( cls, + competition_type: str, dataset_name: str, categories: List[str], configs: List[dict], @@ -164,10 +183,10 @@ def draw_p_cmp( Cache( project_name=config["project"], job_name=config["job"], - ).load_knockout_stats( - n=config["n"], - k=config["k"], + ).load_competition_stats( + competition_type=competition_type, categories=categories, + suffix=cls.construct_suffix(competition_type, config), ) for config in configs ] @@ -175,7 +194,6 @@ def draw_p_cmp( for category in categories: all_correct_cnt = 0 all_wrong_cnt = 0 - fig, ax = plt.subplots(figsize=(3.5, 3)) right_p_gens = [] right_p_cmps = [] wrong_p_gens = [] @@ -192,6 +210,10 @@ def draw_p_cmp( all_wrong_cnt += 1 if stat["acc"]["1"] == 1: all_correct_cnt += 1 + fig = plt.figure(figsize=(4, 3)) + gs = fig.add_gridspec(1, 2, width_ratios=[3, 1], wspace=0.0) + ax = fig.add_subplot(gs[0]) + ax_hist = fig.add_subplot(gs[1], sharey=ax) ax.scatter( right_p_gens, right_p_cmps, @@ -209,6 +231,15 @@ def draw_p_cmp( color=configs[i]["color"], marker="x", ) + bins = np.linspace(-0.0, 1.0, 50) + ax_hist.hist( + right_p_cmps + wrong_p_cmps, + bins=bins, + orientation="horizontal", + color=configs[i]["color"], + alpha=0.6, + ) + ax_hist.set_axis_off() above_count = sum( 1 for p_cmp in right_p_cmps + wrong_p_cmps if p_cmp > 0.5 ) @@ -285,6 +316,7 @@ def draw_p_cmp( @classmethod def draw_subset_acc( cls, + competition_type: str, threshold: float, dataset_name: str, categories: List[str], @@ -297,10 +329,10 @@ def draw_subset_acc( Cache( project_name=config["project"], job_name=config["job"], - ).load_knockout_stats( - n=config["n"], - k=config["k"], + ).load_competition_stats( + competition_type=competition_type, categories=categories, + suffix=cls.construct_suffix(competition_type, config), ) for config in configs ]