diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl index 19d1e20bacc..415898ba117 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl @@ -2,7 +2,6 @@ import os import warnings -from collections import OrderedDict from typing import Dict, Tuple import torch @@ -11,7 +10,7 @@ from flwr.common import Context from flwr.common.config import unflatten_dict from flwr.common.typing import NDArrays, Scalar from omegaconf import DictConfig -from peft import get_peft_model_state_dict, set_peft_model_state_dict + from transformers import TrainingArguments from trl import SFTTrainer @@ -20,7 +19,12 @@ from $import_name.dataset import ( load_data, replace_keys, ) -from $import_name.models import cosine_annealing, get_model +from $import_name.models import ( + cosine_annealing, + get_model, + set_parameters, + get_parameters, +) # Avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -92,20 +96,6 @@ class FlowerClient(NumPyClient): ) -def set_parameters(model, parameters: NDArrays) -> None: - """Change the parameters of the model using the given ones.""" - peft_state_dict_keys = get_peft_model_state_dict(model).keys() - params_dict = zip(peft_state_dict_keys, parameters) - state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) - set_peft_model_state_dict(model, state_dict) - - -def get_parameters(model) -> NDArrays: - """Return the parameters of the current net.""" - state_dict = get_peft_model_state_dict(model) - return [val.cpu().numpy() for _, val in state_dict.items()] - - def client_fn(context: Context) -> FlowerClient: """Create a Flower client representing a single organization.""" partition_id = context.node_config["partition-id"] diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl index a548ba9abee..3f3f95c8b8e 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl @@ -4,10 +4,18 @@ import math import torch from omegaconf import DictConfig -from peft import LoraConfig, get_peft_model +from collections import OrderedDict +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) from peft.utils import prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from flwr.common.typing import NDArrays + def cosine_annealing( current_round: int, @@ -54,3 +62,17 @@ def get_model(model_cfg: DictConfig): model.config.use_cache = False return get_peft_model(model, peft_config) + + +def set_parameters(model, parameters: NDArrays) -> None: + """Change the parameters of the model using the given ones.""" + peft_state_dict_keys = get_peft_model_state_dict(model).keys() + params_dict = zip(peft_state_dict_keys, parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + set_peft_model_state_dict(model, state_dict) + + +def get_parameters(model) -> NDArrays: + """Return the parameters of the current net.""" + state_dict = get_peft_model_state_dict(model) + return [val.cpu().numpy() for _, val in state_dict.items()] diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl index 586b929be06..7d4de0f73db 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl @@ -8,8 +8,7 @@ from flwr.common.config import unflatten_dict from flwr.server import ServerApp, ServerAppComponents, ServerConfig from omegaconf import DictConfig -from $import_name.client_app import get_parameters, set_parameters -from $import_name.models import get_model +from $import_name.models import get_model, get_parameters, set_parameters from $import_name.dataset import replace_keys from $import_name.strategy import FlowerTuneLlm diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 5046a6f89f2..8c15739e92f 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -8,7 +8,7 @@ version = "1.0.0" description = "" license = "Apache-2.0" dependencies = [ - "flwr[simulation]>=1.10.0", + "flwr[simulation]>=1.11.1", "flwr-datasets>=0.3.0", "trl==0.8.1", "bitsandbytes==0.43.0",