Skip to content

Commit

Permalink
Update docsting based on reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Nov 20, 2024
1 parent 29c1b29 commit f7e1d89
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
checkpointer: Optional[ClientCheckpointModule] = None,
deep_mmd_loss_weight: float = 10.0,
feature_extraction_layers_with_size: Optional[Dict[str, int]] = None,
beta_global_update_interval: int = 20,
mmd_kernel_train_interval: int = 20,
num_accumulating_batches: Optional[int] = None,
) -> None:
"""
Expand All @@ -50,13 +50,13 @@ def __init__(
deep_mmd_loss_weight (float, optional): weight applied to the Deep MMD loss. Defaults to 10.0.
feature_extraction_layers_with_size (Optional[Dict[str, int]], optional): Dictionary of layers to extract
features from them and their respective feature size. Defaults to None.
beta_global_update_interval (int, optional): interval at which to update the betas for the MK-MMD loss. If
set to above 0, the betas will be updated based on whole distribution of latent features of data with
the given update interval. If set to 0, the betas will not be updated. If set to -1, the betas will be
updated after each individual batch based on only that individual batch. Defaults to 20.
mmd_kernel_update_interval (int, optional): interval at which to train and update the Deep MMD kernel. If
set to above 0, the kernel will be train based on whole distribution of latent features of data with
the given train interval. If set to 0, the kernal will not be trained. If set to -1, the kernel will
be trained after each individual batch based on only that individual batch. Defaults to 20.
num_accumulating_batches (int, optional): Number of batches to accumulate features to approximate the whole
distribution of the latent features for updating beta of the MK-MMD loss. This parameter is only used
if beta_global_update_interval is set to larger than 0. Defaults to None.
distribution of the latent features for updating Deep MMD kernel. This parameter is only used
if mmd_kernel_train_interval is set to larger than 0. Defaults to None.
"""
super().__init__(
data_path=data_path,
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self.local_feature_extractor: FeatureExtractorBuffer
self.initial_global_feature_extractor: FeatureExtractorBuffer
self.num_accumulating_batches = num_accumulating_batches
self.beta_global_update_interval = beta_global_update_interval
self.mmd_kernel_train_interval = mmd_kernel_train_interval

def setup_client(self, config: Config) -> None:
super().setup_client(config)
Expand All @@ -114,31 +114,33 @@ def update_before_train(self, current_server_round: int) -> None:
)
# Register hooks to extract features from the initial global model if not already registered
self.initial_global_feature_extractor._maybe_register_hooks()
# Enable training of Deep MMD loss layers if the beta_global_update_interval is set to -1
# Enable training of Deep MMD loss layers if the mmd_kernel_train_interval is set to -1
# meaning that the betas will be updated after each individual batch based on only that
# individual batch
if self.beta_global_update_interval == -1:
if self.mmd_kernel_train_interval == -1:
for layer in self.flatten_feature_extraction_layers.keys():
self.deep_mmd_losses[layer].training = True

def _should_optimize_betas(self, step: int) -> bool:
step_at_interval = (step - 1) % self.beta_global_update_interval == 0
step_at_interval = (step - 1) % self.mmd_kernel_train_interval == 0
valid_components_present = self.initial_global_model is not None
# If the Deep MMD loss doesn't matter, we don't bother optimizing betas
weighted_deep_mmd_loss = self.deep_mmd_loss_weight != 0
return step_at_interval and valid_components_present and weighted_deep_mmd_loss

def update_after_step(self, step: int, current_round: Optional[int] = None) -> None:
if self.beta_global_update_interval > 0 and self._should_optimize_betas(step):
if self.mmd_kernel_train_interval > 0 and self._should_optimize_betas(step):
# Get the feature distribution of the local and initial global features with evaluation
# mode
local_distributions, initial_global_distributions = self.update_buffers(
self.model, self.initial_global_model
)
# Update betas for the Deep MMD loss based on gathered features during training
# As we set the training mode of the Deep MMD loss layers to True, we train the
# kernel of the Deep MMD loss based on gathered features in the buffer and compute the
# Deep MMD loss
for layer, layer_deep_mmd_loss in self.deep_mmd_losses.items():
layer_deep_mmd_loss.training = True
_ = layer_deep_mmd_loss(local_distributions[layer], initial_global_distributions[layer])
layer_deep_mmd_loss(local_distributions[layer], initial_global_distributions[layer])
layer_deep_mmd_loss.training = False
super().update_after_step(step)

Expand Down Expand Up @@ -168,8 +170,8 @@ def update_buffers(
initial_state_local_model = local_model.training

# Set local model to evaluation mode, as we don't want to create a computational graph
# for the local model when populating the local buffer with features to compute optimal
# betas for the MK-MMD loss
# for the local model when populating the local buffer with features to train Deep MMD
# kernel
local_model.eval()

# Make sure the local model is in evaluation mode before populating the local buffer
Expand Down Expand Up @@ -287,9 +289,11 @@ def compute_training_loss(
additional losses indexed by name. Additional losses includes each loss component and the global model
loss tensor.
"""
if self.beta_global_update_interval == -1:
for layer_loss_module in self.deep_mmd_losses.values():
for layer_loss_module in self.deep_mmd_losses.values():
if self.mmd_kernel_train_interval == -1:
assert layer_loss_module.training
else:
assert not layer_loss_module.training
# Check that both models are in training mode
assert self.global_model.training and self.model.training

Expand Down

0 comments on commit f7e1d89

Please sign in to comment.