diff --git a/flamby/benchmarks/benchmark_utils.py b/flamby/benchmarks/benchmark_utils.py new file mode 100644 index 000000000..9c220a11c --- /dev/null +++ b/flamby/benchmarks/benchmark_utils.py @@ -0,0 +1,535 @@ +import copy + +import numpy as np +import pandas as pd +from torch.utils.data import DataLoader as dl +from tqdm import tqdm + +from flamby.utils import evaluate_model_on_tests + + +def fill_df_with_xp_results( + df, + perf_dict, + hyperparams, + method_name, + columns_names, + results_file, + dump=True, + pooled=False, +): + """Add results to dataframe for a specific strategy with specific hyperparameters. + + Parameters + ---------- + df : pd.DataFrame + The Dataframe of results + perf_dict: dict + A dictionnary with keys being the different tests and values being the metric. + hyperparams : dict + The dict of hyerparameters. + method_name : str + The name of the training method. + columns_names : list[str] + The columns names in the considered dataframe. + dump: bool + Should it dump the dataframe to disk after having added the results. + Defaults to True. + pooled: bool + If it is the pooled result we should change the name of the test to + distinguish it from the first local test. + Default to False. + """ + + perf_lines_dicts = df.to_dict("records") + if pooled: + assert ( + len(perf_dict) == 1 + ), "Your pooled perf dict has multiple keys this is impossible." + perf_dict["Pooled Test"] = perf_dict.pop(list(perf_dict)[0]) + + for k, v in perf_dict.items(): + perf_lines_dicts.append( + prepare_dict( + keys=columns_names, + allow_new=True, + Test=k, + Metric=v, + Method=method_name, + # We add the hyperparameters used + **hyperparams, + ) + ) + # We update csv and save it when the results are there + df = pd.DataFrame.from_dict(perf_lines_dicts) + if dump: + df.to_csv(results_file, index=False) + return df + + +def find_xps_in_df(df, hyperparameters, sname, num_updates): + """This function returns the index in the given dataframe where it found + results for a given set of hyperparameters of the sname federated strategy + with num_updates number of updates secified as a dict. + + Parameters + ---------- + df : pd.DataFrame + The dataframe of experiments + hyperparameters : dict + A dict with keys that are columns of the dataframe and values that are + used to filter the dataframe. + sname: str + The name of the FL strategy to investigate. + Should be in the following list: + ["FedAvg", "Scaffold", "FedProx", "Cyclic", "FedAdam", "FedAdagrad", + 'FedYogi'] + num_udpates: int + The number of batch updates used in the strategy. + """ + # This is very ugly but this is the only way I found to accomodate float + # and objects equality in a robust fashion + # The non-robust version would be simpler but it doesn't handle floats well + # index_of_interest = df.loc[ + # (df["Method"] == (sname + str(num_updates))) + # & ( + # df[list(hyperparameters)] == pd.Series(hyperparameters) + # ).all(axis=1) + # ].index + assert all( + [e in df.columns for e in list(hyperparameters)] + ), "Some hyperparameters provided are not included in the dataframe" + assert sname in [ + "FedAvg", + "Scaffold", + "FedProx", + "Cyclic", + "FedAdam", + "FedAdagrad", + "FedYogi", + ], f"Strategy name {sname} not recognized." + found_xps = df[list(hyperparameters)] + + # Different types of data need different matching strategy + found_xps_numerical = found_xps.select_dtypes(exclude=[object]) + col_numericals = found_xps_numerical.columns + col_objects = [c for c in found_xps.columns if not (c in col_numericals)] + + # Special cases for boolean parameters + if "deterministic_cycle" in found_xps_numerical.columns: + found_xps_numerical["deterministic_cycle"] = ( + found_xps_numerical["deterministic_cycle"].fillna(0.0).astype(float) + ) + + if len(col_numericals) > 0: + bool_numerical = np.all( + np.isclose( + found_xps_numerical, + pd.Series( + { + k: float(hyperparameters[k]) + for k in list(hyperparameters.keys()) + if k in col_numericals + } + ), + equal_nan=True, + ), + axis=1, + ) + else: + bool_numerical = np.ones((len(df.index), 1)).astype("bool") + + if len(col_objects): + bool_objects = found_xps[col_objects].astype(str) == pd.Series( + { + k: str(hyperparameters[k]) + for k in list(hyperparameters.keys()) + if k in col_objects + } + ) + else: + bool_objects = np.ones((len(df.index), 1)).astype("bool") + + # We filter on the Method we want + bool_method = df["Method"] == (sname + str(num_updates)) + + index_of_interest_1 = df.loc[pd.DataFrame(bool_numerical).all(axis=1)].index + index_of_interest_2 = df.loc[pd.DataFrame(bool_objects).all(axis=1)].index + index_of_interest_3 = df.loc[pd.DataFrame(bool_method).all(axis=1)].index + index_of_interest = index_of_interest_1.intersection( + index_of_interest_2 + ).intersection(index_of_interest_3) + return index_of_interest + + +def init_data_loaders( + dataset, + pooled=False, + batch_size=1, + num_workers=1, + num_clients=None, + batch_size_test=None, + collate_fn=None, +): + """ + Initializes the data loaders for the training and test datasets. + """ + if (not pooled) and num_clients is None: + raise ValueError("num_clients must be specified for the non-pooled data") + batch_size_test = batch_size if batch_size_test is None else batch_size_test + if not pooled: + training_dls = [ + dl( + dataset(center=i, train=True, pooled=False), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + collate_fn=collate_fn, + ) + for i in range(num_clients) + ] + test_dls = [ + dl( + dataset(center=i, train=False, pooled=False), + batch_size=batch_size_test, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn, + ) + for i in range(num_clients) + ] + return training_dls, test_dls + else: + train_pooled = dl( + dataset(train=True, pooled=True), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + collate_fn=collate_fn, + ) + test_pooled = dl( + dataset(train=False, pooled=True), + batch_size=batch_size_test, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn, + ) + return train_pooled, test_pooled + + +def prepare_dict(keys, allow_new=False, **kwargs): + """ + Prepares the dictionary with the given keys and fills them with the kwargs. + If allow_new is set to False (default) + Kwargs must be one of the keys. If + kwarg is not given for a key the value of that key will be None + """ + if not allow_new: + # ensure all the kwargs are in the columns + assert sum([not (key in keys) for key in kwargs.keys()]) == 0, ( + "Some of the keys given were not found in the existsing columns;" + f"keys: {kwargs.keys()}, columns: {keys}" + ) + + # create the dictionary from the given keys and fill when appropriate with the kwargs + return {**dict.fromkeys(keys), **kwargs} + + +def get_logfile_name_from_strategy(dataset_name, sname, num_updates, args): + """Produce exlicit logfile name from strategy num updates and args. + + Parameters + ---------- + dataset_name : str + The name of the dataset. + sname : str + The name of the strategy. + num_updates : int + The number of batch updates used in the strategy. + args : dict + The dict of hyperparameters of the strategy + """ + basename = dataset_name + "-" + sname + f"-num-updates{num_updates}" + for k, v in args.items(): + if k in ["learning_rate", "server_learning_rate"]: + basename += "-" + "".join([e[0] for e in str(k).split("_")]) + str(v) + if k in ["mu", "deterministic_cycle"]: + basename += "-" + str(k) + str(v) + return basename + + +def evaluate_model_on_local_and_pooled_tests( + m, local_dls, pooled_dl, metric, evaluate_func, return_pred=False +): + """Evaluate the model on a list of dataloaders and on one dataloader using + the evaluate function given. + + Parameters + ---------- + m : torch.nn.Module + The model to evaluate. + local_dls : list[torch.utils.data.DataLoader] + The list of dataloader used for tests. + pooled_dl : torch.utils.data.DataLoader + The single dataloader used for test. + metric: callable + The metric to use for evaluation. + evaluate_func : callable + The function used to evaluate + return_pred: bool + Whether or not to return pred. + + Returns + ------- + Tuple(dict, dict) + Two performances dicts. + """ + + perf_dict = evaluate_func(m, local_dls, metric, return_pred=return_pred) + pooled_perf_dict = evaluate_func(m, [pooled_dl], metric, return_pred=return_pred) + + # Very ugly tuple unpacking in case we return the predictions as well + # in thee future the evaluation function should return a dict but there is + # a lot of refactoring needed + if return_pred: + perf_dict, y_true_dict, y_pred_dict = perf_dict + pooled_perf_dict, y_true_pooled_dict, y_pred_pooled_dict = pooled_perf_dict + else: + y_true_dict, y_pred_dict, y_true_pooled_dict, y_pred_pooled_dict = ( + None, + None, + None, + None, + ) + + print("Per-center performance:") + print(perf_dict) + print("Performance on pooled test set:") + print(pooled_perf_dict) + return ( + perf_dict, + pooled_perf_dict, + y_true_dict, + y_pred_dict, + y_true_pooled_dict, + y_pred_pooled_dict, + ) + + +def train_single_centric( + global_init, + train_dl, + use_gpu, + name, + opt_class, + learning_rate, + loss_class, + num_epochs, + return_pred=False, +): + """Train the global_init model usiing train_dl and defaut parameters. + + Parameters + ---------- + global_init : torch.nn.Module + The initialized model to train. + train_dl : torch.utils.data.DataLoader + The dataloader to use for training. + use_gpu : bool + Whether or not to use the GPU. + name : str + The name of the method to display. + opt_class: torch.optim + A callable with signature (list[torch.Tensor], lr) -> torch.optim + learning_rate: float + The learning rate of the optimizer. + loss_class: torch.losses._Loss + A callable return a pytorch loss. + num_epochs: int + The number of epochs on which to train. + Returns + ------- + torch.nn.Module + The trained model. + """ + model = copy.deepcopy(global_init) + if use_gpu: + model.cuda() + bloss = loss_class() + opt = opt_class(model.parameters(), lr=learning_rate) + print(name) + for _ in tqdm(range(num_epochs)): + for X, y in train_dl: + if use_gpu: + # use GPU if requested and available + X = X.cuda() + y = y.cuda() + opt.zero_grad() + y_pred = model(X) + loss = bloss(y_pred, y) + loss.backward() + opt.step() + return model + + +def init_xp_plan( + num_clients, + nlocal, + single_centric_baseline=None, + strategy=None, + compute_ensemble_perf=False, +): + """_summary_ + + Parameters + ---------- + num_clients : int + The number of available clients. + nlocal : int + The index of the chosen client. + single_centric_baseline : str + The single centric baseline to comute. + strategy: str + The strategy to compute results for. + + Returns + ------- + dict + A dict with the plannification of xps to do. + + Raises + ------ + ValueError + _description_ + """ + do_strategy = True + do_baselines = {"Pooled": True} + for i in range(num_clients): + do_baselines[f"Local {i}"] = True + # Single client baseline computation + if single_centric_baseline is not None: + do_baselines = {"Pooled": False} + for i in range(num_clients): + do_baselines[f"Local {i}"] = False + if single_centric_baseline == "Pooled": + do_baselines[single_centric_baseline] = True + elif single_centric_baseline == "Local": + assert nlocal in range(num_clients), "The client you chose does not exist" + do_baselines[single_centric_baseline + " " + str(nlocal)] = True + # If we do a single-centric baseline we don't do the strategies + do_strategy = False + + # if we give a strategy we compute only the strategy and not the baselines + if strategy is not None: + for k, _ in do_baselines.items(): + do_baselines[k] = False + + do_all_local = all([do_baselines[f"Local {i}"] for i in range(num_clients)]) + if compute_ensemble_perf and not (do_all_local): + raise ValueError( + "Cannot compute ensemble performance if training on only one local" + ) + return do_baselines, do_strategy + + +def ensemble_perf_from_predictions( + y_true_dicts, y_pred_dicts, num_clients, metric, num_clients_test=None +): + """_summary_ + + Parameters + ---------- + y_true_dicts : dict + The ground truth dicts for all clients + y_pred_dicts :dict + The prediction array for all models and clients. + num_clients : int + The number of clients + metric : callable + (torch.Tensor, torch.Tensor) -> [0, 1.] + num_clients_test: int + When testing on pooled. + + Returns + ------- + dict + A dict with the predictions of all ensembles + """ + print("Computing ensemble performance") + ensemble_perf = {} + if num_clients_test is None: + num_clients_test = num_clients + for testset_idx in range(num_clients_test): + # Small safety net + for model_idx in range(1, num_clients): + assert ( + y_true_dicts[f"Local {0}"][f"client_test_{testset_idx}"] + == y_true_dicts[f"Local {model_idx}"][f"client_test_{testset_idx}"] + ).all(), "Models in the ensemble have different ground truths" + # Since they are all the same we use the first one for this specific tests as the ground truth + ensemble_true = y_true_dicts["Local 0"][f"client_test_{testset_idx}"] + + # Accumulating predictions + ensemble_pred = y_pred_dicts["Local 0"][f"client_test_{testset_idx}"] + for model_idx in range(1, num_clients): + ensemble_pred += y_pred_dicts[f"Local {model_idx}"][ + f"client_test_{testset_idx}" + ] + ensemble_pred /= float(num_clients) + ensemble_perf[f"client_test_{testset_idx}"] = metric( + ensemble_true, ensemble_pred + ) + return ensemble_perf + + +def set_dataset_specific_config(dataset_name, compute_ensemble_perf=False, use_gpu=True): + """_summary_ + + Parameters + ---------- + dataset_name : _type_ + _description_ + compute_ensemble_perf: bool + Whether or not to compute ensemble performances. Cannot be used with + KITS or LIDC. Defaults to None. + + Returns + ------- + _type_ + _description_ + """ + # Instantiate all train and test dataloaders required including pooled ones + if dataset_name == "fed_lidc_idri": + batch_size_test = 1 + from flamby.datasets.fed_lidc_idri import evaluate_dice_on_tests_by_chunks + + def evaluate_func(m, test_dls, metric, use_gpu=use_gpu, return_pred=False): + dice_dict = evaluate_dice_on_tests_by_chunks(m, test_dls, use_gpu) + if return_pred: + return dice_dict, None, None + return dice_dict + + compute_ensemble_perf = False + elif dataset_name == "fed_kits19": + from flamby.datasets.fed_kits19 import evaluate_dice_on_tests + + batch_size_test = 2 + + def evaluate_func(m, test_dls, metric, use_gpu=use_gpu, return_pred=False): + dice_dict = evaluate_dice_on_tests(m, test_dls, metric, use_gpu) + if return_pred: + return dice_dict, None, None + return dice_dict + + compute_ensemble_perf = False + + elif dataset_name == "fed_ixi": + batch_size_test = 1 + evaluate_func = evaluate_model_on_tests + compute_ensemble_perf = False + + else: + batch_size_test = None + evaluate_func = evaluate_model_on_tests + + return evaluate_func, batch_size_test, compute_ensemble_perf diff --git a/flamby/benchmarks/conf.py b/flamby/benchmarks/conf.py new file mode 100644 index 000000000..153989814 --- /dev/null +++ b/flamby/benchmarks/conf.py @@ -0,0 +1,197 @@ +import inspect +import json +from pathlib import Path + +import torch # noqa:F401 # necessary for importing optimizer + +import flamby + + +def check_config(config_path): + config = json.loads(Path(config_path).read_text()) + # ensure that dataset exists + try: + # try importing the dataset from the config file + getattr( + __import__("flamby.datasets", fromlist=[config["dataset"]]), + config["dataset"], + ) + except AttributeError: + raise AttributeError( + f"Dataset {config['dataset']} has not been found in flamby.datasets." + "Please ensure that the spelling is correct." + ) + + # ensure that the strategies exist + for strategy in config["strategies"]: + try: + # try importing the strategy from the config file + getattr(__import__("flamby.strategies", fromlist=[strategy]), strategy) + except AttributeError: + raise AttributeError( + f"Strategy {strategy} has not been found in flamby.strategies." + "Please ensure that the spelling is correct." + ) + if "optimizer_class" in config["strategies"][strategy].keys(): + # ensure that optimizer (if any) comes from the torch library + if not config["strategies"][strategy]["optimizer_class"].startswith( + "torch." + ): + raise ValueError("Optimizer must be from torch") + + # ensure that the results file exists if not create it + results_file = Path(config["results_file"]) + + if not results_file.suffix == ".csv": + results_file.with_suffix(".csv") + results_file.parent.mkdir(parents=True, exist_ok=True) + return config + + +def get_dataset_args( + dataset_name, + params=[ + "BATCH_SIZE", + "LR", + "NUM_CLIENTS", + "NUM_EPOCHS_POOLED", + "Baseline", + "BaselineLoss", + "Optimizer", + "get_nb_max_rounds", + "metric", + "collate_fn", + ], +): + """Get dataset spepcific handles + + Parameters + ---------- + dataset_name : str + The name of the dataset to use. + params : list, optional + All named pparameters to be fetched, by default + [ "BATCH_SIZE", "LR", "NUM_CLIENTS", "NUM_EPOCHS_POOLED", "Baseline", + "BaselineLoss", "Optimizer", "get_nb_max_rounds", "metric", + "collate_fn", ] + + Returns + ------- + tuple(str, torch.utils.data.Dataset, list) + _description_ + """ + # We first get all parameters excluding datasets + param_list = [] + for param in params: + try: + p = getattr( + __import__(f"flamby.datasets.{dataset_name}", fromlist=param), + param, + ) + except AttributeError: + p = None + param_list.append(p) + + fed_dataset_name = dataset_name.split("_") + fed_dataset_name = "".join([name.capitalize() for name in fed_dataset_name]) + + if fed_dataset_name == "FedIxi": + fed_dataset_name = "FedIXITiny" + + fed_dataset = getattr( + __import__(f"flamby.datasets.{dataset_name}", fromlist=fed_dataset_name), + fed_dataset_name, + ) + return fed_dataset, param_list + + +def get_strategies(config, learning_rate=None, args={}): + """Parse the config to extract strategies and hyperparameters. + Parameters + ---------- + config : dict + The config dict. + learning_rate : float + The learning rate to use, by default None + args : dict, optional + The dict given by the CLI, by default {} if given will supersede the + config. + + Returns + ------- + dict + dict with all strategies and their hyperparameters. + + Raises + ------ + ValueError + Some parameter are incorrect. + """ + if args["strategy"] is not None: + strategies = {args["strategy"]: {}} + for k, v in args.items(): + if k in [ + "mu", + "server_learning_rate", + "learning_rate", + "optimizer_class", + "deterministic", + ] and (v is not None): + strategies[args["strategy"]][k] = v + if args["strategy"] != "Cyclic": + strategies[args["strategy"]].pop("deterministic") + else: + strategies[args["strategy"]]["deterministic_cycle"] = strategies[ + args["strategy"] + ].pop("deterministic") + + else: + strategies = config["strategies"] + + for sname, sparams in strategies.items(): + possible_parameters = dict( + inspect.signature(getattr(flamby.strategies, sname)).parameters + ) + non_compatible_parameters = [ + param + for param, _ in sparams.items() + if not ((param in possible_parameters) or (param == "learning_rate_scaler")) + ] + assert ( + len(non_compatible_parameters) == 0 + ), f"The parameter.s {non_compatible_parameters} is/are not" + "compatible with the strategy's signature. " + f"Please check the {sname} strategy documentation." + + # We occasionally apply the scaler + if ("learning_rate" in sparams) and ("learning_rate_scaler" in sparams): + raise ValueError( + "Cannot provide both a leraning rate and a learning rate scaler." + ) + elif "learning_rate" not in sparams: + scaler = ( + 1.0 + if not ("learning_rate_scaler" in sparams) + else sparams.pop("learning_rate_scaler") + ) + strategies[sname]["learning_rate"] = learning_rate * float(scaler) + + if "optimizer_class" in sparams: + strategies[sname]["optimizer_class"] = eval(sparams.pop("optimizer_class")) + + if (sname == "FedProx") and "mu" not in sparams: + raise ValueError("If using FedProx you should provide a value for mu.") + + return strategies + + +def get_results_file(config, path=None): + if path is None: + return Path(config["results_file"]) + else: + return Path(path) + + +if __name__ == "__main__": + get_strategies() + # check_config(config) diff --git a/flamby/benchmarks/fed_benchmark.py b/flamby/benchmarks/fed_benchmark.py new file mode 100644 index 000000000..39d5a5151 --- /dev/null +++ b/flamby/benchmarks/fed_benchmark.py @@ -0,0 +1,529 @@ +import argparse +import copy + +import numpy as np +import pandas as pd +import torch + +import flamby.strategies as strats +from flamby.benchmarks.benchmark_utils import ( + ensemble_perf_from_predictions, + evaluate_model_on_local_and_pooled_tests, + fill_df_with_xp_results, + find_xps_in_df, + get_logfile_name_from_strategy, + init_data_loaders, + init_xp_plan, + set_dataset_specific_config, + train_single_centric, +) +from flamby.benchmarks.conf import ( + check_config, + get_dataset_args, + get_results_file, + get_strategies, +) +from flamby.gpu_utils import use_gpu_idx + + +def main(args_cli): + """_summary_ + + Parameters + ---------- + args_cli : A namespace of hyperparameters providing the ability to overwrite + the config provided to some extents. + + Returns + ------- + _type_ + _description_ + """ + # Use the same initialization for everyone in order to be fair + torch.manual_seed(args_cli.seed) + np.random.seed(args_cli.seed) + + use_gpu = use_gpu_idx(args_cli.GPU, args_cli.cpu_only) + # Find a way to provide it through hyperparameters + run_num_updates = [100] + + # ensure that the config provided by the user is ok + config = check_config(args_cli.config_file_path) + + dataset_name = config["dataset"] + + # get all the dataset specific handles + ( + FedDataset, + [ + BATCH_SIZE, + LR, + NUM_CLIENTS, + NUM_EPOCHS_POOLED, + Baseline, + BaselineLoss, + Optimizer, + get_nb_max_rounds, + metric, + collate_fn, + ], + ) = get_dataset_args(dataset_name) + + nrounds_list = [get_nb_max_rounds(num_updates) for num_updates in run_num_updates] + + if args_cli.debug: + nrounds_list = [1 for _ in run_num_updates] + NUM_EPOCHS_POOLED = 1 + + if args_cli.hard_debug: + nrounds_list = [0 for _ in run_num_updates] + NUM_EPOCHS_POOLED = 0 + + # We can now instantiate the dataset specific model on CPU + global_init = Baseline() + + # We pparse the hyperparams from the config or from the CLI if strategy is given + strategy_specific_hp_dicts = get_strategies( + config, learning_rate=LR, args=vars(args_cli) + ) + pooled_hyperparameters = { + "optimizer_class": Optimizer, + "learning_rate": LR, + "seed": args_cli.seed, + } + main_columns_names = ["Test", "Method", "Metric", "seed"] + + # We might need to dynamically add additional parameters to the csv columns + all_strategies_args = [] + # get all hparam names from all the strategies used + for strategy in strategy_specific_hp_dicts.values(): + all_strategies_args += [ + arg_names + for arg_names in strategy.keys() + if arg_names not in all_strategies_args + ] + # column names used for the results file + columns_names = list(set(main_columns_names + all_strategies_args)) + + evaluate_func, batch_size_test, compute_ensemble_perf = set_dataset_specific_config( + dataset_name, compute_ensemble_perf=True + ) + + # We compute the number of local and ensemble performances we should have + # in the results dataframe + nb_local_and_ensemble_xps = (NUM_CLIENTS + int(compute_ensemble_perf)) * ( + NUM_CLIENTS + 1 + ) + + # We init dataloader for train and test and for local and pooled datasets + + training_dls, test_dls = init_data_loaders( + dataset=FedDataset, + pooled=False, + batch_size=BATCH_SIZE, + num_workers=args_cli.workers, + num_clients=NUM_CLIENTS, + batch_size_test=batch_size_test, + collate_fn=collate_fn, + ) + train_pooled, test_pooled = init_data_loaders( + dataset=FedDataset, + pooled=True, + batch_size=BATCH_SIZE, + num_workers=args_cli.workers, + batch_size_test=batch_size_test, + collate_fn=collate_fn, + ) + + # Check if some results are already computed + results_file = get_results_file(config, path=args_cli.results_file_path) + if results_file.exists(): + df = pd.read_csv(results_file) + # Update df if new hyperparameters added + df = df.reindex( + df.columns.union(columns_names, sort=False).unique(), + axis="columns", + fill_value=None, + ) + else: + # initialize data frame with the column_names and no data if no csv was + # found + df = pd.DataFrame(columns=columns_names) + + # We compute the experiment plan given the config and user-specific hyperparams + + do_baselines, do_strategy = init_xp_plan( + NUM_CLIENTS, + args_cli.nlocal, + args_cli.single_centric_baseline, + args_cli.strategy, + compute_ensemble_perf, + ) + + # We can now proceed to the trainings + # Pooled training + # We check if we have already the results for pooled + index_of_interest = df.loc[ + (df["Method"] == "Pooled Training") & (df["seed"] == args_cli.seed) + ].index + + # There is no use in running the experiment if it is already found + if (len(index_of_interest) < (NUM_CLIENTS + 1)) and do_baselines["Pooled"]: + # dealing with edge case that shouldn't happen + # If some of the rows are there but not all of them we redo the + # experiments + if len(index_of_interest) > 0: + df.drop(index_of_interest, inplace=True) + m = copy.deepcopy(global_init) + m = train_single_centric( + m, + train_pooled, + use_gpu, + "Pooled", + pooled_hyperparameters["optimizer_class"], + pooled_hyperparameters["learning_rate"], + BaselineLoss, + NUM_EPOCHS_POOLED, + ) + ( + perf_dict, + pooled_perf_dict, + _, + _, + _, + _, + ) = evaluate_model_on_local_and_pooled_tests( + m, test_dls, test_pooled, metric, evaluate_func + ) + df = fill_df_with_xp_results( + df, + perf_dict, + pooled_hyperparameters, + "Pooled Training", + columns_names, + results_file, + ) + df = fill_df_with_xp_results( + df, + pooled_perf_dict, + pooled_hyperparameters, + "Pooled Training", + columns_names, + results_file, + pooled=True, + ) + + # We check if we have the results for local trainings and possibly ensemble as well + index_of_interest = df.loc[ + (df["Method"] == "Local 0") & (df["seed"] == args_cli.seed) + ].index + for i in range(1, NUM_CLIENTS): + index_of_interest = index_of_interest.union( + df.loc[(df["Method"] == f"Local {i}") & (df["seed"] == args_cli.seed)].index + ) + index_of_interest = index_of_interest.union( + df.loc[(df["Method"] == "Ensemble") & (df["seed"] == args_cli.seed)].index + ) + + if len(index_of_interest) < nb_local_and_ensemble_xps: + # The fact that we are here means some local experiments are missing or + # we need to compute ensemble as well so we need to redo all local experiments + y_true_dicts = {} + y_pred_dicts = {} + pooled_y_true_dicts = {} + pooled_y_pred_dicts = {} + + for i in range(NUM_CLIENTS): + index_of_interest = df.loc[ + (df["Method"] == f"Local {i}") & (df["seed"] == args_cli.seed) + ].index + # We do the experiments only if results are not found or we need + # ensemble performances AND this experiment is planned. + # i.e. we allow to not do anything else if the user specify + if ( + (len(index_of_interest) < (NUM_CLIENTS + 1)) or compute_ensemble_perf + ) and do_baselines[f"Local {i}"]: + if len(index_of_interest) > 0: + df.drop(index_of_interest, inplace=True) + + m = copy.deepcopy(global_init) + method_name = f"Local {i}" + m = train_single_centric( + m, + training_dls[i], + use_gpu, + method_name, + pooled_hyperparameters["optimizer_class"], + pooled_hyperparameters["learning_rate"], + BaselineLoss, + NUM_EPOCHS_POOLED, + ) + ( + perf_dict, + pooled_perf_dict, + y_true_dicts[f"Local {i}"], + y_pred_dicts[f"Local {i}"], + pooled_y_true_dicts[f"Local {i}"], + pooled_y_pred_dicts[f"Local {i}"], + ) = evaluate_model_on_local_and_pooled_tests( + m, test_dls, test_pooled, metric, evaluate_func, return_pred=True + ) + df = fill_df_with_xp_results( + df, + perf_dict, + pooled_hyperparameters, + method_name, + columns_names, + results_file, + ) + df = fill_df_with_xp_results( + df, + pooled_perf_dict, + pooled_hyperparameters, + method_name, + columns_names, + results_file, + pooled=True, + ) + + if compute_ensemble_perf: + print( + "Computing ensemble performance, local models need to have been trained in the same runtime" + ) + local_ensemble_perf = ensemble_perf_from_predictions( + y_true_dicts, y_pred_dicts, NUM_CLIENTS, metric + ) + pooled_ensemble_perf = ensemble_perf_from_predictions( + pooled_y_true_dicts, + pooled_y_pred_dicts, + NUM_CLIENTS, + metric, + num_clients_test=1, + ) + + df = fill_df_with_xp_results( + df, + local_ensemble_perf, + pooled_hyperparameters, + "Ensemble", + columns_names, + results_file, + ) + df = fill_df_with_xp_results( + df, + pooled_ensemble_perf, + pooled_hyperparameters, + "Ensemble", + columns_names, + results_file, + pooled=True, + ) + + # Strategies + if do_strategy: + for idx, num_updates in enumerate(run_num_updates): + for sname in strategy_specific_hp_dicts.keys(): + # Base arguments + m = copy.deepcopy(global_init) + bloss = BaselineLoss() + # We init the strategy parameters to the following default ones + args = { + "training_dataloaders": training_dls, + "model": m, + "loss": bloss, + "optimizer_class": torch.optim.SGD, + "learning_rate": LR, + "num_updates": num_updates, + "nrounds": nrounds_list[idx], + } + # We overwrite defaults with new hyperparameters from config + strategy_specific_hp_dict = strategy_specific_hp_dicts[sname] + # Overwriting arguments with strategy specific arguments + for k, v in strategy_specific_hp_dict.items(): + args[k] = v + # We fill the hyperparameters dict for later use in filling the csv by filling missing column with nans + hyperparameters = {} + for k in all_strategies_args: + if k in args: + hyperparameters[k] = args[k] + else: + hyperparameters[k] = np.nan + hyperparameters["seed"] = args_cli.seed + + index_of_interest = find_xps_in_df( + df, hyperparameters, sname, num_updates + ) + + # An experiment is finished if there are num_clients + 1 rows + if len(index_of_interest) < (NUM_CLIENTS + 1): + # Dealing with edge case that shouldn't happen + # If some of the rows are there but not all of them we redo the + # experiments + if len(index_of_interest) > 0: + df.drop(index_of_interest, inplace=True) + basename = get_logfile_name_from_strategy( + dataset_name, sname, num_updates, args + ) + + # We run the FL strategy + s = getattr(strats, sname)( + **args, log=args_cli.log, log_basename=basename + ) + print("FL strategy", sname, " num_updates ", num_updates) + m = s.run()[0] + ( + perf_dict, + pooled_perf_dict, + _, + _, + _, + _, + ) = evaluate_model_on_local_and_pooled_tests( + m, test_dls, test_pooled, metric, evaluate_func + ) + + df = fill_df_with_xp_results( + df, + perf_dict, + hyperparameters, + sname + str(num_updates), + columns_names, + results_file, + ) + df = fill_df_with_xp_results( + df, + pooled_perf_dict, + hyperparameters, + sname + str(num_updates), + columns_names, + results_file, + pooled=True, + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--GPU", + type=int, + default=0, + help="GPU to run the training on (if available)", + ) + parser.add_argument( + "--cpu-only", + action="store_true", + default=False, + help="Force computation on CPU.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Do 1 round and 1 epoch to check if the script is working", + ) + parser.add_argument( + "--hard-debug", + action="store_true", + default=False, + help="Do 0 round and 0 epoch to check if the script is working", + ) + parser.add_argument( + "--workers", + type=int, + default=0, + help="Numbers of workers for the dataloader", + ) + parser.add_argument( + "--learning_rate", + "-lr", + type=float, + default=None, + help="Client side learning rate if strategy is given", + ) + parser.add_argument( + "--server_learning_rate", + "-slr", + type=float, + default=None, + help="Server side learning rate if strategy is given", + ) + parser.add_argument( + "--mu", + "-mu", + type=float, + default=None, + help="FedProx mu parameter if strategy is given and that it is FedProx", + ) + parser.add_argument( + "--strategy", + "-s", + type=str, + default=None, + help="If this parameter is chosen will only run this specific strategy", + choices=[ + None, + "FedAdam", + "FedYogi", + "FedAdagrad", + "Scaffold", + "FedAvg", + "Cyclic", + "FedProx", + ], + ) + parser.add_argument( + "--optimizer-class", + "-opt", + type=str, + default="torch.optim.SGD", + help="The optimizer class to use if strategy is given", + ) + parser.add_argument( + "--deterministic", + "-d", + action="store_true", + default=False, + help="whether or not to use deterministic cycling for the cyclic strategy", + ) + parser.add_argument( + "--log", + "-l", + action="store_true", + default=False, + help="Whether or not to log the strategies", + ) + parser.add_argument( + "--config-file-path", + "-cfp", + default="./config.json", + type=str, + help="Which config file to use.", + ) + parser.add_argument( + "--results-file-path", + "-rfp", + default=None, + type=str, + help="The path to the created results (overwrite the config path)", + ) + parser.add_argument( + "--single-centric-baseline", + "-scb", + default=None, + type=str, + help="Whether or not to compute only one single-centric baseline and which one.", + choices=["Pooled", "Local"], + ) + parser.add_argument( + "--nlocal", + default=0, + type=int, + help="Will only be used if --single-centric-baseline Local, will test" + "only training on Local {nlocal}.", + ) + parser.add_argument("--seed", default=0, type=int, help="Seed") + + args = parser.parse_args() + + main(args) diff --git a/flamby/conf.py b/flamby/conf.py deleted file mode 100644 index c7828a0ee..000000000 --- a/flamby/conf.py +++ /dev/null @@ -1,139 +0,0 @@ -import json -from pathlib import Path - -import torch # noqa:F401 # necessary for importing optimizer - - -def check_config(config_path): - config = json.loads(Path(config_path).read_text()) - # ensure that dataset exists - try: - # try importing the dataset from the config file - getattr( - __import__("flamby.datasets", fromlist=[config["dataset"]]), - config["dataset"], - ) - except AttributeError: - raise AttributeError( - f"Dataset {config['dataset']} has not been found in flamby.datasets." - "Please ensure that the spelling is correct." - ) - - # ensure that the strategies exist - for strategy in config["strategies"]: - try: - # try importing the strategy from the config file - getattr(__import__("flamby.strategies", fromlist=[strategy]), strategy) - except AttributeError: - raise AttributeError( - f"Strategy {strategy} has not been found in flamby.strategies." - "Please ensure that the spelling is correct." - ) - if "optimizer_class" in config["strategies"][strategy].keys(): - # ensure that optimizer (if any) comes from the torch library - if not config["strategies"][strategy]["optimizer_class"].startswith( - "torch." - ): - raise ValueError("Optimizer must be from torch") - - # ensure that the results file exists if not create it - results_file = Path(config["results_file"]) - - if not results_file.suffix == ".csv": - results_file.with_suffix(".csv") - results_file.parent.mkdir(parents=True, exist_ok=True) - return config - - -def get_dataset_args( - config, - params=[ - "BATCH_SIZE", - "LR", - "NUM_CLIENTS", - "NUM_EPOCHS_POOLED", - "Baseline", - "BaselineLoss", - ], -): - param_list = [] - for param in params: - try: - p = getattr( - __import__(f"flamby.datasets.{config['dataset']}", fromlist=param), - param, - ) - except AttributeError: - p = None - param_list.append(p) - - fed_dataset_name = config["dataset"].split("_") - fed_dataset_name = "".join([name.capitalize() for name in fed_dataset_name]) - if fed_dataset_name == "FedIxi": - fed_dataset_name = "FedIXITiny" - - fed_dataset = getattr( - __import__(f"flamby.datasets.{config['dataset']}", fromlist=fed_dataset_name), - fed_dataset_name, - ) - return config["dataset"], fed_dataset, param_list - - -def get_strategies(config, learning_rate=None, args={}): - - strategies = config["strategies"] - if args and any([v is not None for k, v in args.items()]): - if args["strategy"] is not None: - strategies = { - args["strategy"]: { - "optimizer_class": args["optimizer_class"], - "learning_rate": args["learning_rate"] - if args["learning_rate"] is not None - else learning_rate, - } - } - if args["mu"] is not None: - assert args["strategy"] == "FedProx" - strategies["FedProx"]["mu"] = args["mu"] - if args["server_learning_rate"] is not None: - assert args["strategy"] in [ - "Scaffold", - "FedAdam", - "FedYogi", - "FedAdagrad", - ] - strategies[args["strategy"]]["server_learning_rate"] = args[ - "server_learning_rate" - ] - if args["strategy"] == "Cyclic": - - strategies[args["strategy"]]["deterministic_cycle"] = args["deterministic"] - - for strategy in strategies.keys(): - if "optimizer_class" in strategies[strategy].keys(): - # have optimizer as a collable param and not a string - strategies[strategy]["optimizer_class"] = eval( - strategies[strategy]["optimizer_class"] - ) - if "learning_rate_scaler" in strategies[strategy].keys(): - if learning_rate is None: - raise ValueError("Learning rate is not defined. Please define it") - # calculate learning rate - strategies[strategy]["learning_rate"] = ( - learning_rate / strategies[strategy]["learning_rate_scaler"] - ) - strategies[strategy].pop("learning_rate_scaler") - return strategies - - - -def get_results_file(config, path=None): - if path is None: - return Path(config["results_file"]) - else: - return Path(path) - - -if __name__ == "__main__": - get_strategies() - # check_config(config) diff --git a/flamby/datasets/fed_heart_disease/model.py b/flamby/datasets/fed_heart_disease/model.py index 06315cbf9..23cc66c07 100644 --- a/flamby/datasets/fed_heart_disease/model.py +++ b/flamby/datasets/fed_heart_disease/model.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn + class Baseline(nn.Module): - def __init__(self, input_dim=16, output_dim=1): + def __init__(self, input_dim=13, output_dim=1): super(Baseline, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) diff --git a/flamby/extract_config.py b/flamby/extract_config.py index dd06cc214..211764c54 100644 --- a/flamby/extract_config.py +++ b/flamby/extract_config.py @@ -1,65 +1,139 @@ +import argparse import inspect import json import os -from glob import glob import numpy as np import pandas as pd import torch -dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "results") -csv_files = glob(os.path.join(dir_path, "results_*.csv")) -dataset_names = [ - "_".join(csvf.split("/")[-1].split(".")[0].split("_")[2:]) for csvf in csv_files -] -optimizers_classes = [e[1] for e in inspect.getmembers(torch.optim, inspect.isclass)] -csvs = [pd.read_csv(e) for e in csv_files] -configs = [] -for dname, csv, csvf in zip(dataset_names, csvs, csv_files): - config = {} - config["dataset"] = dname - config["results_file"] = csvf.split("/")[-1] - config["strategies"] = {} - for stratname in [ - "Scaffold", - "Cyclic", - "FedAdam", - "FedYogi", - "FedAvg", - "FedProx", - "FedAdagrad", - ]: - config["strategies"][stratname] = {} - current = csv.loc[ - (csv["Method"] == stratname + "100") & (csv["Test"] == "Pooled Test") +def main(args_cli): + datasets = [ + "fed_kits19", + "fed_ixi", + "fed_camelyon16", + "fed_isic2019", + "fed_lidc_idri", + "fed_heart_disease", + "fed_tcga_brca", + ] + + csv_files = args_cli.path_to_results + if args_cli.dataset_name is None: + dataset_names = [ + "_".join(csvf.split("/")[-1].split(".")[0].split("_")[2:]) + for csvf in csv_files ] - current = current.reset_index() - try: - idx = current["Metric"].idxmax() - except ValueError: - print(f"For dataset {dname} missing {stratname} !!!") - continue - best_hyperparams = current.iloc[idx][ - [col for col in current.columns if col not in ["Test", "Method", "Metric"]] - ].to_dict() - best_hyperparams.pop("index") - for k, v in best_hyperparams.items(): + breakpoint() + assert all([d in datasets for d in dataset_names]) + else: + if len(args_cli.dataset_name) == len(csv_files): + dataset_names = args_cli.dataset_name + elif len(args_cli.dataset_name) == 1: + dataset_names = [args_cli.dataset_name[0] for _ in range(len(csv_files))] + else: + raise ValueError( + "You should provide as many dataset names as you gave results" + " files or 1 if they all come from the same dataset." + ) + optimizers_classes = [e[1] for e in inspect.getmembers(torch.optim, inspect.isclass)] + csvs = [pd.read_csv(e) for e in csv_files] + for dname, csv, csvf in zip(dataset_names, csvs, csv_files): + config = {} + config["dataset"] = dname + config["results_file"] = csvf.split("/")[-1] + config["strategies"] = {} + for stratname in [ + "Scaffold", + "Cyclic", + "FedAdam", + "FedYogi", + "FedAvg", + "FedProx", + "FedAdagrad", + ]: + config["strategies"][stratname] = {} + current = csv.loc[ + (csv["Method"] == stratname + "100") & (csv["Test"] == "Pooled Test") + ] + current = current.reset_index() try: - isnan = np.isnan(v) - except TypeError: - isnan = False - if not (isnan): - has_corresp_opt = [ - str(v) == str(opt_class) for opt_class in optimizers_classes + idx = current["Metric"].idxmax() + except ValueError: + print(f"For dataset {dname} missing {stratname} !!!") + continue + best_hyperparams = current.iloc[idx][ + [ + col + for col in current.columns + if col not in ["Test", "Method", "Metric"] ] + ].to_dict() + best_hyperparams.pop("index") + for k, v in best_hyperparams.items(): + try: + isnan = np.isnan(v) + except TypeError: + isnan = False + if not (isnan): + has_corresp_opt = [ + str(v) == str(opt_class) for opt_class in optimizers_classes + ] + + if any(has_corresp_opt): + v = ( + "torch.optim." + + optimizers_classes[has_corresp_opt.index(True)].__name__ + ) + config["strategies"][stratname][k] = v + results_file_basename = csvf.split("/")[-1].split(".")[0] + root = f"config_{results_file_basename}" + basename = root + ".json" + c = 0 + while os.path.exists(os.path.join(args_cli.extract_to_path, basename)): + basename = root + f"_{c}.json" + c += 1 + with open(os.path.join(args_cli.extract_to_path, basename), "w") as outfile: + json.dump(config, outfile, indent=4, sort_keys=True) + - if any(has_corresp_opt): - v = ( - "torch.optim." - + optimizers_classes[has_corresp_opt.index(True)].__name__ - ) - config["strategies"][stratname][k] = v +if __name__ == "__main__": - with open(f"config_{dname}.json", "w") as outfile: - json.dump(config, outfile, indent=4, sort_keys=True) + parser = argparse.ArgumentParser() + parser.add_argument( + "--path-to-results", + type=str, + default="./results/results.csv", + nargs="+", + help="The path of the file to extract config from.", + ) + parser.add_argument( + "--extract-to-path", + type=str, + default=".", + help="The path where the config will be extracted", + ) + parser.add_argument( + "--dataset-name", + type=str, + default=None, + help="The dataset name of the associated results file." + "If not provided tries to extract it from the results file name.", + nargs="+", + choices=[ + None, + "fed_kits19", + "fed_ixi", + "fed_camelyon16", + "fed_isic2019", + "fed_lidc_idri", + "fed_heart_disease", + "fed_tcga_brca", + ], + ) + args = parser.parse_args() + assert os.path.isdir( + args.extract_to_path + ), "You should provide a path towards a directory" + main(args) diff --git a/flamby/fed_benchmark.py b/flamby/fed_benchmark.py deleted file mode 100644 index a78a680a3..000000000 --- a/flamby/fed_benchmark.py +++ /dev/null @@ -1,729 +0,0 @@ -import argparse -import copy -import os - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import DataLoader as dl -from tqdm import tqdm - -import flamby.strategies as strats -from flamby.conf import check_config, get_dataset_args, get_results_file, get_strategies -from flamby.utils import evaluate_model_on_tests - - -# Only 4 lines to change to evaluate different datasets (except for LIDC where the -# evaluation function is custom) -# Still some datasets might require specific augmentation strategies or collate_fn -# functions in the data loading part -def main(args_cli): - n_gpus = torch.cuda.device_count() - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = str(args_cli.GPU) - # torch.use_deterministic_algorithms(False) - use_gpu = ( - args_cli.GPU in [str(i) for i in range(n_gpus)] - ) and torch.cuda.is_available() - run_num_updates = [100, 500] - - # ensure that the config is ok - config = check_config(args_cli.config_file_path) - - params_list = [ - "BATCH_SIZE", - "LR", - "NUM_CLIENTS", - "NUM_EPOCHS_POOLED", - "Baseline", - "BaselineLoss", - "Optimizer", - "get_nb_max_rounds", - "metric", - "collate_fn", - ] - # get all the dataset args - ( - dataset_name, - FedDataset, - [ - BATCH_SIZE, - LR, - NUM_CLIENTS, - NUM_EPOCHS_POOLED, - Baseline, - BaselineLoss, - Optimizer, - get_nb_max_rounds, - metric, - collate_fn, - ], - ) = get_dataset_args(config, params_list) - results_file = get_results_file(config, path=args_cli.results_file_path) - - # One might need to iterate on the hyperparameters to some extent if performances - # are seriously degraded with default ones - strategy_specific_hp_dicts = get_strategies( - config, learning_rate=LR, args=vars(args_cli) - ) - - init_hp_additional_args = ["Test", "Method", "Metric"] - # We need to add strategy hyperparameters columns to the benchmark - hp_additional_args = [] - - # get all hparam names from all the strategies used - for strategy in strategy_specific_hp_dicts.values(): - hp_additional_args += [ - arg_names - for arg_names in strategy.keys() - if arg_names not in hp_additional_args - ] - # column names used for the results file - columns_names = init_hp_additional_args + hp_additional_args - - # Use the same initialization for everyone in order to be fair - torch.manual_seed(args_cli.seed) - np.random.seed(args_cli.seed) - global_init = Baseline() - # Instantiate all train and test dataloaders required including pooled ones - if dataset_name == "fed_lidc_idri": - batch_size_test = 1 - from flamby.datasets.fed_lidc_idri import evaluate_dice_on_tests_by_chunks - - def evaluate_func(m, test_dls, metric, use_gpu=use_gpu, return_pred=False): - dice_dict = evaluate_dice_on_tests_by_chunks(m, test_dls, use_gpu) - # dice_dict = {f"client_test_{i}": 0.5 for i in range(NUM_CLIENTS)} - if return_pred: - return dice_dict, None, None - return dice_dict - - compute_ensemble_perf = False - elif dataset_name == "fed_kits19": - from flamby.datasets.fed_kits19 import evaluate_dice_on_tests - - batch_size_test = 2 - - def evaluate_func(m, test_dls, metric, use_gpu=use_gpu, return_pred=False): - dice_dict = evaluate_dice_on_tests(m, test_dls, metric, use_gpu) - # dice_dict = {f"client_test_{i}": 0.5 for i in range(NUM_CLIENTS)} - if return_pred: - return dice_dict, None, None - return dice_dict - - compute_ensemble_perf = False - - else: - batch_size_test = None - evaluate_func = evaluate_model_on_tests - compute_ensemble_perf = False - - nb_local_and_ensemble_xps = (NUM_CLIENTS + int(compute_ensemble_perf)) * ( - NUM_CLIENTS + 1 - ) - - training_dls, test_dls = init_data_loaders( - dataset=FedDataset, - pooled=False, - batch_size=BATCH_SIZE, - num_workers=args_cli.workers, - num_clients=NUM_CLIENTS, - batch_size_test=batch_size_test, - collate_fn=collate_fn, - ) - train_pooled, test_pooled = init_data_loaders( - dataset=FedDataset, - pooled=True, - batch_size=BATCH_SIZE, - num_workers=args_cli.workers, - batch_size_test=batch_size_test, - collate_fn=collate_fn, - ) - - # Check if some results are already computed - if results_file.exists(): - df = pd.read_csv(results_file) - # Update df if new hyperparameters added - df = df.reindex( - df.columns.union(columns_names, sort=False), axis="columns", fill_value=None - ) - perf_lines_dicts = df.to_dict("records") - else: - # initialize data frame with the column_names and no data - df = pd.DataFrame(columns=columns_names) - perf_lines_dicts = [] - - do_strategy = True - do_baseline = {"Pooled": True} - for i in range(NUM_CLIENTS): - do_baseline[f"Local {i}"] = True - # Single client baseline computation - if args_cli.single_centric_baseline is not None: - do_baseline = {"Pooled": False} - for i in range(NUM_CLIENTS): - do_baseline[f"Local {i}"] = False - if args_cli.single_centric_baseline == "Pooled": - do_baseline[args_cli.single_centric_baseline] = True - elif args_cli.single_centric_baseline == "Local": - assert args_cli.nlocal in range( - NUM_CLIENTS - ), "The client you chose does not exist" - do_baseline[ - args_cli.single_centric_baseline + " " + str(args_cli.nlocal) - ] = True - # If we do a single-centric baseline we don't do the strategies - do_strategy = False - - # if we give a strategy we compute only the strategy and not the baselines - if args_cli.strategy is not None: - for k, _ in do_baseline.items(): - do_baseline[k] = False - - do_all_local = all([do_baseline[f"Local {i}"] for i in range(NUM_CLIENTS)]) - if compute_ensemble_perf and not (do_all_local): - raise ValueError( - "Cannot compute ensemble performance if training on only one local" - ) - - # We use the same set of parameters as found in the corresponding - # flamby/datasets/fed_mydataset/benchmark.py - # Pooled Baseline - # Throughout the experiments we only launch training if we do not have the results - # yet. Note that pooled and local baselines do not use hyperparameters. - index_of_interest = df.loc[df["Method"] == "Pooled Training"].index - # an experiment is finished if there are num_clients + 1 rows - if (len(index_of_interest) < (NUM_CLIENTS + 1)) and do_baseline["Pooled"]: - # dealing with edge case that shouldn't happen - # If some of the rows are there but not all of them we redo the experiments - if len(index_of_interest) > 0: - df.drop(index_of_interest, inplace=True) - perf_lines_dicts = df.to_dict("records") - model = copy.deepcopy(global_init) - if use_gpu: - model.cuda() - bloss = BaselineLoss() - opt = Optimizer(model.parameters(), lr=LR) - print("Pooled") - for _ in tqdm(range(NUM_EPOCHS_POOLED)): - for X, y in train_pooled: - if use_gpu: - # use GPU if requested and available - X = X.cuda() - y = y.cuda() - opt.zero_grad() - y_pred = model(X) - loss = bloss(y_pred, y) - loss.backward() - opt.step() - - perf_dict = evaluate_func(model, test_dls, metric, use_gpu=use_gpu) - pooled_perf_dict = evaluate_func(model, [test_pooled], metric, use_gpu=use_gpu) - print("Per-center performance:") - print(perf_dict) - print("Performance on pooled test set:") - print(pooled_perf_dict) - method = "Pooled Training" - for k, v in perf_dict.items(): - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - Test=k, - Metric=v, - Method=method, - learning_rate=str(LR), - optimizer_class=Optimizer, - ) - ) - - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - Test="Pooled Test", - Metric=pooled_perf_dict["client_test_0"], - Method=method, - learning_rate=str(LR), - optimizer_class=Optimizer, - ) - ) - # We update csv and save it when the results are there - df = pd.DataFrame.from_dict(perf_lines_dicts) - df.to_csv(results_file, index=False) - - # Local baselines and ensemble - y_true_dicts = {} - y_pred_dicts = {} - pooled_y_true_dicts = {} - pooled_y_pred_dicts = {} - - # We only launch training if it's not finished already. - index_of_interest = df.loc[df["Method"] == "Local 0"].index - for i in range(1, NUM_CLIENTS): - index_of_interest = index_of_interest.union( - df.loc[df["Method"] == f"Local {i}"].index - ) - index_of_interest = index_of_interest.union(df.loc[df["Method"] == "Ensemble"].index) - # This experiment is finished if there are num_clients + 1 rows in each local - # training and the ensemble training - - if len(index_of_interest) < nb_local_and_ensemble_xps: - # Dealing with edge case that shouldn't happen. - # If some of the rows are there but not all of them we redo the experiments. - for i in range(NUM_CLIENTS): - m = copy.deepcopy(global_init) - if use_gpu: - m = m.cuda() - bloss = BaselineLoss() - opt = Optimizer(m.parameters(), lr=LR) - index_of_interest = df.loc[df["Method"] == f"Local {i}"].index - - if ( - (len(index_of_interest) < (NUM_CLIENTS + 1)) or compute_ensemble_perf - ) and do_baseline[f"Local {i}"]: - if len(index_of_interest) > 0: - df.drop(index_of_interest, inplace=True) - perf_lines_dicts = df.to_dict("records") - print("Local " + str(i)) - for e in tqdm(range(NUM_EPOCHS_POOLED)): - for X, y in training_dls[i]: - if use_gpu: - X = X.cuda() - y = y.cuda() - opt.zero_grad() - y_pred = m(X) - loss = bloss(y_pred, y) - loss.backward() - opt.step() - - ( - perf_dict, - y_true_dicts[f"Local {i}"], - y_pred_dicts[f"Local {i}"], - ) = evaluate_func(m, test_dls, metric, return_pred=True) - ( - pooled_perf_dict, - pooled_y_true_dicts[f"Local {i}"], - pooled_y_pred_dicts[f"Local {i}"], - ) = evaluate_func(m, [test_pooled], metric, return_pred=True) - print("Per-center performance:") - print(perf_dict) - print("Performance on pooled test set:") - print(pooled_perf_dict) - for k, v in perf_dict.items(): - # Make sure there is no weird inplace stuff - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - Test=k, - Metric=v, - Method=f"Local {i}", - learning_rate=str(LR), - optimizer_class=Optimizer, - ) - ) - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - Test="Pooled Test", - Metric=pooled_perf_dict["client_test_0"], - Method=f"Local {i}", - learning_rate=str(LR), - optimizer_class=Optimizer, - ) - ) - # We update csv and save it when the results are there - df = pd.DataFrame.from_dict(perf_lines_dicts) - df.to_csv(results_file, index=False) - - if compute_ensemble_perf: - print("Computing ensemble performance") - for testset in range(NUM_CLIENTS): - for model in range(1, NUM_CLIENTS): - assert ( - y_true_dicts[f"Local {0}"][f"client_test_{testset}"] - == y_true_dicts[f"Local {model}"][f"client_test_{testset}"] - ).all(), "Models in the ensemble have different ground truths" - ensemble_true = y_true_dicts["Local 0"][f"client_test_{testset}"] - ensemble_pred = y_pred_dicts["Local 0"][f"client_test_{testset}"] - for model in range(1, NUM_CLIENTS): - ensemble_pred += y_pred_dicts[f"Local {model}"][ - f"client_test_{testset}" - ] - ensemble_pred /= NUM_CLIENTS - - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - Test=f"client_test_{testset}", - Metric=metric(ensemble_true, ensemble_pred), - Method="Ensemble", - learning_rate=str(LR), - optimizer_class=Optimizer, - ) - ) - # We update csv and save it when the results are there - df = pd.DataFrame.from_dict(perf_lines_dicts) - df.to_csv(results_file, index=False) - # Computing ensemble performance in the pooled case - for model in range(1, NUM_CLIENTS): - assert ( - pooled_y_true_dicts["Local 0"]["client_test_0"] - == pooled_y_true_dicts[f"Local {model}"]["client_test_0"] - ).all(), ( - "Models in the ensemble do not make predictions in the same x order" - ) - pooled_ensemble_true = pooled_y_true_dicts["Local 0"]["client_test_0"] - pooled_ensemble_pred = pooled_y_pred_dicts["Local 0"]["client_test_0"] - for model in range(1, NUM_CLIENTS): - pooled_ensemble_pred += pooled_y_pred_dicts[f"Local {model}"][ - "client_test_0" - ] - pooled_ensemble_pred /= NUM_CLIENTS - - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - Test="Pooled Test", - Metric=metric(pooled_ensemble_true, pooled_ensemble_pred), - Method="Ensemble", - learning_rate=str(LR), - optimizer_class=Optimizer, - ) - ) - - # We update csv and save it when the results are there - df = pd.DataFrame.from_dict(perf_lines_dicts) - df.to_csv(results_file, index=False) - - # Strategies - if do_strategy: - for num_updates in run_num_updates: - for sname in strategy_specific_hp_dicts.keys(): - # Base arguments - m = copy.deepcopy(global_init) - bloss = BaselineLoss() - args = { - "training_dataloaders": training_dls, - "model": m, - "loss": bloss, - "optimizer_class": torch.optim.SGD, - "learning_rate": LR, - "num_updates": num_updates, - "nrounds": get_nb_max_rounds(num_updates), - } - strategy_specific_hp_dict = strategy_specific_hp_dicts[sname] - # Overwriting arguments with strategy specific arguments - for k, v in strategy_specific_hp_dict.items(): - args[k] = v - # We only launch training if it's not finished already. Maybe FL - # hyperparameters need to be tuned. - hyperparameters = {} - for k in hp_additional_args: # columns_names: - if k in args: - hyperparameters[k] = args[k] - else: - hyperparameters[k] = np.nan - # This is very ugly but this is the only way I found to accomodate float - # and objects equality in a robust fashion - found_xps = df[list(hyperparameters)] - found_xps_numerical = found_xps.select_dtypes(exclude=[object]) - if "deterministic_cycle" in found_xps_numerical.columns: - found_xps_numerical["deterministic_cycle"] = ( - found_xps_numerical["deterministic_cycle"] - .fillna(0.0) - .astype(float) - ) - col_numericals = found_xps_numerical.columns - col_objects = [c for c in found_xps.columns if not (c in col_numericals)] - - if len(col_numericals) > 0: - bool_numerical = np.all( - np.isclose( - found_xps_numerical, - pd.Series( - { - k: float(hyperparameters[k]) - for k in list(hyperparameters.keys()) - if k in col_numericals - } - ), - equal_nan=True, - ), - axis=1, - ) - else: - bool_numerical = np.ones((len(df.index), 1)).astype("bool") - - if len(col_objects): - bool_objects = found_xps[col_objects].astype(str) == pd.Series( - { - k: str(hyperparameters[k]) - for k in list(hyperparameters.keys()) - if k in col_objects - } - ) - else: - bool_objects = np.ones((len(df.index), 1)).astype("bool") - - bool_method = df["Method"] == (sname + str(num_updates)) - index_of_interest_1 = df.loc[ - pd.DataFrame(bool_numerical).all(axis=1) - ].index - index_of_interest_2 = df.loc[ - pd.DataFrame(bool_objects).all(axis=1) - ].index - index_of_interest_3 = df.loc[pd.DataFrame(bool_method).all(axis=1)].index - index_of_interest = index_of_interest_1.intersection( - index_of_interest_2 - ).intersection(index_of_interest_3) - - # non-robust version - # index_of_interest = df.loc[ - # (df["Method"] == (sname + str(num_updates))) - # & ( - # df[list(hyperparameters)] == pd.Series(hyperparameters) - # ).all(axis=1) - # ].index - # An experiment is finished if there are num_clients + 1 rows - if len(index_of_interest) < (NUM_CLIENTS + 1): - # Dealing with edge case that shouldn't happen - # If some of the rows are there but not all of them we redo the - # experiments - if len(index_of_interest) > 0: - df.drop(index_of_interest, inplace=True) - perf_lines_dicts = df.to_dict("records") - basename = dataset_name + "-" + sname + f"-num-updates{num_updates}" - for k, v in args.items(): - if k in ["learning_rate", "server_learning_rate"]: - basename += ( - "-" + "".join([e[0] for e in str(k).split("_")]) + str(v) - ) - if k in ["mu", "deterministic_cycle"]: - basename += "-" + str(k) + str(v) - - # We run the FL strategy - s = getattr(strats, sname)( - **args, log=args_cli.log, log_basename=basename - ) - print("FL strategy", sname, " num_updates ", num_updates) - m = s.run()[0] - - perf_dict = evaluate_func(m, test_dls, metric) - pooled_perf_dict = evaluate_func(m, [test_pooled], metric) - print("Per-center performance:") - print(perf_dict) - print("Performance on pooled test set:") - print(pooled_perf_dict) - hyperparams_save = { - k: v - for k, v in hyperparameters.items() - if k not in init_hp_additional_args - } - for k, v in perf_dict.items(): - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - allow_new=True, - Test=k, - Metric=v, - Method=sname + str(num_updates), - # We add the hyperparameters used - **hyperparams_save, - ) - ) - perf_lines_dicts.append( - prepare_dict( - keys=columns_names, - allow_new=True, - Test="Pooled Test", - Metric=pooled_perf_dict["client_test_0"], - Method=sname + str(num_updates), - # We add the hyperparamters used - **hyperparams_save, - ) - ) - - # We update csv and save it when the results are there - df = pd.DataFrame.from_dict(perf_lines_dicts) - df.to_csv(results_file, index=False) - - -def init_data_loaders( - dataset, - pooled=False, - batch_size=1, - num_workers=1, - num_clients=None, - batch_size_test=None, - collate_fn=None, -): - """ - Initializes the data loaders for the training and test datasets. - """ - if (not pooled) and num_clients is None: - raise ValueError("num_clients must be specified for the non-pooled data") - batch_size_test = batch_size if batch_size_test is None else batch_size_test - if not pooled: - training_dls = [ - dl( - dataset(center=i, train=True, pooled=False), - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - collate_fn=collate_fn, - ) - for i in range(num_clients) - ] - test_dls = [ - dl( - dataset(center=i, train=False, pooled=False), - batch_size=batch_size_test, - shuffle=False, - num_workers=num_workers, - collate_fn=collate_fn, - ) - for i in range(num_clients) - ] - return training_dls, test_dls - else: - train_pooled = dl( - dataset(train=True, pooled=True), - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - collate_fn=collate_fn, - ) - test_pooled = dl( - dataset(train=False, pooled=True), - batch_size=batch_size_test, - shuffle=False, - num_workers=num_workers, - collate_fn=collate_fn, - ) - return train_pooled, test_pooled - - -def prepare_dict(keys, allow_new=False, **kwargs): - """ - Prepares the dictionary with the given keys and fills them with the kwargs. - If allow_new is set to False (default) - Kwargs must be one of the keys. If - kwarg is not given for a key the value of that key will be None - """ - if not allow_new: - # ensure all the kwargs are in the columns - assert sum([not (key in keys) for key in kwargs.keys()]) == 0, ( - "Some of the keys given were not found in the existsing columns;" - f"keys: {kwargs.keys()}, columns: {keys}" - ) - - # create the dictionary from the given keys and fill when appropriate with the kwargs - return {**dict.fromkeys(keys), **kwargs} - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument( - "--GPU", - type=str, - default="0", - help="GPU to run the training on (if available)", - ) - parser.add_argument( - "--workers", - type=int, - default=1, - help="Numbers of workers for the dataloader", - ) - parser.add_argument( - "--learning_rate", - "-lr", - type=float, - default=None, - help="Client side learning rate if strategy is given", - ) - parser.add_argument( - "--server_learning_rate", - "-slr", - type=float, - default=None, - help="Server side learning rate if strategy is given", - ) - parser.add_argument( - "--mu", - "-mu", - type=float, - default=None, - help="FedProx mu parameter if strategy is given and that it is FedProx", - ) - parser.add_argument( - "--strategy", - "-s", - type=str, - default=None, - help="If this parameter is chosen will only run this specific strategy", - choices=[ - None, - "FedAdam", - "FedYogi", - "FedAdagrad", - "Scaffold", - "FedAvg", - "Cyclic", - "FedProx", - ], - ) - parser.add_argument( - "--optimizer-class", - "-opt", - type=str, - default="torch.optim.SGD", - help="The optimizer class to use if strategy is given", - ) - parser.add_argument( - "--deterministic", - "-d", - action="store_true", - default=False, - help="whether or not to use deterministic cycling for the cyclic strategy", - ) - parser.add_argument( - "--log", - "-l", - action="store_true", - default=False, - help="Whether or not to log the strategies", - ) - parser.add_argument( - "--config-file-path", - "-cfgp", - default="./config.json", - type=str, - help="Which config file to use.", - ) - parser.add_argument( - "--results-file-path", - "-rfp", - default=None, - type=str, - help="The path to the created results (overwrite the config path)", - ) - parser.add_argument( - "--single-centric-baseline", - "-scb", - default=None, - type=str, - help="Whether or not to compute only one single-centric baseline and which one.", - choices=["Pooled", "Local"], - ) - parser.add_argument( - "--nlocal", - default=0, - type=int, - help="Will only be used if --single-centric-baseline Local, will test" - "only training on Local {nlocal}.", - ) - parser.add_argument("--seed", default=0, type=int, help="Seed") - - args = parser.parse_args() - - main(args) diff --git a/flamby/gpu_utils.py b/flamby/gpu_utils.py new file mode 100644 index 000000000..542ed0327 --- /dev/null +++ b/flamby/gpu_utils.py @@ -0,0 +1,31 @@ +import os + +import torch + + +def use_gpu_idx(idx, cpu_only=False): + """Small util that put computations on the chosen GPU. + + Parameters + ---------- + idx : int + The GPU to use. + cpu_only : bool, optional + Whether to force comptations to be on CPU even in the presence of GPUS, + by default False + + Returns + ------- + bool + Whether we will be using GPU or not. + """ + gpu_detected = torch.cuda.is_available() + if not (gpu_detected) or cpu_only: + return False + else: + n_gpus = torch.cuda.device_count() + assert idx < n_gpus, f"You chose GPU {idx} that does not exist." + # We use environment variables to manage GPU + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(idx) + return True diff --git a/flamby/utils.py b/flamby/utils.py index 48660fa82..e3016842a 100644 --- a/flamby/utils.py +++ b/flamby/utils.py @@ -59,8 +59,9 @@ def evaluate_model_on_tests( y_true_final = np.concatenate(y_true_final) y_pred_final = np.concatenate(y_pred_final) results_dict[f"client_test_{i}"] = metric(y_true_final, y_pred_final) - y_true_dict[f"client_test_{i}"] = y_true_final - y_pred_dict[f"client_test_{i}"] = y_pred_final + if return_pred: + y_true_dict[f"client_test_{i}"] = y_true_final + y_pred_dict[f"client_test_{i}"] = y_pred_final if return_pred: return results_dict, y_true_dict, y_pred_dict else: