Skip to content

Commit

Permalink
refactor(framework) Update get/set parameter functions for FlowerTune…
Browse files Browse the repository at this point in the history
… template (#4217)
  • Loading branch information
yan-gao-GY authored Sep 16, 2024
1 parent 968fb4b commit d94b22d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
24 changes: 7 additions & 17 deletions src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import warnings
from collections import OrderedDict
from typing import Dict, Tuple

import torch
Expand All @@ -11,7 +10,7 @@ from flwr.common import Context
from flwr.common.config import unflatten_dict
from flwr.common.typing import NDArrays, Scalar
from omegaconf import DictConfig
from peft import get_peft_model_state_dict, set_peft_model_state_dict

from transformers import TrainingArguments
from trl import SFTTrainer

Expand All @@ -20,7 +19,12 @@ from $import_name.dataset import (
load_data,
replace_keys,
)
from $import_name.models import cosine_annealing, get_model
from $import_name.models import (
cosine_annealing,
get_model,
set_parameters,
get_parameters,
)

# Avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "true"
Expand Down Expand Up @@ -92,20 +96,6 @@ class FlowerClient(NumPyClient):
)


def set_parameters(model, parameters: NDArrays) -> None:
"""Change the parameters of the model using the given ones."""
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
params_dict = zip(peft_state_dict_keys, parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
set_peft_model_state_dict(model, state_dict)


def get_parameters(model) -> NDArrays:
"""Return the parameters of the current net."""
state_dict = get_peft_model_state_dict(model)
return [val.cpu().numpy() for _, val in state_dict.items()]


def client_fn(context: Context) -> FlowerClient:
"""Create a Flower client representing a single organization."""
partition_id = context.node_config["partition-id"]
Expand Down
24 changes: 23 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@ import math

import torch
from omegaconf import DictConfig
from peft import LoraConfig, get_peft_model
from collections import OrderedDict
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from peft.utils import prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from flwr.common.typing import NDArrays


def cosine_annealing(
current_round: int,
Expand Down Expand Up @@ -54,3 +62,17 @@ def get_model(model_cfg: DictConfig):
model.config.use_cache = False

return get_peft_model(model, peft_config)


def set_parameters(model, parameters: NDArrays) -> None:
"""Change the parameters of the model using the given ones."""
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
params_dict = zip(peft_state_dict_keys, parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
set_peft_model_state_dict(model, state_dict)


def get_parameters(model) -> NDArrays:
"""Return the parameters of the current net."""
state_dict = get_peft_model_state_dict(model)
return [val.cpu().numpy() for _, val in state_dict.items()]
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ from flwr.common.config import unflatten_dict
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from omegaconf import DictConfig

from $import_name.client_app import get_parameters, set_parameters
from $import_name.models import get_model
from $import_name.models import get_model, get_parameters, set_parameters
from $import_name.dataset import replace_keys
from $import_name.strategy import FlowerTuneLlm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.10.0",
"flwr[simulation]>=1.11.1",
"flwr-datasets>=0.3.0",
"trl==0.8.1",
"bitsandbytes==0.43.0",
Expand Down

0 comments on commit d94b22d

Please sign in to comment.