From ed66b87d0db0e8e14d1a7828c28f27b3c1066e71 Mon Sep 17 00:00:00 2001 From: jeandut Date: Tue, 14 Jun 2022 14:14:52 +0200 Subject: [PATCH] adding seed in column --- flamby/benchmarks/fed_benchmark.py | 39 ++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/flamby/benchmarks/fed_benchmark.py b/flamby/benchmarks/fed_benchmark.py index e9926e422..1b30b9de0 100644 --- a/flamby/benchmarks/fed_benchmark.py +++ b/flamby/benchmarks/fed_benchmark.py @@ -86,8 +86,12 @@ def main(args_cli): strategy_specific_hp_dicts = get_strategies( config, learning_rate=LR, args=vars(args_cli) ) - pooled_hyperparameters = {"optimizer_class": Optimizer, "learning_rate": LR} - main_columns_names = ["Test", "Method", "Metric"] + pooled_hyperparameters = { + "optimizer_class": Optimizer, + "learning_rate": LR, + "seed": args_cli.seed, + } + main_columns_names = ["Test", "Method", "Metric", "seed"] # We might need to dynamically add additional parameters to the csv columns hp_additional_args = [] @@ -99,7 +103,7 @@ def main(args_cli): if arg_names not in hp_additional_args ] # column names used for the results file - columns_names = main_columns_names + hp_additional_args + columns_names = list(set(main_columns_names + hp_additional_args)) evaluate_func, batch_size_test, compute_ensemble_perf = set_dataset_specific_config( dataset_name, compute_ensemble_perf=True @@ -159,7 +163,9 @@ def main(args_cli): # We can now proceed to the trainings # Pooled training # We check if we have already the results for pooled - index_of_interest = df.loc[df["Method"] == "Pooled Training"].index + index_of_interest = df.loc[ + (df["Method"] == "Pooled Training") & (df["seed"] == args_cli.seed) + ].index # There is no use in running the experiment if it is already found if (len(index_of_interest) < (NUM_CLIENTS + 1)) and do_baselines["Pooled"]: @@ -207,13 +213,17 @@ def main(args_cli): pooled=True, ) - # We check if we have the results for local and possibly ensemble as well - index_of_interest = df.loc[df["Method"] == "Local 0"].index + # We check if we have the results for local trainings and possibly ensemble as well + index_of_interest = df.loc[ + (df["Method"] == "Local 0") & (df["seed"] == args_cli.seed) + ].index for i in range(1, NUM_CLIENTS): index_of_interest = index_of_interest.union( - df.loc[df["Method"] == f"Local {i}"].index + df.loc[(df["Method"] == f"Local {i}") & (df["seed"] == args_cli.seed)].index ) - index_of_interest = index_of_interest.union(df.loc[df["Method"] == "Ensemble"].index) + index_of_interest = index_of_interest.union( + df.loc[(df["Method"] == "Ensemble") & (df["seed"] == args_cli.seed)].index + ) if len(index_of_interest) < nb_local_and_ensemble_xps: # The fact that we are here means some local experiments are missing or @@ -224,7 +234,9 @@ def main(args_cli): pooled_y_pred_dicts = {} for i in range(NUM_CLIENTS): - index_of_interest = df.loc[df["Method"] == f"Local {i}"].index + index_of_interest = df.loc[ + (df["Method"] == f"Local {i}") & (df["seed"] == args_cli.seed) + ].index # We do the experiments only if results are not found or we need # ensemble performances AND this experiment is planned. # i.e. we allow to not do anything else if the user specify @@ -241,8 +253,8 @@ def main(args_cli): training_dls[i], use_gpu, method_name, - Optimizer, - LR, + pooled_hyperparameters["optimizer_class"], + pooled_hyperparameters["learning_rate"], BaselineLoss, NUM_EPOCHS_POOLED, ) @@ -259,7 +271,7 @@ def main(args_cli): df = fill_df_with_xp_results( df, perf_dict, - {"learning_rate": LR, "optimizer_class": Optimizer}, + pooled_hyperparameters, method_name, columns_names, results_file, @@ -267,7 +279,7 @@ def main(args_cli): df = fill_df_with_xp_results( df, pooled_perf_dict, - {"learning_rate": LR, "optimizer_class": Optimizer}, + pooled_hyperparameters, method_name, columns_names, results_file, @@ -336,6 +348,7 @@ def main(args_cli): hyperparameters[k] = args[k] else: hyperparameters[k] = np.nan + hyperparameters["seed"] = args_cli.seed index_of_interest = find_xps_in_df( df, hyperparameters, sname, num_updates