Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Nov 3, 2023
1 parent 8ea9419 commit b27b3eb
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 41 deletions.
4 changes: 2 additions & 2 deletions fl4health/clients/fenda_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def get_perFCL_loss(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
PerFCL loss consists of two contrastive losses.
First one airms to enhance the similarity between the current global features and aggregated global feature
First one aims to enhance the similarity between the current global features and aggregated global features
as positive pairs while reducing the similarity between the current global features and old global
features as negative pairs.
Second one airms to enhance the similarity between the current local features and old local feature
Second one aims to enhance the similarity between the current local features and old local features
as positive pairs while reducing the similarity between the current local features and aggregated global
features as negative pairs.
"""
Expand Down
2 changes: 1 addition & 1 deletion research/flamby/fed_heart_disease/fenda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
type_run: str = "cos_sim",
type_run: str = "vanilla",
) -> None:
super().__init__(
data_path=data_path,
Expand Down
10 changes: 2 additions & 8 deletions research/flamby/fed_heart_disease/fenda/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)


def main(config: Dict[str, Any], server_address: str, run_name: str) -> None:
def main(config: Dict[str, Any], server_address: str) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down Expand Up @@ -77,14 +77,8 @@ def main(config: Dict[str, Any], server_address: str, run_name: str) -> None:
help="Server Address to be used to communicate with the clients",
default="0.0.0.0:8080",
)
parser.add_argument(
"--run_name",
action="store",
help="Name of the run, model checkpoints will be saved under a subfolder with this name",
required=True,
)
args = parser.parse_args()

config = load_config(args.config_path)
log(INFO, f"Server Address: {args.server_address}")
main(config, args.server_address, args.run_name)
main(config, args.server_address)
27 changes: 2 additions & 25 deletions research/flamby/fed_isic2019/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Dict

import flwr as fl
import torch
from flamby.datasets.fed_isic2019 import Baseline
from flwr.common.logger import log
from flwr.server.client_manager import SimpleClientManager
Expand All @@ -22,7 +21,7 @@
)


def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, pretrain: bool) -> None:
def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -37,28 +36,6 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
client_manager = SimpleClientManager()
model = Baseline()

log(INFO, f"if pretrain: {pretrain}")
if pretrain:
dir = (
"/ssd003/projects/aieng/public/FL_env/models/fed_isic2019/fedavg/hp_sweep_results/lr_0.001/"
+ run_name
+ "/server_best_model.pkl"
)
fedavg_model_state = torch.load(dir).state_dict()
model_state = model.state_dict()
matching_state = {}
for k, v in fedavg_model_state.items():
if k in model_state:
if v.size() == model_state[k].size():
matching_state[k] = v
elif model_state[k].size()[1:] == v.size()[1:]:
repeat = model_state[k].size()[0] // v.size()[0]
original_size = tuple([1] * (len(model_state[k].size()) - 1))
matching_state[k] = v.repeat((repeat,) + original_size)
log(INFO, f"matching state: {len(matching_state)}")
model_state.update(matching_state)
model.load_state_dict(model_state)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
min_fit_clients=config["n_clients"],
Expand Down Expand Up @@ -125,4 +102,4 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_

config = load_config(args.config_path)
log(INFO, f"Server Address: {args.server_address}")
main(config, args.server_address, args.artifact_dir, args.run_name, args.pretrain)
main(config, args.server_address, args.artifact_dir, args.run_name)
2 changes: 1 addition & 1 deletion research/flamby/fed_isic2019/fenda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
type_run: str = "cos_sim",
type_run: str = "vanilla",
) -> None:
super().__init__(
data_path=data_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ do
--client_number ${c} \
--learning_rate ${CLIENT_LR} \
--server_address ${SERVER_ADDRESS} \
--type_run cos_sim \
> ${CLIENT_LOG_PATH} 2>&1 &
done

Expand Down
4 changes: 2 additions & 2 deletions research/flamby/fed_isic2019/fenda/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)


def main(config: Dict[str, Any], server_address: str, run_name: str) -> None:
def main(config: Dict[str, Any], server_address: str) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down Expand Up @@ -87,4 +87,4 @@ def main(config: Dict[str, Any], server_address: str, run_name: str) -> None:

config = load_config(args.config_path)
log(INFO, f"Server Address: {args.server_address}")
main(config, args.server_address, args.run_name)
main(config, args.server_address)
2 changes: 1 addition & 1 deletion research/flamby/fed_ixi/fenda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
metric_meter_type: MetricMeterType = MetricMeterType.ACCUMULATION,
checkpointer: Optional[TorchCheckpointer] = None,
type_run: str = "cos_sim",
type_run: str = "vanilla",
) -> None:

super().__init__(
Expand Down

0 comments on commit b27b3eb

Please sign in to comment.