From 6779867ec9891e70fc7364a0420aba3a2c606bca Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 1 Jul 2024 11:37:52 +0800 Subject: [PATCH 01/50] Add new dataset&trainer for 2.2 Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/pdss/pdss_trainer.py | 239 ++++++++++++++++++ .../fate_llm/dataset/input_output_dataset.py | 113 +++++++++ 2 files changed, 352 insertions(+) create mode 100644 python/fate_llm/algo/pdss/pdss_trainer.py create mode 100644 python/fate_llm/dataset/input_output_dataset.py diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py new file mode 100644 index 0000000..1c34954 --- /dev/null +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -0,0 +1,239 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pickle +import time +from torch import nn +from fate.ml.aggregator.base import Aggregator +from dataclasses import dataclass +from typing import List, Optional, Callable, Literal +from fate.arch import Context +from torch.utils.data import DataLoader, Dataset +from transformers.trainer_callback import TrainerCallback +from transformers import PreTrainedTokenizer +import logging +import torch +import torch.distributed as dist +from fate_llm.dataset.pdss_dataset import PrefixDataset +from transformers.modeling_utils import unwrap_model +from transformers import PreTrainedTokenizer, PreTrainedModel +from typing import Dict, Any +from transformers import Seq2SeqTrainingArguments +from transformers.trainer_utils import EvalPrediction +from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments +from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer + + +logger = logging.getLogger(__name__) +_MODE = ['train_only', 'inferdpt_only', 'inferdpt_and_train'] + + +# share obj between ranks in an easy way +def save_to(obj, filepath, filename='tmp.pkl'): + if not os.path.exists(filepath): + os.mkdir(filepath) + path = filepath + filename + with open(path, 'wb') as f: + pickle.dump(obj, f) + dist.barrier() + os.remove(path) + + +def load(filepath, filename='tmp.pkl'): + path = filepath + filename + while not os.path.exists(path): + time.sleep(0.1) + while True: + try: + with open(path, 'rb') as f: + d = pickle.load(f) + break + except (EOFError, pickle.UnpicklingError): + time.sleep(0.1) + + dist.barrier() + return d + + +class DSSTrainerClient(Seq2SeqTrainer): + + def __init__(self, + ctx: Context, + model: nn.Module, + training_args: Seq2SeqTrainingArguments, + train_set: Dataset, + val_set: Dataset = None, + alpha: float = 0.5, + optimizer: torch.optim.Optimizer = None, + data_collator: Callable = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + callbacks: Optional[List[TrainerCallback]] = [], + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + + self.alpha = alpha + self.ctx = ctx + Seq2SeqTrainer.__init__( + self, + model=model, + args=training_args, + train_dataset=train_set, + eval_dataset=val_set, + data_collator=data_collator, + optimizers=(optimizer, scheduler), + tokenizer=tokenizer, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + compute_metrics=compute_metrics, + callbacks=callbacks, + ) + + def compute_loss(self, model, inputs, return_outputs=False): + + label_outputs = model(**inputs['predict']) + cot_outputs = model(**inputs['rationale']) + loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss + return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss + + +class PDSSTrainerClient(DSSTrainerClient): + + def __init__(self, + ctx: Context, + training_args: Seq2SeqTrainingArguments, + train_set: PrefixDataset, + val_set: Dataset = None, + model: nn.Module = None, + optimizer: torch.optim.Optimizer = None, + data_collator: Callable = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + callbacks: Optional[List[TrainerCallback]] = [], + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + alpha: float = 0.5, + mode: Literal['train_only', 'inferdpt_only', 'inferdpt_and_train'] = 'inferdpt_and_train', + inferdpt_client: InferDPTClient = None, + doc_template: str = None, + instruction_template: str = None, + decode_template: str = None, + result_key: str = 'inferdpt_result', + verbose: bool = False, + remote_inference_kwargs: dict = {}, + local_inference_kwargs: dict = {}, + ) -> None: + + self.mode = mode + self.inferdpt_client = inferdpt_client + self.inferdpt_result = None + self.inferdpt_predict_kwargs = { + 'doc_template': doc_template, + 'instruction_template': instruction_template, + 'decode_template': decode_template, + 'result_key': result_key, + 'verbose': verbose, + 'remote_inference_kwargs': remote_inference_kwargs, + 'local_inference_kwargs': local_inference_kwargs + } + self.inferdpt_result = None + + assert mode in _MODE, "mode should be one of {}".format(_MODE) + if training_args.local_rank == 0: + if mode == 'inferdpt_only' or mode == 'inferdpt_and_train': + if self.inferdpt_client is None: + raise ValueError('You must provide inference_inst for remote inference') + + if mode != 'inferdpt_only': + training_args.remove_unused_columns = False # this parameter is neccessary + DSSTrainerClient.__init__( + self, + ctx=ctx, + model=model, + training_args=training_args, + train_set=train_set, + val_set=val_set, + data_collator=data_collator, + optimizer=optimizer, + scheduler=scheduler, + tokenizer=tokenizer, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + compute_metrics=compute_metrics, + callbacks=callbacks, + alpha=alpha + ) + else: + # skip trainer initialzation becuase training is not needed + self.args = training_args + self.train_dataset = train_set + + def inferdpt(self) -> List[str]: + + if self.args.local_rank == 0: # other rank will skip federation step + assert isinstance(self.train_dataset, PrefixDataset), "train_set should be an instance of PrefixDataset" + dict_dataset = self.train_dataset.get_raw_dataset() + inferdpt_result = self.inferdpt_client.predict(dict_dataset, **self.inferdpt_predict_kwargs) + self.inferdpt_result = inferdpt_result + rationale_list = [i[self.inferdpt_predict_kwargs['result_key']] for i in self.inferdpt_result] + self.train_dataset.load_rationale(rationale_list) + logger.info('Rationale loaded: {}'.format(rationale_list)) + + if self.mode == 'inferdpt_and_train': + if self.args.world_size > 1: # sync dataset with other ranks + print('scattering obj') + save_to(rationale_list, self.args.output_dir) + + if self.args.local_rank > 0: + if self.mode == 'inferdpt_and_train': + # wait until inferdpt is done + print('waiting for obj') + rationale_list = load(self.args.output_dir) + self.train_dataset.load_rationale(rationale_list) + logger.info('Rationale loaded') + logger.info('inferdpt done') + + def train(self): + + if self.mode == 'train_only': + logger.info("Train only mode") + super().train() + elif self.mode == 'inferdpt_only': + logger.info("Inferdpt only mode, skip training") + self.inferdpt() + elif self.mode == 'inferdpt_and_train': + logger.info("Inferdpt and train mode") + self.inferdpt() + super().train() + + def get_inferdpt_result(self): + return self.inferdpt_result + + +class PDSSTraineServer(object): + + def __init__(self, ctx: Context): + super().__init__() + self.ctx = ctx + + def train(self): + pass + + def predict(self): + pass + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/python/fate_llm/dataset/input_output_dataset.py b/python/fate_llm/dataset/input_output_dataset.py new file mode 100644 index 0000000..59308cc --- /dev/null +++ b/python/fate_llm/dataset/input_output_dataset.py @@ -0,0 +1,113 @@ +from fate.ml.nn.dataset.base import Dataset +from transformers.trainer_pt_utils import LabelSmoother +from typing import List, Dict, Union, Literal +import logging +from jinja2 import Template +from transformers import AutoTokenizer + + +logger = logging.getLogger(__name__) + + +class InputOutputDataset(Dataset): + + def __init__(self, + tokenizer_path, + input_template: str = '', + output_template: str = '', + max_input_length: int = 256, + max_target_length: int = 256, + pad_token: int = -100, + load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk', + split_key: str = None + ): + + super().__init__() + self.tokenizer = None + self.tokenizer_path = tokenizer_path + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True) + self.max_input_length = max_input_length + self.max_target_length = max_target_length + self.dataset = None + self.load_from = load_from + self.input_template = Template(input_template) + self.output_template = Template(output_template) + self.pad_token = pad_token + self.split_key = split_key + self.max_seq_length = max_input_length + max_target_length + 1 + + def load(self, path): + if self.load_from == 'hf_load_from_disk': + import datasets + self.dataset = [i for i in datasets.load_from_disk(path)] + if self.split_key is not None: + self.dataset = self.dataset[self.split_key] + elif self.load_from == 'jsonl': + import json + with open(path, 'r') as f: + json_lines = f.read().split('\n') + self.dataset = [] + for i in json_lines: + try: + self.dataset.append(json.loads(i)) + except: + print('skip line') + elif self.load_from == 'hf_load_dataset': + from datasets import load_dataset + self.dataset = load_dataset(path) + if self.split_key is not None: + self.dataset = self.dataset[self.split_key] + else: + raise ValueError('unknown load format') + + if not isinstance(self.dataset, list) or not isinstance(self.dataset[0], dict): + logger.warn('loaded dataset is expected to be a list of dict') + + def get_raw_dataset(self): + return self.dataset + + def __len__(self): + return len(self.dataset) + + def get_str_item(self, i) -> dict: + + data_item = self.dataset[i] + in_ = self.input_template.render(**data_item) + out_ = self.output_template.render(**data_item) + return { + 'input': in_, + 'output': out_ + } + + def _process_item(self, data_item): + + a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True, + max_length=self.max_input_length) + b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True, + max_length=self.max_target_length) + + context_length = len(a_ids) + input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id] + labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id] + + pad_len = self.max_seq_length - len(input_ids) + input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len + labels = labels + [self.tokenizer.pad_token_id] * pad_len + labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels] + + assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}" + + return { + "input_ids": input_ids, + "labels": labels + } + + def get_tokenized_item(self, i) -> dict: + + str_item = self.get_str_item(i) + ret_dict = self._process_item(str_item) + return ret_dict + + def __getitem__(self, i) -> dict: + item = self.get_tokenized_item(i) + return item From 5b5f9c69e9bf3a815f5a19a5f7dd6dff5f6b1b41 Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 1 Jul 2024 14:54:35 +0800 Subject: [PATCH 02/50] Add pdss dataset Signed-off-by: weijingchen Signed-off-by: cwj --- .../fate_llm/dataset/input_output_dataset.py | 6 +- python/fate_llm/dataset/pdss_dataset.py | 62 +++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 python/fate_llm/dataset/pdss_dataset.py diff --git a/python/fate_llm/dataset/input_output_dataset.py b/python/fate_llm/dataset/input_output_dataset.py index 59308cc..bf4978c 100644 --- a/python/fate_llm/dataset/input_output_dataset.py +++ b/python/fate_llm/dataset/input_output_dataset.py @@ -13,11 +13,10 @@ class InputOutputDataset(Dataset): def __init__(self, tokenizer_path, - input_template: str = '', - output_template: str = '', + input_template: str, + output_template: str, max_input_length: int = 256, max_target_length: int = 256, - pad_token: int = -100, load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk', split_key: str = None ): @@ -32,7 +31,6 @@ def __init__(self, self.load_from = load_from self.input_template = Template(input_template) self.output_template = Template(output_template) - self.pad_token = pad_token self.split_key = split_key self.max_seq_length = max_input_length + max_target_length + 1 diff --git a/python/fate_llm/dataset/pdss_dataset.py b/python/fate_llm/dataset/pdss_dataset.py new file mode 100644 index 0000000..545b32b --- /dev/null +++ b/python/fate_llm/dataset/pdss_dataset.py @@ -0,0 +1,62 @@ +from fate_llm.dataset.input_output_dataset import InputOutputDataset +from transformers.trainer_pt_utils import LabelSmoother +from typing import List, Dict, Union, Literal +import logging +from jinja2 import Template +from transformers import AutoTokenizer + + + +logger = logging.getLogger(__name__) + + +class PrefixDataset(InputOutputDataset): + + def __init__(self, + tokenizer_path, + predict_template_input: str, + predict_template_output: str, + rationale_template_input: str, + rationale_template_output: str, + max_input_length: int = 256, + max_target_length: int = 256, + load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk', + split_key: str = None + ): + + super().__init__(tokenizer_path, predict_template_input, predict_template_output, max_input_length, max_target_length, load_from, split_key) + self.r_input_template = Template(rationale_template_input) + self.r_output_template = Template(rationale_template_output) + + def load_rationale(self, result_list): + for d, r in zip(self.dataset, result_list): + d['rationale'] = r + + def get_str_item(self, i) -> dict: + + data_item = self.dataset[i] + p_in = self.input_template.render(data_item) + p_out = self.output_template.render(data_item) + r_in = self.r_input_template.render(data_item) + r_out = self.r_output_template.render(data_item) + ret_dict = { + 'predict':{ + 'input': p_in, + 'output': p_out + }, + 'rationale':{ + 'input': r_in, + 'output': r_out + } + } + return ret_dict + + def get_tokenized_item(self, i) -> dict: + + str_item = self.get_str_item(i) + ret_dict = { + 'predict': self._process_item(str_item['predict']), + 'rationale': self._process_item(str_item['rationale']) + } + + return ret_dict From 96ea64f2228a11012676eeb88a328d7d81bb02ec Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 1 Jul 2024 16:05:16 +0800 Subject: [PATCH 03/50] Update trainer & dataset Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/pdss/pdss_trainer.py | 6 ++---- python/fate_llm/dataset/pdss_dataset.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index 1c34954..906159d 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -72,7 +72,6 @@ def load(filepath, filename='tmp.pkl'): class DSSTrainerClient(Seq2SeqTrainer): def __init__(self, - ctx: Context, model: nn.Module, training_args: Seq2SeqTrainingArguments, train_set: Dataset, @@ -88,7 +87,6 @@ def __init__(self, ) -> None: self.alpha = alpha - self.ctx = ctx Seq2SeqTrainer.__init__( self, model=model, @@ -186,7 +184,7 @@ def inferdpt(self) -> List[str]: if self.args.local_rank == 0: # other rank will skip federation step assert isinstance(self.train_dataset, PrefixDataset), "train_set should be an instance of PrefixDataset" dict_dataset = self.train_dataset.get_raw_dataset() - inferdpt_result = self.inferdpt_client.predict(dict_dataset, **self.inferdpt_predict_kwargs) + inferdpt_result = self.inferdpt_client.inference(dict_dataset, **self.inferdpt_predict_kwargs) self.inferdpt_result = inferdpt_result rationale_list = [i[self.inferdpt_predict_kwargs['result_key']] for i in self.inferdpt_result] self.train_dataset.load_rationale(rationale_list) @@ -236,4 +234,4 @@ def predict(self): pass if __name__ == '__main__': - pass \ No newline at end of file + pass diff --git a/python/fate_llm/dataset/pdss_dataset.py b/python/fate_llm/dataset/pdss_dataset.py index 545b32b..2c4dae7 100644 --- a/python/fate_llm/dataset/pdss_dataset.py +++ b/python/fate_llm/dataset/pdss_dataset.py @@ -6,7 +6,6 @@ from transformers import AutoTokenizer - logger = logging.getLogger(__name__) @@ -14,19 +13,19 @@ class PrefixDataset(InputOutputDataset): def __init__(self, tokenizer_path, - predict_template_input: str, - predict_template_output: str, - rationale_template_input: str, - rationale_template_output: str, + predict_input_template: str, + predict_output_template: str, + rationale_input_template: str, + rationale_output_template: str, max_input_length: int = 256, max_target_length: int = 256, load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk', split_key: str = None ): - super().__init__(tokenizer_path, predict_template_input, predict_template_output, max_input_length, max_target_length, load_from, split_key) - self.r_input_template = Template(rationale_template_input) - self.r_output_template = Template(rationale_template_output) + super().__init__(tokenizer_path, predict_input_template, predict_output_template, max_input_length, max_target_length, load_from, split_key) + self.r_input_template = Template(rationale_input_template) + self.r_output_template = Template(rationale_output_template) def load_rationale(self, result_list): for d, r in zip(self.dataset, result_list): From 7a345904fe251f56383e86e203cf7a6117d6e97a Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 2 Jul 2024 15:19:01 +0800 Subject: [PATCH 04/50] Update pdss inference Signed-off-by: weijingchen Signed-off-by: cwj --- .../encoder_decoder/slm_encoder_decoder.py | 79 +++++++++++++++++++ python/fate_llm/algo/pdss/pdss_trainer.py | 6 +- 2 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py diff --git a/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py b/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py new file mode 100644 index 0000000..f6eca62 --- /dev/null +++ b/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py @@ -0,0 +1,79 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy +from jinja2 import Template +from tqdm import tqdm +from fate.arch import Context +from typing import List, Dict, Union +from fate.ml.nn.dataset.base import Dataset +from fate_llm.algo.inferdpt.utils import InferDPTKit +from openai import OpenAI +import logging +from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer +from fate_llm.dataset.hf_dataset import HuggingfaceDataset + + +logger = logging.getLogger(__name__) + + +class SLMEncoderDecoderClient(InferDPTClient): + + def __init__(self, ctx: Context, local_inference_inst: Inference) -> None: + self.ctx = ctx + self.comm_idx = 0 + self.local_inference_inst = local_inference_inst + self.local_inference_kwargs = {} + + def encode(self, docs: List[Dict[str, str]], format_template: str = None, verbose=False, perturb_doc_key: str ='perturbed_doc') -> List[Dict[str, str]]: + + template = Template(format_template) + copy_docs = copy.deepcopy(docs) + doc_to_infer = [] + for doc in tqdm(copy_docs): + rendered_doc = template.render(**doc) + doc_to_infer.append(rendered_doc) + # perturb using local model inference + self.doc_to_infer = doc_to_infer + infer_result = self.local_inference_inst.inference(doc_to_infer, self.local_inference_kwargs) + for doc, pr in zip(copy_docs, infer_result): + doc[perturb_doc_key] = pr + self.doc_with_p = copy_docs + return copy_docs + + def decode(self, p_docs: List[Dict[str, str]], instruction_template: str = None, decode_template: str = None, verbose=False, + perturbed_response_key: str = 'perturbed_response', result_key: str = 'result', + remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}): + return super().decode(p_docs, instruction_template, decode_template, verbose, perturbed_response_key, result_key, remote_inference_kwargs, local_inference_kwargs) + + def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset], + encode_template: str, + instruction_template: str, + decode_template: str, + verbose: bool = False, + remote_inference_kwargs: dict = {}, + local_inference_kwargs: dict = {}, + perturb_doc_key: str = 'perturbed_doc', + perturbed_response_key: str = 'perturbed_response', + result_key: str = 'result', + ) -> List[Dict[str, str]]: + self.local_inference_kwargs = local_inference_kwargs + return super().inference(docs, encode_template, instruction_template, decode_template, verbose, remote_inference_kwargs, \ + local_inference_kwargs, perturb_doc_key, perturbed_response_key, result_key) + + +class SLMEncoderDecoderServer(InferDPTServer): + pass diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index 906159d..aa1de77 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -35,7 +35,7 @@ from transformers.trainer_utils import EvalPrediction from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments from fate_llm.algo.inferdpt.inference.inference_base import Inference -from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer +from fate_llm.algo.inferdpt._encode_decode import EncoderDecoder logger = logging.getLogger(__name__) @@ -125,8 +125,8 @@ def __init__(self, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, alpha: float = 0.5, - mode: Literal['train_only', 'inferdpt_only', 'inferdpt_and_train'] = 'inferdpt_and_train', - inferdpt_client: InferDPTClient = None, + mode: Literal['train_only', 'infer_only', 'infer_and_train'] = 'infer_and_train', + inferdpt_client: EncoderDecoder = None, doc_template: str = None, instruction_template: str = None, decode_template: str = None, From f20a50de37950789c833ad3143324de666b53ca4 Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 3 Jul 2024 10:30:59 +0800 Subject: [PATCH 05/50] Update trainer, add collator Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/pdss/pdss_trainer.py | 68 +++++++++---------- .../data/data_collator/pdss_collator.py | 13 ++++ 2 files changed, 47 insertions(+), 34 deletions(-) create mode 100644 python/fate_llm/data/data_collator/pdss_collator.py diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index aa1de77..b08255d 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -17,9 +17,7 @@ import pickle import time from torch import nn -from fate.ml.aggregator.base import Aggregator -from dataclasses import dataclass -from typing import List, Optional, Callable, Literal +from typing import List, Optional, Callable, Literal, Union from fate.arch import Context from torch.utils.data import DataLoader, Dataset from transformers.trainer_callback import TrainerCallback @@ -35,11 +33,12 @@ from transformers.trainer_utils import EvalPrediction from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments from fate_llm.algo.inferdpt.inference.inference_base import Inference -from fate_llm.algo.inferdpt._encode_decode import EncoderDecoder +from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer +from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer logger = logging.getLogger(__name__) -_MODE = ['train_only', 'inferdpt_only', 'inferdpt_and_train'] +_MODE = ['train_only', 'infer_only', 'infer_and_train'] # share obj between ranks in an easy way @@ -126,20 +125,20 @@ def __init__(self, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, alpha: float = 0.5, mode: Literal['train_only', 'infer_only', 'infer_and_train'] = 'infer_and_train', - inferdpt_client: EncoderDecoder = None, + infer_client: Union[SLMEncoderDecoderClient, InferDPTClient] = None, doc_template: str = None, instruction_template: str = None, decode_template: str = None, - result_key: str = 'inferdpt_result', + result_key: str = 'infer_result', verbose: bool = False, remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}, ) -> None: self.mode = mode - self.inferdpt_client = inferdpt_client - self.inferdpt_result = None - self.inferdpt_predict_kwargs = { + self.infer_client = infer_client + self.infer_result = None + self.infer_predict_kwargs = { 'doc_template': doc_template, 'instruction_template': instruction_template, 'decode_template': decode_template, @@ -148,15 +147,15 @@ def __init__(self, 'remote_inference_kwargs': remote_inference_kwargs, 'local_inference_kwargs': local_inference_kwargs } - self.inferdpt_result = None + self.infer_result = None assert mode in _MODE, "mode should be one of {}".format(_MODE) if training_args.local_rank == 0: - if mode == 'inferdpt_only' or mode == 'inferdpt_and_train': - if self.inferdpt_client is None: + if mode == 'infer_only' or mode == 'infer_and_train': + if self.infer_client is None: raise ValueError('You must provide inference_inst for remote inference') - if mode != 'inferdpt_only': + if mode != 'infer_only': training_args.remove_unused_columns = False # this parameter is neccessary DSSTrainerClient.__init__( self, @@ -179,59 +178,60 @@ def __init__(self, self.args = training_args self.train_dataset = train_set - def inferdpt(self) -> List[str]: + def infer(self) -> List[str]: if self.args.local_rank == 0: # other rank will skip federation step assert isinstance(self.train_dataset, PrefixDataset), "train_set should be an instance of PrefixDataset" dict_dataset = self.train_dataset.get_raw_dataset() - inferdpt_result = self.inferdpt_client.inference(dict_dataset, **self.inferdpt_predict_kwargs) - self.inferdpt_result = inferdpt_result - rationale_list = [i[self.inferdpt_predict_kwargs['result_key']] for i in self.inferdpt_result] + infer_result = self.infer_client.inference(dict_dataset, **self.infer_predict_kwargs) + self.infer_result = infer_result + rationale_list = [i[self.infer_predict_kwargs['result_key']] for i in self.infer_result] self.train_dataset.load_rationale(rationale_list) logger.info('Rationale loaded: {}'.format(rationale_list)) - if self.mode == 'inferdpt_and_train': + if self.mode == 'infer_and_train': if self.args.world_size > 1: # sync dataset with other ranks print('scattering obj') save_to(rationale_list, self.args.output_dir) if self.args.local_rank > 0: - if self.mode == 'inferdpt_and_train': - # wait until inferdpt is done + if self.mode == 'infer_and_train': + # wait until infer is done print('waiting for obj') rationale_list = load(self.args.output_dir) self.train_dataset.load_rationale(rationale_list) logger.info('Rationale loaded') - logger.info('inferdpt done') + logger.info('infer done') def train(self): if self.mode == 'train_only': logger.info("Train only mode") super().train() - elif self.mode == 'inferdpt_only': - logger.info("Inferdpt only mode, skip training") - self.inferdpt() - elif self.mode == 'inferdpt_and_train': - logger.info("Inferdpt and train mode") - self.inferdpt() + elif self.mode == 'infer_only': + logger.info("infer only mode, skip training") + self.infer() + elif self.mode == 'infer_and_train': + logger.info("infer and train mode") + self.infer() super().train() - def get_inferdpt_result(self): - return self.inferdpt_result + def get_infer_result(self): + return self.infer_result class PDSSTraineServer(object): - def __init__(self, ctx: Context): + def __init__(self, ctx: Context, infer_server: Union[SLMEncoderDecoderServer, InferDPTServer]): super().__init__() self.ctx = ctx + self.infer_server = infer_server def train(self): - pass + logger.info('Server side start inference') + self.infer_server.inference() + logger.info('Server inference done') - def predict(self): - pass if __name__ == '__main__': pass diff --git a/python/fate_llm/data/data_collator/pdss_collator.py b/python/fate_llm/data/data_collator/pdss_collator.py new file mode 100644 index 0000000..2a8007a --- /dev/null +++ b/python/fate_llm/data/data_collator/pdss_collator.py @@ -0,0 +1,13 @@ +from transformers import DataCollatorForSeq2Seq +import pandas as pd + +class PrefixDataCollator(DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + features_df = pd.DataFrame(features) + cot = super().__call__(list(features_df['predict']), return_tensors) + label = super().__call__(list(features_df['rationale']), return_tensors) + + return { + 'predict': cot, + 'rationale': label + } \ No newline at end of file From 9e4f06eca00bd1b3eddd7202348cabdab417d445 Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 3 Jul 2024 16:45:10 +0800 Subject: [PATCH 06/50] Update pdss Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/inferdpt/inferdpt.py | 4 ++-- python/fate_llm/algo/pdss/pdss_trainer.py | 15 ++++++--------- python/fate_llm/dataset/pdss_dataset.py | 4 ++-- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/python/fate_llm/algo/inferdpt/inferdpt.py b/python/fate_llm/algo/inferdpt/inferdpt.py index 25641ae..ea99cf2 100644 --- a/python/fate_llm/algo/inferdpt/inferdpt.py +++ b/python/fate_llm/algo/inferdpt/inferdpt.py @@ -113,7 +113,7 @@ def decode(self, p_docs: List[Dict[str, str]], instruction_template: str = None, return docs_with_infer_result def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset], - doc_template: str, + encode_template: str, instruction_template: str, decode_template: str, verbose: bool = False, @@ -128,7 +128,7 @@ def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset], # perturb doc if isinstance(docs, HuggingfaceDataset): docs = [docs[i] for i in range(len(docs))] - docs_with_p = self.encode(docs, format_template=doc_template, verbose=verbose, perturb_doc_key=perturb_doc_key) + docs_with_p = self.encode(docs, format_template=encode_template, verbose=verbose, perturb_doc_key=perturb_doc_key) logger.info('encode done') # inference using perturbed doc final_result = self.decode( diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index b08255d..14ed944 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -126,7 +126,7 @@ def __init__(self, alpha: float = 0.5, mode: Literal['train_only', 'infer_only', 'infer_and_train'] = 'infer_and_train', infer_client: Union[SLMEncoderDecoderClient, InferDPTClient] = None, - doc_template: str = None, + encode_template: str = None, instruction_template: str = None, decode_template: str = None, result_key: str = 'infer_result', @@ -139,7 +139,7 @@ def __init__(self, self.infer_client = infer_client self.infer_result = None self.infer_predict_kwargs = { - 'doc_template': doc_template, + 'encode_template': encode_template, 'instruction_template': instruction_template, 'decode_template': decode_template, 'result_key': result_key, @@ -159,7 +159,6 @@ def __init__(self, training_args.remove_unused_columns = False # this parameter is neccessary DSSTrainerClient.__init__( self, - ctx=ctx, model=model, training_args=training_args, train_set=train_set, @@ -186,22 +185,20 @@ def infer(self) -> List[str]: infer_result = self.infer_client.inference(dict_dataset, **self.infer_predict_kwargs) self.infer_result = infer_result rationale_list = [i[self.infer_predict_kwargs['result_key']] for i in self.infer_result] - self.train_dataset.load_rationale(rationale_list) - logger.info('Rationale loaded: {}'.format(rationale_list)) - + self.train_dataset.load_rationale(rationale_list, key=self.infer_predict_kwargs['result_key']) + logger.info('infer done') if self.mode == 'infer_and_train': if self.args.world_size > 1: # sync dataset with other ranks - print('scattering obj') + logger.info('scattering obj') save_to(rationale_list, self.args.output_dir) if self.args.local_rank > 0: if self.mode == 'infer_and_train': # wait until infer is done - print('waiting for obj') + logger.info('waiting for obj') rationale_list = load(self.args.output_dir) self.train_dataset.load_rationale(rationale_list) logger.info('Rationale loaded') - logger.info('infer done') def train(self): diff --git a/python/fate_llm/dataset/pdss_dataset.py b/python/fate_llm/dataset/pdss_dataset.py index 2c4dae7..bee8d99 100644 --- a/python/fate_llm/dataset/pdss_dataset.py +++ b/python/fate_llm/dataset/pdss_dataset.py @@ -27,9 +27,9 @@ def __init__(self, self.r_input_template = Template(rationale_input_template) self.r_output_template = Template(rationale_output_template) - def load_rationale(self, result_list): + def load_rationale(self, result_list, key='rationale'): for d, r in zip(self.dataset, result_list): - d['rationale'] = r + d[key] = r def get_str_item(self, i) -> dict: From e129e2ac14439621a8d9bb41dd8a1f11d07f73d5 Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 3 Jul 2024 18:07:16 +0800 Subject: [PATCH 07/50] Update pdss runner, developing Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/pdss/pdss_trainer.py | 2 +- python/fate_llm/runner/inferdpt_runner.py | 6 +- python/fate_llm/runner/pdss_runner.py | 224 ++++++++++++++++++++++ 3 files changed, 228 insertions(+), 4 deletions(-) create mode 100644 python/fate_llm/runner/pdss_runner.py diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index 14ed944..7a7a1f4 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -153,7 +153,7 @@ def __init__(self, if training_args.local_rank == 0: if mode == 'infer_only' or mode == 'infer_and_train': if self.infer_client is None: - raise ValueError('You must provide inference_inst for remote inference') + raise ValueError('You must provide an inference instance for remote inference') if mode != 'infer_only': training_args.remove_unused_columns = False # this parameter is neccessary diff --git a/python/fate_llm/runner/inferdpt_runner.py b/python/fate_llm/runner/inferdpt_runner.py index 12ec5a3..0c204f0 100644 --- a/python/fate_llm/runner/inferdpt_runner.py +++ b/python/fate_llm/runner/inferdpt_runner.py @@ -47,7 +47,7 @@ class InferDPTRunner(NNRunner): def __init__( self, inferdpt_init_conf: Dict, - doc_template: str = None, + encode_template: str = None, instruction_template: str = None, decode_template: str = None, dataset_conf: Optional[Dict] = None, @@ -58,7 +58,7 @@ def __init__( result_key: str = 'inferdpt_result', ) -> None: self.inferdpt_init_conf = inferdpt_init_conf - self.doc_template = doc_template + self.encode_template = encode_template self.instruction_template = instruction_template self.decode_template = decode_template self.dataset_conf = dataset_conf @@ -128,7 +128,7 @@ def train( logger.info('initializing inst') client_inst = self.client_setup() pred_rs = client_inst.inference( - dataset_0, self.doc_template, self.instruction_template, self.decode_template, \ + dataset_0, self.encode_template, self.instruction_template, self.decode_template, \ remote_inference_kwargs=self.remote_inference_kwargs, local_inference_kwargs=self.local_inference_kwargs ) diff --git a/python/fate_llm/runner/pdss_runner.py b/python/fate_llm/runner/pdss_runner.py new file mode 100644 index 0000000..7232c12 --- /dev/null +++ b/python/fate_llm/runner/pdss_runner.py @@ -0,0 +1,224 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fate.components.components.nn.nn_runner import ( + NNRunner, + load_model_dict_from_path, + dir_warning, + loader_load_from_conf, + run_dataset_func, +) +from fate.components.components.nn.runner.homo_default_runner import DefaultRunner +from fate.components.components.nn.loader import Loader +from fate.ml.nn.trainer.trainer_base import HomoTrainerServer +from fate.arch.dataframe import DataFrame +from typing import Dict +from fate_llm.algo.pdss.pdss_trainer import PDSSTrainerClient, PDSSTraineServer +from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer +from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer +import torch.nn as nn +import torch.optim as optim +from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments, HomoSeq2SeqTrainerClient +from typing import Union, Type, Callable, Optional +from transformers.trainer_utils import get_last_checkpoint +from typing import Literal +import logging + + +logger = logging.getLogger(__name__) + + +def _check_instances( + model: nn.Module = None, + optimizer: optim.Optimizer = None, + train_args: Seq2SeqTrainingArguments = None, + data_collator: Callable = None, +) -> None: + + if model is not None and not issubclass(type(model), nn.Module): + raise TypeError(f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}") + + if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer): + raise TypeError( + f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}" + ) + + if train_args is not None and not isinstance(train_args, Seq2SeqTrainingArguments): + raise TypeError( + f"SetupReturn Error: train_args must be an instance of Seq2SeqTrainingArguments " + f"but got {type(train_args)}" + ) + + if data_collator is not None and not callable(data_collator): + raise TypeError(f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}") + + +class Seq2SeqRunner(DefaultRunner): + def __init__( + self, + model_conf: Optional[Dict] = None, + dataset_conf: Optional[Dict] = None, + optimizer_conf: Optional[Dict] = None, + training_args_conf: Optional[Dict] = None, + data_collator_conf: Optional[Dict] = None, + tokenizer_conf: Optional[Dict] = None, + mode: Literal['train_only', 'infer_only', 'infer_and_train'] = False, + infer_client_init_conf: Dict = None, + encode_template: str = None, + instruction_template: str = None, + decode_template: str = None, + remote_inference_kwargs: Dict = {}, + local_inference_kwargs: Dict = {}, + perturb_doc_key: str = 'perturbed_doc', + perturbed_response_key: str = 'perturbed_response', + result_key: str = 'infer_result', + ) -> None: + super(NNRunner, self).__init__() + self.model_conf = model_conf + self.dataset_conf = dataset_conf + self.optimizer_conf = optimizer_conf + self.training_args_conf = training_args_conf + self.data_collator_conf = data_collator_conf + self.mode = mode + self.tokenizer_conf = tokenizer_conf + self.infer_client_init_conf = infer_client_init_conf + self.encode_template = encode_template + self.instruction_template = instruction_template + self.decode_template = decode_template + self.remote_inference_kwargs = remote_inference_kwargs + self.local_inference_kwargs = local_inference_kwargs + self.perturb_doc_key = perturb_doc_key + self.perturbed_response_key = perturbed_response_key + self.result_key = result_key + + assert isinstance(self.local_mode, bool), "local should be bool" + # setup var + self.trainer = None + self.training_args = None + + + def _get_inferdpt_inst(self, init_conf): + loader = Loader.from_dict(init_conf) + init_inst = loader.load_item()(self.get_context()) + assert isinstance(init_inst, InferDPTInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) + inferdpt_inst = init_inst.get_inferdpt_inst() + logger.info('inferdpt inst loaded') + return inferdpt_inst + + + def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"): + + ctx = self.get_context() + model = loader_load_from_conf(self.model_conf) + + if model is None: + raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") + if output_dir is None: + output_dir = "./" + + resume_path = None + if saved_model is not None: + model_dict = load_model_dict_from_path(saved_model) + model.load_state_dict(model_dict) + logger.info(f"loading model dict from {saved_model} to model done") + if get_last_checkpoint(saved_model) is not None: + resume_path = saved_model + logger.info(f"checkpoint detected, resume_path set to {resume_path}") + + # load optimizer + if self.optimizer_conf: + optimizer_loader = Loader.from_dict(self.optimizer_conf) + optimizer_ = optimizer_loader.load_item() + optimizer_params = optimizer_loader.kwargs + optimizer = optimizer_(model.parameters(), **optimizer_params) + else: + optimizer = None + + # load collator func + data_collator = loader_load_from_conf(self.data_collator_conf) + + # load tokenizer if import conf provided + tokenizer = loader_load_from_conf(self.tokenizer_conf) + + # args + dir_warning(self.training_args_conf) + training_args = Seq2SeqTrainingArguments(**self.training_args_conf) + self.training_args = training_args + # reset to default, saving to arbitrary path is not allowed in + # DefaultRunner + training_args.output_dir = output_dir + training_args.resume_from_checkpoint = resume_path # resume path + + # prepare trainer + trainer = PDSSTrainerClient( + ctx=ctx, + training_args=training_args, + train_set=train_set, + val_set=validate_set, + model=model, + tokenizer=tokenizer, + mode=self.mode, + encode_template=self.encode_template, + decode_template=self.decode_template, + instruction_template=self.instruction_template, + local_inference_kwargs=self.local_inference_kwargs, + remote_inference_kwargs=self.remote_inference_kwargs + ) + + return trainer + + def server_setup(self, stage="train"): + trainer = None + return trainer + + def predict(self, test_data: Union[str, DataFrame], saved_model_path: str = None) -> Union[DataFrame, None]: + if self.is_client(): + test_set = self._prepare_data(test_data, "test_data") + if self.trainer is not None: + trainer = self.trainer + logger.info("trainer found, skip setting up") + else: + trainer = self.client_setup(saved_model=saved_model_path, stage="predict") + + classes = run_dataset_func(test_set, "get_classes") + match_ids = run_dataset_func(test_set, "get_match_ids") + sample_ids = run_dataset_func(test_set, "get_sample_ids") + match_id_name = run_dataset_func(test_set, "get_match_id_name") + sample_id_name = run_dataset_func(test_set, "get_sample_id_name") + + if not self.training_args.predict_with_generate: + return + + pred_rs = trainer.predict(test_set) + + if self.training_args and self.training_args.deepspeed and self.training_args.local_rank != 0: + return + + rs_df = self.get_nn_output_dataframe( + self.get_context(), + pred_rs.predictions, + pred_rs.label_ids if hasattr(pred_rs, "label_ids") else None, + match_ids, + sample_ids, + match_id_name=match_id_name, + sample_id_name=sample_id_name, + dataframe_format="dist_df", + task_type=self.task_type, + classes=classes, + ) + return rs_df + else: + # server not predict + return From 306cf69c00cfc43077ce09d5bc0d00dc25b12784 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 4 Jul 2024 10:35:09 +0800 Subject: [PATCH 08/50] Update interface Signed-off-by: weijingchen Signed-off-by: cwj --- doc/tutorial/inferdpt/inferdpt_tutorial.ipynb | 14 +++-- python/fate_llm/algo/inferdpt/init/_init.py | 5 +- .../algo/inferdpt/init/default_init.py | 12 ++-- .../pdss/encoder_decoder/init/default_init.py | 56 +++++++++++++++++++ python/fate_llm/runner/inferdpt_runner.py | 10 ++-- python/fate_llm/runner/pdss_runner.py | 8 ++- 6 files changed, 83 insertions(+), 22 deletions(-) create mode 100644 python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py diff --git a/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb b/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb index e853319..5d51738 100644 --- a/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb +++ b/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb @@ -693,13 +693,14 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.init._init import InferDPTInit\n", + "from fate_llm.algo.inferdpt.init._init import InferClientInit\n", "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", "from fate_llm.algo.inferdpt import inferdpt\n", "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", + "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", "\n", "\n", - "class InferDPTAPIClientInit(InferDPTInit):\n", + "class InferDPTAPIClientInit(InferClientInit):\n", "\n", " api_url = ''\n", " api_model_name = ''\n", @@ -711,14 +712,14 @@ " super().__init__(ctx)\n", " self.ctx = ctx\n", "\n", - " def get_inferdpt_inst(self):\n", + " def get_inst(self)-> InferDPTClient:\n", " inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n", " kit = InferDPTKit.load_from_path(self.inferdpt_kit_path)\n", " inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps)\n", " return inferdpt_client\n", "\n", "\n", - "class InferDPTAPIServerInit(InferDPTInit):\n", + "class InferDPTAPIServerInit(InferClientInit):\n", "\n", " api_url = ''\n", " api_model_name = ''\n", @@ -728,10 +729,11 @@ " super().__init__(ctx)\n", " self.ctx = ctx\n", "\n", - " def get_inferdpt_inst(self):\n", + " def get_inst(self)-> InferDPTServer:\n", " inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n", " inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference)\n", - " return inferdpt_server\n" + " return inferdpt_server\n", + " " ] }, { diff --git a/python/fate_llm/algo/inferdpt/init/_init.py b/python/fate_llm/algo/inferdpt/init/_init.py index e57e830..8781771 100644 --- a/python/fate_llm/algo/inferdpt/init/_init.py +++ b/python/fate_llm/algo/inferdpt/init/_init.py @@ -15,15 +15,14 @@ # from fate.arch import Context -from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer from typing import Union -class InferDPTInit(object): +class InferClientInit(object): def __init__(self, ctx: Context): self.ctx = ctx - def get_inferdpt_inst(self) -> Union[InferDPTClient, InferDPTServer]: + def get_inst(self): pass diff --git a/python/fate_llm/algo/inferdpt/init/default_init.py b/python/fate_llm/algo/inferdpt/init/default_init.py index 30648cb..bca3ef9 100644 --- a/python/fate_llm/algo/inferdpt/init/default_init.py +++ b/python/fate_llm/algo/inferdpt/init/default_init.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from fate_llm.algo.inferdpt.init._init import InferDPTInit +from fate_llm.algo.inferdpt.init._init import InferClientInit from fate_llm.algo.inferdpt.inference.api import APICompletionInference from fate_llm.algo.inferdpt import inferdpt from fate_llm.algo.inferdpt.utils import InferDPTKit +from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer -class InferDPTAPIClientInit(InferDPTInit): +class InferDPTAPIClientInit(InferClientInit): api_url = '' api_model_name = '' @@ -32,14 +32,14 @@ def __init__(self, ctx): super().__init__(ctx) self.ctx = ctx - def get_inferdpt_inst(self): + def get_inst(self)-> InferDPTClient: inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) kit = InferDPTKit.load_from_path(self.inferdpt_kit_path) inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps) return inferdpt_client -class InferDPTAPIServerInit(InferDPTInit): +class InferDPTAPIServerInit(InferClientInit): api_url = '' api_model_name = '' @@ -49,7 +49,7 @@ def __init__(self, ctx): super().__init__(ctx) self.ctx = ctx - def get_inferdpt_inst(self): + def get_inst(self)-> InferDPTServer: inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference) return inferdpt_server diff --git a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py new file mode 100644 index 0000000..44d1fee --- /dev/null +++ b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py @@ -0,0 +1,56 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from fate_llm.algo.inferdpt.init._init import InferClientInit +from fate_llm.algo.inferdpt.inference.api import APICompletionInference +from fate_llm.algo.inferdpt import inferdpt +from fate_llm.algo.inferdpt.utils import InferDPTKit + + + +class InferDPTAPIClientInit(InferClientInit): + + api_url = '' + api_model_name = '' + api_key = 'EMPTY' + inferdpt_kit_path = '' + eps = 3.0 + + def __init__(self, ctx): + super().__init__(ctx) + self.ctx = ctx + + def get_inferdpt_inst(self): + inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) + kit = InferDPTKit.load_from_path(self.inferdpt_kit_path) + inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps) + return inferdpt_client + + +class InferDPTAPIServerInit(InferClientInit): + + api_url = '' + api_model_name = '' + api_key = 'EMPTY' + + def __init__(self, ctx): + super().__init__(ctx) + self.ctx = ctx + + def get_inferdpt_inst(self): + inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) + inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference) + return inferdpt_server diff --git a/python/fate_llm/runner/inferdpt_runner.py b/python/fate_llm/runner/inferdpt_runner.py index 0c204f0..96c2fc8 100644 --- a/python/fate_llm/runner/inferdpt_runner.py +++ b/python/fate_llm/runner/inferdpt_runner.py @@ -32,7 +32,7 @@ from typing import Literal import logging from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer -from fate_llm.algo.inferdpt.init._init import InferDPTInit +from fate_llm.algo.inferdpt.init._init import InferClientInit from fate.components.components.nn.loader import Loader from fate_llm.dataset.hf_dataset import HuggingfaceDataset, Dataset from fate.arch.dataframe import DataFrame @@ -71,18 +71,18 @@ def __init__( def _get_inferdpt_inst(self): loader = Loader.from_dict(self.inferdpt_init_conf) init_inst = loader.load_item()(self.get_context()) - assert isinstance(init_inst, InferDPTInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) - inferdpt_inst = init_inst.get_inferdpt_inst() + assert isinstance(init_inst, InferClientInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) + inferdpt_inst = init_inst.get_inst() logger.info('inferdpt inst loaded') return inferdpt_inst def client_setup(self): - client_inst = self._get_inferdpt_inst() + client_inst = self._get_inst() assert isinstance(client_inst, InferDPTClient), 'Client need to get an InferDPTClient class to run the algo' return client_inst def server_setup(self): - server_inst = self._get_inferdpt_inst() + server_inst = self._get_inst() assert isinstance(server_inst, InferDPTServer), 'Server need to get an InferDPTServer class to run the algo' return server_inst diff --git a/python/fate_llm/runner/pdss_runner.py b/python/fate_llm/runner/pdss_runner.py index 7232c12..431e937 100644 --- a/python/fate_llm/runner/pdss_runner.py +++ b/python/fate_llm/runner/pdss_runner.py @@ -109,7 +109,9 @@ def __init__( self.training_args = None - def _get_inferdpt_inst(self, init_conf): + def _get_inferclient_inst(self, init_conf): + if init_conf is None: + return None loader = Loader.from_dict(init_conf) init_inst = loader.load_item()(self.get_context()) assert isinstance(init_inst, InferDPTInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) @@ -174,7 +176,9 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved decode_template=self.decode_template, instruction_template=self.instruction_template, local_inference_kwargs=self.local_inference_kwargs, - remote_inference_kwargs=self.remote_inference_kwargs + remote_inference_kwargs=self.remote_inference_kwargs, + data_collator=data_collator, + optimizer=optimizer ) return trainer From 25a9e469639e6cc896c71b0d453b9b61c78ad678 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 4 Jul 2024 11:33:17 +0800 Subject: [PATCH 09/50] Add PDSS runner Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/inferdpt/init/_init.py | 2 +- .../algo/inferdpt/init/default_init.py | 6 +- .../pdss/encoder_decoder/init/default_init.py | 22 ++-- python/fate_llm/runner/inferdpt_runner.py | 4 +- python/fate_llm/runner/pdss_runner.py | 120 ++++++++++-------- 5 files changed, 83 insertions(+), 71 deletions(-) diff --git a/python/fate_llm/algo/inferdpt/init/_init.py b/python/fate_llm/algo/inferdpt/init/_init.py index 8781771..aaa4074 100644 --- a/python/fate_llm/algo/inferdpt/init/_init.py +++ b/python/fate_llm/algo/inferdpt/init/_init.py @@ -18,7 +18,7 @@ from typing import Union -class InferClientInit(object): +class InferInit(object): def __init__(self, ctx: Context): self.ctx = ctx diff --git a/python/fate_llm/algo/inferdpt/init/default_init.py b/python/fate_llm/algo/inferdpt/init/default_init.py index bca3ef9..bdcb56d 100644 --- a/python/fate_llm/algo/inferdpt/init/default_init.py +++ b/python/fate_llm/algo/inferdpt/init/default_init.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_llm.algo.inferdpt.init._init import InferClientInit +from fate_llm.algo.inferdpt.init._init import InferInit from fate_llm.algo.inferdpt.inference.api import APICompletionInference from fate_llm.algo.inferdpt import inferdpt from fate_llm.algo.inferdpt.utils import InferDPTKit from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer -class InferDPTAPIClientInit(InferClientInit): +class InferDPTAPIClientInit(InferInit): api_url = '' api_model_name = '' @@ -39,7 +39,7 @@ def get_inst(self)-> InferDPTClient: return inferdpt_client -class InferDPTAPIServerInit(InferClientInit): +class InferDPTAPIServerInit(InferInit): api_url = '' api_model_name = '' diff --git a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py index 44d1fee..0de0f54 100644 --- a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py +++ b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py @@ -14,33 +14,28 @@ # limitations under the License. # -from fate_llm.algo.inferdpt.init._init import InferClientInit +from fate_llm.algo.inferdpt.init._init import InferInit from fate_llm.algo.inferdpt.inference.api import APICompletionInference -from fate_llm.algo.inferdpt import inferdpt -from fate_llm.algo.inferdpt.utils import InferDPTKit +from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer - -class InferDPTAPIClientInit(InferClientInit): +class PDSSEDAPIClientInit(InferInit): api_url = '' api_model_name = '' api_key = 'EMPTY' - inferdpt_kit_path = '' - eps = 3.0 def __init__(self, ctx): super().__init__(ctx) self.ctx = ctx - def get_inferdpt_inst(self): + def get_inst(self): inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) - kit = InferDPTKit.load_from_path(self.inferdpt_kit_path) - inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps) - return inferdpt_client + client = SLMEncoderDecoderClient(self.ctx, inference) + return client -class InferDPTAPIServerInit(InferClientInit): +class PDSSEDAPIServerInit(InferInit): api_url = '' api_model_name = '' @@ -52,5 +47,4 @@ def __init__(self, ctx): def get_inferdpt_inst(self): inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) - inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference) - return inferdpt_server + return SLMEncoderDecoderServer(self.ctx, inference) diff --git a/python/fate_llm/runner/inferdpt_runner.py b/python/fate_llm/runner/inferdpt_runner.py index 96c2fc8..20ec20c 100644 --- a/python/fate_llm/runner/inferdpt_runner.py +++ b/python/fate_llm/runner/inferdpt_runner.py @@ -32,7 +32,7 @@ from typing import Literal import logging from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer -from fate_llm.algo.inferdpt.init._init import InferClientInit +from fate_llm.algo.inferdpt.init._init import InferInit from fate.components.components.nn.loader import Loader from fate_llm.dataset.hf_dataset import HuggingfaceDataset, Dataset from fate.arch.dataframe import DataFrame @@ -71,7 +71,7 @@ def __init__( def _get_inferdpt_inst(self): loader = Loader.from_dict(self.inferdpt_init_conf) init_inst = loader.load_item()(self.get_context()) - assert isinstance(init_inst, InferClientInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) + assert isinstance(init_inst, InferInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) inferdpt_inst = init_inst.get_inst() logger.info('inferdpt inst loaded') return inferdpt_inst diff --git a/python/fate_llm/runner/pdss_runner.py b/python/fate_llm/runner/pdss_runner.py index 431e937..c748b57 100644 --- a/python/fate_llm/runner/pdss_runner.py +++ b/python/fate_llm/runner/pdss_runner.py @@ -18,19 +18,17 @@ load_model_dict_from_path, dir_warning, loader_load_from_conf, - run_dataset_func, ) -from fate.components.components.nn.runner.homo_default_runner import DefaultRunner from fate.components.components.nn.loader import Loader -from fate.ml.nn.trainer.trainer_base import HomoTrainerServer from fate.arch.dataframe import DataFrame +from fate.ml.nn.dataset.base import Dataset from typing import Dict from fate_llm.algo.pdss.pdss_trainer import PDSSTrainerClient, PDSSTraineServer from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer -from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer +from fate_llm.algo.inferdpt.init._init import InferInit import torch.nn as nn import torch.optim as optim -from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments, HomoSeq2SeqTrainerClient +from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments from typing import Union, Type, Callable, Optional from transformers.trainer_utils import get_last_checkpoint from typing import Literal @@ -65,7 +63,7 @@ def _check_instances( raise TypeError(f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}") -class Seq2SeqRunner(DefaultRunner): +class Seq2SeqRunner(NNRunner): def __init__( self, model_conf: Optional[Dict] = None, @@ -75,7 +73,7 @@ def __init__( data_collator_conf: Optional[Dict] = None, tokenizer_conf: Optional[Dict] = None, mode: Literal['train_only', 'infer_only', 'infer_and_train'] = False, - infer_client_init_conf: Dict = None, + infer_inst_init_conf: Dict = None, encode_template: str = None, instruction_template: str = None, decode_template: str = None, @@ -93,7 +91,7 @@ def __init__( self.data_collator_conf = data_collator_conf self.mode = mode self.tokenizer_conf = tokenizer_conf - self.infer_client_init_conf = infer_client_init_conf + self.infer_inst_init_conf = infer_inst_init_conf self.encode_template = encode_template self.instruction_template = instruction_template self.decode_template = decode_template @@ -109,15 +107,43 @@ def __init__( self.training_args = None - def _get_inferclient_inst(self, init_conf): + def _get_infer_inst(self, init_conf): if init_conf is None: return None loader = Loader.from_dict(init_conf) init_inst = loader.load_item()(self.get_context()) - assert isinstance(init_inst, InferDPTInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) - inferdpt_inst = init_inst.get_inferdpt_inst() + assert isinstance(init_inst, InferInit), 'Need a InferInit class for initialization, but got {}'.format(type(init_inst)) + infer_inst = init_inst.get_inst() logger.info('inferdpt inst loaded') - return inferdpt_inst + return infer_inst + + + def _prepare_data(self, data, data_name): + if data is None: + return None + if isinstance(data, DataFrame) and self.dataset_conf is None: + raise RuntimeError('DataFrame format dataset is not supported, please use bind path to load your dataset') + else: + dataset = loader_load_from_conf(self.dataset_conf) + if hasattr(dataset, "load"): + logger.info("load path is {}".format(data)) + load_output = dataset.load(data) + if load_output is not None: + dataset = load_output + return dataset + else: + raise ValueError( + f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \ + Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \ + for the necessary interfaces to implement." + ) + if dataset is not None and not issubclass(type(dataset), Dataset): + raise TypeError( + f"SetupReturn Error: {data_name}_set must be a subclass of fate built-in Dataset but got {type(dataset)}, \n" + f"You can get the class via: from fate.ml.nn.dataset.table import Dataset" + ) + + return dataset def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"): @@ -178,51 +204,43 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved local_inference_kwargs=self.local_inference_kwargs, remote_inference_kwargs=self.remote_inference_kwargs, data_collator=data_collator, - optimizer=optimizer + optimizer=optimizer, + infer_client=self._get_infer_inst(self.infer_inst_init_conf) ) return trainer def server_setup(self, stage="train"): - trainer = None + trainer = PDSSTraineServer( + ctx=self.get_context(), + infer_server=self._get_infer_inst(self.infer_inst_init_conf) + ) return trainer - def predict(self, test_data: Union[str, DataFrame], saved_model_path: str = None) -> Union[DataFrame, None]: + def train( + self, + train_data: Optional[Union[str]] = None, + validate_data: Optional[Union[str]] = None, + output_dir: str = None, + saved_model_path: str = None, + ): if self.is_client(): - test_set = self._prepare_data(test_data, "test_data") - if self.trainer is not None: - trainer = self.trainer - logger.info("trainer found, skip setting up") - else: - trainer = self.client_setup(saved_model=saved_model_path, stage="predict") - - classes = run_dataset_func(test_set, "get_classes") - match_ids = run_dataset_func(test_set, "get_match_ids") - sample_ids = run_dataset_func(test_set, "get_sample_ids") - match_id_name = run_dataset_func(test_set, "get_match_id_name") - sample_id_name = run_dataset_func(test_set, "get_sample_id_name") - - if not self.training_args.predict_with_generate: - return - - pred_rs = trainer.predict(test_set) - - if self.training_args and self.training_args.deepspeed and self.training_args.local_rank != 0: - return - - rs_df = self.get_nn_output_dataframe( - self.get_context(), - pred_rs.predictions, - pred_rs.label_ids if hasattr(pred_rs, "label_ids") else None, - match_ids, - sample_ids, - match_id_name=match_id_name, - sample_id_name=sample_id_name, - dataframe_format="dist_df", - task_type=self.task_type, - classes=classes, + train_set = self._prepare_data(train_data, "train_data") + validate_set = self._prepare_data(validate_data, "val_data") + trainer = self.client_setup( + train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path ) - return rs_df - else: - # server not predict - return + self.trainer = trainer + trainer.train() + if output_dir is not None: + if self.training_args.deepspeed and self.training_args.local_rank != 0: + pass + else: + trainer.save_model(output_dir) + elif self.is_server(): + trainer = self.server_setup() + trainer.train() + + def predict(self, test_data: Union[str], saved_model_path: str = None) -> None: + logger.warning('The prediction mode is not supported by this algorithm in the current version. Please perform inference using locally saved models.') + return \ No newline at end of file From 36f6b2513a3b24fbc784fe3b35d3575f90459dc6 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 4 Jul 2024 16:26:16 +0800 Subject: [PATCH 10/50] Update import Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/data/data_collator/pdss_collator.py | 9 ++++++++- python/fate_llm/dataset/input_output_dataset.py | 4 +++- python/fate_llm/runner/pdss_runner.py | 11 ++++++----- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/fate_llm/data/data_collator/pdss_collator.py b/python/fate_llm/data/data_collator/pdss_collator.py index 2a8007a..568276b 100644 --- a/python/fate_llm/data/data_collator/pdss_collator.py +++ b/python/fate_llm/data/data_collator/pdss_collator.py @@ -1,4 +1,5 @@ from transformers import DataCollatorForSeq2Seq +from transformers import AutoTokenizer import pandas as pd class PrefixDataCollator(DataCollatorForSeq2Seq): @@ -10,4 +11,10 @@ def __call__(self, features, return_tensors=None): return { 'predict': cot, 'rationale': label - } \ No newline at end of file + } + + +def get_prefix_data_collator(tokenizer_name_or_path): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + data_collator = PrefixDataCollator(tokenizer) + return data_collator diff --git a/python/fate_llm/dataset/input_output_dataset.py b/python/fate_llm/dataset/input_output_dataset.py index bf4978c..aa59988 100644 --- a/python/fate_llm/dataset/input_output_dataset.py +++ b/python/fate_llm/dataset/input_output_dataset.py @@ -37,9 +37,10 @@ def __init__(self, def load(self, path): if self.load_from == 'hf_load_from_disk': import datasets - self.dataset = [i for i in datasets.load_from_disk(path)] + self.dataset = datasets.load_from_disk(path) if self.split_key is not None: self.dataset = self.dataset[self.split_key] + self.dataset = [i for i in self.dataset] elif self.load_from == 'jsonl': import json with open(path, 'r') as f: @@ -55,6 +56,7 @@ def load(self, path): self.dataset = load_dataset(path) if self.split_key is not None: self.dataset = self.dataset[self.split_key] + self.dataset = [i for i in self.dataset] else: raise ValueError('unknown load format') diff --git a/python/fate_llm/runner/pdss_runner.py b/python/fate_llm/runner/pdss_runner.py index c748b57..15599ca 100644 --- a/python/fate_llm/runner/pdss_runner.py +++ b/python/fate_llm/runner/pdss_runner.py @@ -19,6 +19,7 @@ dir_warning, loader_load_from_conf, ) +from fate_llm.model_zoo.hf_model import HFAutoModelForCausalLM from fate.components.components.nn.loader import Loader from fate.arch.dataframe import DataFrame from fate.ml.nn.dataset.base import Dataset @@ -63,7 +64,7 @@ def _check_instances( raise TypeError(f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}") -class Seq2SeqRunner(NNRunner): +class PDSSRunner(NNRunner): def __init__( self, model_conf: Optional[Dict] = None, @@ -101,12 +102,10 @@ def __init__( self.perturbed_response_key = perturbed_response_key self.result_key = result_key - assert isinstance(self.local_mode, bool), "local should be bool" # setup var self.trainer = None self.training_args = None - def _get_infer_inst(self, init_conf): if init_conf is None: return None @@ -117,7 +116,6 @@ def _get_infer_inst(self, init_conf): logger.info('inferdpt inst loaded') return infer_inst - def _prepare_data(self, data, data_name): if data is None: return None @@ -150,6 +148,8 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved ctx = self.get_context() model = loader_load_from_conf(self.model_conf) + if isinstance(model, HFAutoModelForCausalLM): + model = model.load() if model is None: raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") @@ -243,4 +243,5 @@ def train( def predict(self, test_data: Union[str], saved_model_path: str = None) -> None: logger.warning('The prediction mode is not supported by this algorithm in the current version. Please perform inference using locally saved models.') - return \ No newline at end of file + return + \ No newline at end of file From 7e05fa2727888887b6e696dfdac3a1ac03de8754 Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 8 Jul 2024 14:51:04 +0800 Subject: [PATCH 11/50] Update args for fixing to_dict bug in transformers Signed-off-by: weijingchen Signed-off-by: cwj --- .../fate_llm/algo/pdss/encoder_decoder/init/default_init.py | 2 +- python/fate_llm/trainer/seq2seq_trainer.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py index 0de0f54..ff89ebc 100644 --- a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py +++ b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py @@ -45,6 +45,6 @@ def __init__(self, ctx): super().__init__(ctx) self.ctx = ctx - def get_inferdpt_inst(self): + def get_inst(self): inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key) return SLMEncoderDecoderServer(self.ctx, inference) diff --git a/python/fate_llm/trainer/seq2seq_trainer.py b/python/fate_llm/trainer/seq2seq_trainer.py index 5778031..fa8ea5e 100644 --- a/python/fate_llm/trainer/seq2seq_trainer.py +++ b/python/fate_llm/trainer/seq2seq_trainer.py @@ -19,6 +19,7 @@ from fate.ml.nn.trainer.trainer_base import HomoTrainerMixin, FedArguments, get_ith_checkpoint import os import torch +import copy from torch import nn from typing import Any, Dict, List, Callable from enum import Enum @@ -54,7 +55,7 @@ class _S2STrainingArguments(_hf_Seq2SeqTrainingArguments): log_level: str = field(default="info") deepspeed: Optional[str] = field(default=None) save_safetensors: bool = field(default=False) - use_cpu: bool = field(default=True) + use_cpu: bool = field(default=False) def __post_init__(self): self.push_to_hub = False @@ -68,6 +69,7 @@ def __post_init__(self): super().__post_init__() +DEFAULT_ARGS = _S2STrainingArguments().to_dict() @dataclass class Seq2SeqTrainingArguments(_S2STrainingArguments): @@ -77,7 +79,7 @@ def to_dict(self): # Call the superclass's to_dict method all_args = super().to_dict() # Get a dict with default values for all fields - default_args = _S2STrainingArguments().to_dict() + default_args = copy.deepcopy(DEFAULT_ARGS) # Filter out args that are equal to their default values set_args = {name: value for name, value in all_args.items() if value != default_args.get(name)} return set_args From 0addf820b9a4b1de338daf59234fe9a0b57d78be Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 9 Jul 2024 15:03:25 +0800 Subject: [PATCH 12/50] Support distributed training with eggroll Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/algo/pdss/pdss_trainer.py | 14 ++++-- python/fate_llm/runner/pdss_runner.py | 56 +++++++++++++++-------- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index 7a7a1f4..8ada9cb 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -82,7 +82,7 @@ def __init__(self, tokenizer: Optional[PreTrainedTokenizer] = None, callbacks: Optional[List[TrainerCallback]] = [], compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None ) -> None: self.alpha = alpha @@ -133,6 +133,7 @@ def __init__(self, verbose: bool = False, remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}, + tmp_data_share_path: str = None ) -> None: self.mode = mode @@ -148,6 +149,7 @@ def __init__(self, 'local_inference_kwargs': local_inference_kwargs } self.infer_result = None + self.tmp_data_share_path = tmp_data_share_path assert mode in _MODE, "mode should be one of {}".format(_MODE) if training_args.local_rank == 0: @@ -189,14 +191,16 @@ def infer(self) -> List[str]: logger.info('infer done') if self.mode == 'infer_and_train': if self.args.world_size > 1: # sync dataset with other ranks - logger.info('scattering obj') - save_to(rationale_list, self.args.output_dir) + tmp_path = self.tmp_data_share_path if self.tmp_data_share_path is not None else self.args.output_dir + logger.info('scattering obj, save to temp path {}'.format(tmp_path)) + save_to(rationale_list, tmp_path) if self.args.local_rank > 0: if self.mode == 'infer_and_train': # wait until infer is done - logger.info('waiting for obj') - rationale_list = load(self.args.output_dir) + tmp_path = self.tmp_data_share_path if self.tmp_data_share_path is not None else self.args.output_dir + logger.info('waiting for obj, load frm temp path {}'.format(tmp_path)) + rationale_list = load(tmp_path) self.train_dataset.load_rationale(rationale_list) logger.info('Rationale loaded') diff --git a/python/fate_llm/runner/pdss_runner.py b/python/fate_llm/runner/pdss_runner.py index 15599ca..a2a89f5 100644 --- a/python/fate_llm/runner/pdss_runner.py +++ b/python/fate_llm/runner/pdss_runner.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import torch from fate.components.components.nn.nn_runner import ( NNRunner, load_model_dict_from_path, @@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) + def _check_instances( model: nn.Module = None, optimizer: optim.Optimizer = None, @@ -67,13 +68,13 @@ def _check_instances( class PDSSRunner(NNRunner): def __init__( self, + mode: Literal['train_only', 'infer_only', 'infer_and_train'], model_conf: Optional[Dict] = None, dataset_conf: Optional[Dict] = None, optimizer_conf: Optional[Dict] = None, training_args_conf: Optional[Dict] = None, data_collator_conf: Optional[Dict] = None, tokenizer_conf: Optional[Dict] = None, - mode: Literal['train_only', 'infer_only', 'infer_and_train'] = False, infer_inst_init_conf: Dict = None, encode_template: str = None, instruction_template: str = None, @@ -101,6 +102,7 @@ def __init__( self.perturb_doc_key = perturb_doc_key self.perturbed_response_key = perturbed_response_key self.result_key = result_key + self._temp_data_path = '' # setup var self.trainer = None @@ -125,10 +127,15 @@ def _prepare_data(self, data, data_name): dataset = loader_load_from_conf(self.dataset_conf) if hasattr(dataset, "load"): logger.info("load path is {}".format(data)) - load_output = dataset.load(data) - if load_output is not None: - dataset = load_output - return dataset + import os + if os.path.exists(data) and os.path.isdir(data): + self._temp_data_path = data + load_output = dataset.load(data) + if load_output is not None: + dataset = load_output + return dataset + else: + raise RuntimeError('You must offer an existing folder path as data input, but got {}'.format(data)) else: raise ValueError( f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \ @@ -143,13 +150,12 @@ def _prepare_data(self, data, data_name): return dataset - def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"): ctx = self.get_context() model = loader_load_from_conf(self.model_conf) if isinstance(model, HFAutoModelForCausalLM): - model = model.load() + model = model.load().cuda() if model is None: raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") @@ -183,12 +189,13 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved # args dir_warning(self.training_args_conf) training_args = Seq2SeqTrainingArguments(**self.training_args_conf) - self.training_args = training_args # reset to default, saving to arbitrary path is not allowed in # DefaultRunner training_args.output_dir = output_dir + logger.info('output dir is {}'.format(output_dir)) training_args.resume_from_checkpoint = resume_path # resume path - + self.training_args = training_args + # prepare trainer trainer = PDSSTrainerClient( ctx=ctx, @@ -205,7 +212,8 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved remote_inference_kwargs=self.remote_inference_kwargs, data_collator=data_collator, optimizer=optimizer, - infer_client=self._get_infer_inst(self.infer_inst_init_conf) + infer_client=self._get_infer_inst(self.infer_inst_init_conf), + tmp_data_share_path=self._temp_data_path ) return trainer @@ -232,16 +240,26 @@ def train( ) self.trainer = trainer trainer.train() - if output_dir is not None: - if self.training_args.deepspeed and self.training_args.local_rank != 0: - pass - else: - trainer.save_model(output_dir) + + if self.mode == 'infer_only': + # save result dataset to the output dir + saving_path = output_dir + '/' + 'inference_result.pkl' + torch.save(train_set.dataset, saving_path) + logger.info('inference result saved to {}'.format(saving_path)) + else: + if output_dir is not None: + if self.training_args.deepspeed and self.training_args.local_rank != 0: + pass + else: + trainer.save_model(output_dir) + elif self.is_server(): - trainer = self.server_setup() - trainer.train() + if self.mode == 'train_only': + return + else: + trainer = self.server_setup() + trainer.train() def predict(self, test_data: Union[str], saved_model_path: str = None) -> None: logger.warning('The prediction mode is not supported by this algorithm in the current version. Please perform inference using locally saved models.') return - \ No newline at end of file From 9a191f51b47cc9c7da38f108cf3abe78166f4820 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 9 Jul 2024 17:10:19 +0800 Subject: [PATCH 13/50] Fix infer client init Signed-off-by: weijingchen Signed-off-by: cwj --- python/fate_llm/runner/pdss_runner.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/fate_llm/runner/pdss_runner.py b/python/fate_llm/runner/pdss_runner.py index a2a89f5..423dbf3 100644 --- a/python/fate_llm/runner/pdss_runner.py +++ b/python/fate_llm/runner/pdss_runner.py @@ -155,7 +155,7 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved ctx = self.get_context() model = loader_load_from_conf(self.model_conf) if isinstance(model, HFAutoModelForCausalLM): - model = model.load().cuda() + model = model.load() if model is None: raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") @@ -192,9 +192,13 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved # reset to default, saving to arbitrary path is not allowed in # DefaultRunner training_args.output_dir = output_dir - logger.info('output dir is {}'.format(output_dir)) training_args.resume_from_checkpoint = resume_path # resume path self.training_args = training_args + + if self.training_args.world_size > 0 and self.training_args.local_rank == 0: + infer_client = self._get_infer_inst(self.infer_inst_init_conf) + else: + infer_client = None # only rank 0 need to load the client # prepare trainer trainer = PDSSTrainerClient( @@ -212,7 +216,7 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved remote_inference_kwargs=self.remote_inference_kwargs, data_collator=data_collator, optimizer=optimizer, - infer_client=self._get_infer_inst(self.infer_inst_init_conf), + infer_client=infer_client, tmp_data_share_path=self._temp_data_path ) From 7f315b6253e186a1b98fda404a6107969c1cb3d1 Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 10 Jul 2024 19:08:31 +0800 Subject: [PATCH 14/50] Add pdss doc Signed-off-by: weijingchen Signed-off-by: cwj --- doc/tutorial/pdss/pdss_tutorial.ipynb | 1184 +++++++++++++++++++++++++ 1 file changed, 1184 insertions(+) create mode 100644 doc/tutorial/pdss/pdss_tutorial.ipynb diff --git a/doc/tutorial/pdss/pdss_tutorial.ipynb b/doc/tutorial/pdss/pdss_tutorial.ipynb new file mode 100644 index 0000000..588aab8 --- /dev/null +++ b/doc/tutorial/pdss/pdss_tutorial.ipynb @@ -0,0 +1,1184 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9234355d-389f-484f-9fc2-7b17563b3390", + "metadata": {}, + "source": [ + "# PDSS Tutorial\n", + "\n", + "## Introduction to PDSS\n", + "\n", + "PDSS is a novel framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. The framework addresses two major challenges faced by LLM deployment in real-world applications: the privacy of domain-specific knowledge and resource constraints.\n", + "\n", + "PDSS adopts a server-client architecture where the client sends perturbed prompts to the server-side LLM for inference, generating perturbed rationales. The client then decodes these rationales and uses them to enrich the training of its task-specific SLM, ultimately enhancing its performance.\n", + "\n", + "PDSS introduces two privacy protection strategies: \n", + "- **the Exponential Mechanism Strategy**\n", + "- **the Encoder-Decoder Strategy**\n", + " \n", + "The Exponential Mechanism Strategy utilizes a DP(differential privacy) based exponential mechanism to obfuscate user prompts, while the Encoder-Decoder Strategy employs a specialized Encoder-Decoder SLM to encode and decode perturbed prompts and rationales. These strategies effectively balance user privacy and the usability of rationales, allowing for secure and enhanced training of the client's SLM without compromising on privacy concerns.\n", + "\n", + "Through experiments on various text generation tasks, PDSS demonstrates its effectiveness in training task-specific SLMs with enhanced performance, significantly improving the SLM's capabilities while prioritizing data privacy protection. For more details, please refer to the [original paper](https://arxiv.org/pdf/2406.12403).\n", + "\n", + "**Before reading this tutorial, we strongly recommend that you first read [the InferDPT](./) tutorial.**\n", + "\n", + "## Use the Infer Client & Server\n", + "\n", + "In this section, we are going to introduce the inference part, which is the key part of PDSS that generates useful rationales with privacy-preserving. You can use InferDPT(which utilize the Exponential Mechanism Strategy) or specifically trained SLM as the text encoder & decoder. In this section, we retrieve a sample from the arc-easy dataset as an example:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c443c920-31ff-446a-801f-d7a02409a8c0", + "metadata": {}, + "outputs": [], + "source": [ + "test_example = {'id': 'Mercury_7220990',\n", + "'question': 'Which factor will most likely cause a person to develop a fever?',\n", + "'choices': {'text': ['a leg muscle relaxing after exercise',\n", + "'a bacterial population in the bloodstream',\n", + "'several viral particles on the skin',\n", + "'carbohydrates being digested in the stomach'],\n", + "'label': ['A', 'B', 'C', 'D']},\n", + "'answerKey': 'B'}" + ] + }, + { + "cell_type": "markdown", + "id": "46646b18-46bb-476d-8b1d-1ef661446929", + "metadata": {}, + "source": [ + "### Fate Context\n", + "\n", + "We need to create fate context to enable the communication between client and server. Then, we can initialize infer client(who will encodes the raw prompt and decodes the perturbed response) and server(who deploys the LLM) to enable secure inference." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0cc8e8f8-88d7-45ab-a988-5ead06356418", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))" + ] + }, + { + "cell_type": "markdown", + "id": "c75dbcda-1a40-421d-ab1b-92eca5600866", + "metadata": {}, + "source": [ + "### The DP based Strategy(InferDPT)\n", + "\n", + "As outlined in the [InferDPT tutorial](./), you can initialize the InferDPT client and server to facilitate secure and private inference. Prior to executing the InferDPT component, it is recommended to generate the InferDPT kit by following the step-by-step instructions provided in the tutorial.\n", + "\n", + "#### Client-Side Code\n", + "\n", + "On the client side, we load the pre-computed inferdpt-kit and deploy a local SLM as the decoding model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff0f317f-414f-4b9f-84e6-b992b31350cb", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.algo.inferdpt import inferdpt\n", + "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", + "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", + "from jinja2 import Template\n", + "from fate.arch import Context\n", + "import sys\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n", + "\n", + "ctx = create_ctx(guest)\n", + "save_kit_path = 'your path'\n", + "kit = InferDPTKit.load_from_path(save_kit_path)\n", + "# local deployed small model as decoding model\n", + "inference = APICompletionInference(api_url=\"http://127.0.0.1:8887/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n", + "\n", + "test_example = {'id': 'Mercury_7220990',\n", + "'question': 'Which factor will most likely cause a person to develop a fever?',\n", + "'choices': {'text': ['a leg muscle relaxing after exercise',\n", + "'a bacterial population in the bloodstream',\n", + "'several viral particles on the skin',\n", + "'carbohydrates being digested in the stomach'],\n", + "'label': ['A', 'B', 'C', 'D']},\n", + "'answerKey': 'B'}\n", + "\n", + "\n", + "doc_template = \"\"\"{{question}} \n", + "Choices:{{choices.text}}\n", + "\"\"\"\n", + "\n", + "instruction_template=\"\"\"\n", + "[INST]\n", + "Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\"\n", + "\n", + "Example(s):\n", + "Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n", + "Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n", + "Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n", + "\n", + "Please explain:\n", + "Question:{{perturbed_doc}}\n", + "Rationale:\n", + "[/INST]\n", + "\"\"\"\n", + "\n", + "decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\"\n", + "\n", + "Example(s):\n", + "Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n", + "Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n", + "Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n", + "\n", + "Question:{{perturbed_doc}}\n", + "Rationale:{{perturbed_response | replace('\\n', '')}}\n", + "\n", + "Please explain:\n", + "Question:{{question}} \n", + "Choices:{{choices.text}}\n", + "Rationale:\n", + "\"\"\"\n", + "\n", + "inferdpt_client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\n", + "result = inferdpt_client.inference([test_example], doc_template, instruction_template, decode_template, \\\n", + " remote_inference_kwargs={\n", + " 'stop': ['<\\s>'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " },\n", + " local_inference_kwargs={\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " })\n", + "print('result is {}'.format(result[0]['inferdpt_result']))" + ] + }, + { + "cell_type": "markdown", + "id": "96fbcb01-6907-432f-8393-ae1746559c3a", + "metadata": {}, + "source": [ + "#### Server Side Code" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "960a476c-50a5-40fb-847d-02101cea27ae", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", + "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", + "from jinja2 import Template\n", + "from fate.arch import Context\n", + "import sys\n", + "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n", + "\n", + "ctx = create_ctx(arbiter)\n", + "# Api to a LLM\n", + "inference_server = APICompletionInference(api_url=\"http://127.0.0.1:8888/v1\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')\n", + "inferdpt_server = InferDPTServer(ctx, inference_server)\n", + "inferdpt_server.inference()" + ] + }, + { + "cell_type": "markdown", + "id": "16f908a7-9187-461a-93db-9945456d502d", + "metadata": {}, + "source": [ + "Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:\n", + "\n", + "```\n", + "The given question asks which factor will most likely cause a person to develop a fever. The factors mentioned are a leg muscle relaxing after exercise, a bacterial population in the bloodstream, several viral particles on the skin, and carbohydrates being digested in the stomach. The question is asking which factor is most likely to cause a person to develop a fever. The factors are all related to the body's internal environment, but the most likely factor is a bacterial population in the bloodstream. This is because bacteria can cause a fever, and the body's immune system responds to the infection by producing antibodies that can fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "fb36a485-2fa8-4629-a2cf-2d53fdbbcc5f", + "metadata": {}, + "source": [ + "### The Encoder-Decoder Model Strategy\n", + "\n", + "Similar to the InferDPT, we can initialize SLMEncoderDecoderClient and SLMEncoderDecoderServer to enable secure inference.\n", + "The client will encode the raw prompt using local slm model and then decoded it with the same model\n", + "\n", + "#### Client Side Code" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cd174244-8640-4cb2-8609-ac6468f5a6f5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n", + "\n", + "\n", + "test_example = {'id': 'Mercury_7220990',\n", + "'question': 'Which factor will most likely cause a person to develop a fever?',\n", + "'choices': {'text': ['a leg muscle relaxing after exercise',\n", + "'a bacterial population in the bloodstream',\n", + "'several viral particles on the skin',\n", + "'carbohydrates being digested in the stomach'],\n", + "'label': ['A', 'B', 'C', 'D']},\n", + "'answerKey': 'B'\n", + "}\n", + "\n", + "\n", + "encode_prompt = \"\"\"Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use to end your reply.\n", + "Origin Doc:Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "Perturb Doc: \n", + "\"\"\"\n", + "\n", + "decode_prompt = \"\"\"This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.\n", + "\n", + "Perturbed doc and rationale:\n", + "{{perturbed_doc}}\n", + "Rationale:{{perturbed_response}}\n", + "\n", + "Original Doc:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "\n", + "Recover Rationale:\n", + "\"\"\"\n", + "\n", + "instruction_template = \"\"\"<|im_start|>system\n", + "You are a helpful assistant<|im_end|>\n", + "<|im_start|>user\n", + "Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\n", + "\n", + "Example(s):\n", + "Question:Which factor will most likely cause a person to develop a fever?\n", + "Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\n", + "Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n", + "\n", + "Please explain:\n", + "{{perturbed_doc}}\n", + "Rationale:\n", + "<|im_end|>\n", + "<|im_start|>assistant\n", + "\"\"\"\n", + "\n", + "ctx = create_ctx(guest)\n", + "model_name = 'Deploy your encoder decoder model'\n", + "# api_url to your locally deployed encoder decoder\n", + "api = APICompletionInference(api_url='http://127.0.0.1:8887/v1', api_key='EMPTY', model_name=model_name)\n", + "client = SLMEncoderDecoderClient(ctx, api)\n", + "result = client.inference([test_example], encode_prompt, instruction_template, decode_prompt, \\\n", + " remote_inference_kwargs={\n", + " 'stop': ['<\\s>'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " },\n", + " local_inference_kwargs={\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " })\n", + "print('result is {}'.format(result[0]['inferdpt_result']))" + ] + }, + { + "cell_type": "markdown", + "id": "1a865536-7814-40a2-a814-d00e46f2787f", + "metadata": {}, + "source": [ + "#### Server Side Code" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "cced44b0-0dcb-4427-8efe-a04135b246ac", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderServer\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n", + "\n", + "ctx = create_ctx(arbiter)\n", + "# api url&name are depolyed LLM\n", + "model_name = '/data/cephfs/llm/models/Qwen1.5-14B-Chat/'\n", + "api = APICompletionInference(api_url='http://127.0.0.1:8888/v1', api_key='EMPTY', model_name=model_name)\n", + "server = SLMEncoderDecoderServer(ctx, api)\n", + "server.inference()" + ] + }, + { + "cell_type": "markdown", + "id": "c38ed7a6-2eb2-4f46-b59c-eaafcc9a5b7a", + "metadata": {}, + "source": [ + "Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:\n", + "\n", + "```\n", + "A fever is typically caused by a bacterial population in the bloodstream, as it is a response to an infection. So the answer is 'a bacterial population in the bloodstream'.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "41fbbefd-e931-4e95-9d28-9675ff7865a3", + "metadata": {}, + "source": [ + "## Prefix Dataset & PDSS Trainer\n", + "\n", + "Now that we can carry out privacy-preserving inference and acquire rationales, the next step is to train a new task-specific model, enhanced by the rationales generated by the LLMs.\n", + "\n", + "In this section, we will introduce the PrefixDataset and PDSSTrainer, which facilitate training tasks with the added benefit of supplementary rationales. The PrefixDataset allows you to assign various text prefixes, guiding the model to produce different text targets. With PDSSTrainer, the model is trained to generate both text labels and text rationales at each update step, ultimately leading to superior performance compared to training on the raw dataset alone.\n", + "\n", + "### Prepare dataset\n", + "In this tutorial, we will use the arc-easy dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e25377d0-1a7e-4e8c-aa9f-3bcb03ae0c45", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "dataset = load_dataset(\"arc_easy\")\n", + "dataset.save_to_disk('path_to_save/arce')" + ] + }, + { + "cell_type": "markdown", + "id": "9166110f-bf67-4bf1-9da8-04c16bd79423", + "metadata": {}, + "source": [ + "Let’s proceed with testing the PrefixDataset. We can utilize Jinja2 templates to structure the text and append prefixes or suffixes to our training data.\n", + "\n", + "Please note that at this stage, the dataset does not contain rationales. In the 'rationale_output_template', the key used for the inference results is ‘infer_result’. We can perform secure inference using the PDSSTrainer and then integrate the rationale results, keyed as ‘infer_result’, into the PrefixDataset." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "fdbd93d6-45f3-404f-813e-9ca1fd6def04", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "from fate_llm.dataset.pdss_dataset import PrefixDataset\n", + "\n", + "pds = PrefixDataset(\n", + " tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',\n", + " predict_input_template=\"\"\"Predict:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "Answer:\n", + " \"\"\",\n", + " predict_output_template=\"\"\"{{choices.text[choices.label.index(answerKey)]}}\"\"\",\n", + " rationale_input_template=\"\"\"Explain:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "Rationale:\n", + " \"\"\",\n", + " rationale_output_template=\"\"\"{{infer_result}}\"\"\",\n", + " max_input_length=128,\n", + " max_target_length=128,\n", + " split_key='train'\n", + " )\n", + "\n", + "\n", + "pds.load('path_to_save/arce')" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "100eeb69-8bd2-4e66-b1cc-667f95e47f23", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': 'Mercury_7220990',\n", + " 'question': 'Which factor will most likely cause a person to develop a fever?',\n", + " 'choices': {'text': ['a leg muscle relaxing after exercise',\n", + " 'a bacterial population in the bloodstream',\n", + " 'several viral particles on the skin',\n", + " 'carbohydrates being digested in the stomach'],\n", + " 'label': ['A', 'B', 'C', 'D']},\n", + " 'answerKey': 'B'}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pds.dataset[0] # the structure is the same as hf dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "6f0356ef-f94b-41db-ab66-b1d0eb862eca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'predict': {'input': \"Predict:\\nQuestion:Which factor will most likely cause a person to develop a fever?\\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\nAnswer:\\n \",\n", + " 'output': 'a bacterial population in the bloodstream'},\n", + " 'rationale': {'input': \"Explain:\\nQuestion:Which factor will most likely cause a person to develop a fever?\\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\nRationale:\\n \",\n", + " 'output': '\\n '}}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pds.get_str_item(0) # we can see that the output of rationale term is empty" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "6a227af7-f24a-46bd-9af7-78584a381b33", + "metadata": {}, + "outputs": [], + "source": [ + "print(pds[0]) # show tokenized, for the sake of breif we dont show it in this tutorial doc" + ] + }, + { + "cell_type": "markdown", + "id": "e0382a33-7a45-43a3-8ed3-58ed1d1b07d8", + "metadata": {}, + "source": [ + "### The PDSSTrainer\n", + "\n", + "Here we introduce the PDSSTrainer which is develop based on Huggingface trainer and supports collaboratively training a task with raw labels and additional rationales. Here show how the compute loss function is realized:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b40b7d99-9ef8-43f9-8e28-db96d96af62a", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_loss(self, model, inputs, return_outputs=False):\n", + "\n", + " label_outputs = model(**inputs['predict'])\n", + " cot_outputs = model(**inputs['rationale'])\n", + " loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss\n", + " return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss" + ] + }, + { + "cell_type": "markdown", + "id": "ff1cee5d-68e1-4caf-96b9-132b27b46dca", + "metadata": {}, + "source": [ + "You have the option to choose from three distinct modes: ‘infer_only’, ‘train_only’, and ‘infer_and_train’, to meet your specific requirements.\n", + "- infer_only: Only generate the rationales and they will be saved to the output_dir\n", + "- train_only: Local training only\n", + "- infer_and_train: Generate rationales, and then load them into PrefixDataset and start training\n", + " \n", + "In this instance, we will opt for the ‘infer_and_train’ mode to initially generate rationales with the assistance of the remote LLM. To activate the inference process, it is necessary to initialize the infer client and server for both the client-side and server-side trainers, as demonstrated in the preceding sections.\n", + "\n", + "Below is an PDSS example. We ran this example on a machine equipped with 4 V100-32G GPUs. We launch the client script using deepspeed. LLM is depolyed on another machine." + ] + }, + { + "cell_type": "markdown", + "id": "c559341a-d133-4a24-8f1a-35cd6d2a26d3", + "metadata": {}, + "source": [ + "## PDSS Example\n", + "\n", + "### Client Script(deepspeed_run.py)\n", + "\n", + "This script show how to setup a pdss task on the client side." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4710fda-904a-4e90-bc65-beec7594703f", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "import sys\n", + "from transformers import (\n", + " AutoTokenizer,\n", + " HfArgumentParser,\n", + " Seq2SeqTrainingArguments,\n", + ")\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from typing import List\n", + "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", + "from fate_llm.dataset.pdss_dataset import PrefixDataset\n", + "from fate_llm.algo.pdss.pdss_trainer import PDSSTrainerClient\n", + "from fate_llm.data.data_collator.pdss_collator import PrefixDataCollator\n", + "\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n", + "\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "doc_template = \"\"\"{{question}} \n", + "Choices:{{choices.text}}\n", + "\"\"\"\n", + "\n", + "instruction_template=\"\"\"\n", + "[INST]\n", + "Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\"\n", + "\n", + "Example(s):\n", + "Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n", + "Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n", + "Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n", + "\n", + "Please explain:\n", + "Question:{{perturbed_doc}}\n", + "Rationale:\n", + "[/INST]\n", + "\"\"\"\n", + "\n", + "decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\"\n", + "\n", + "Example(s):\n", + "Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n", + "Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n", + "Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n", + "\n", + "Question:{{perturbed_doc}}\n", + "Rationale:{{perturbed_response | replace('\\n', '')}}\n", + "\n", + "Please explain:\n", + "Question:{{question}} \n", + "Choices:{{choices.text}}\n", + "Rationale:\n", + "\"\"\"\n", + " \n", + "\n", + "if __name__ == \"__main__\":\n", + " \n", + " parser = HfArgumentParser(Seq2SeqTrainingArguments)\n", + " if len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n", + " training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]\n", + " else:\n", + " training_args = parser.parse_args_into_dataclasses()[0]\n", + "\n", + " model_path = '/data/cephfs/llm/models/Qwen1.5-0.5B/'\n", + " pds = PrefixDataset(\n", + " tokenizer_path=model_path,\n", + " predict_input_template=\"\"\"Predict:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "Answer:\n", + " \"\"\",\n", + " predict_output_template=\"\"\"{{choices.text[choices.label.index(answerKey)]}}\"\"\",\n", + " rationale_input_template=\"\"\"Explain:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "Rationale:\n", + " \"\"\",\n", + " rationale_output_template=\"\"\"{{infer_result}}\n", + " \"\"\",\n", + " max_input_length=128,\n", + " max_target_length=128,\n", + " )\n", + " pds.load('/data/cephfs/llm/datasets/arce/')\n", + " \n", + " model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda()\n", + " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", + " model.gradient_checkpointing_enable()\n", + " model.enable_input_require_grads()\n", + "\n", + " ctx = create_ctx(guest)\n", + " if self.training_args.local_rank == 0:\n", + " # only rank 0 need to load infer instance\n", + " save_kit_path = 'your path'\n", + " kit = InferDPTKit.load_from_path(save_kit_path)\n", + " # local deployed small model as decoding model\n", + " inference = APICompletionInference(api_url=\"http://xxxx/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n", + " client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\n", + " else:\n", + " client = None\n", + " \n", + " trainer = PDSSTrainerClient(\n", + " ctx=ctx,\n", + " model=model,\n", + " training_args=training_args,\n", + " tokenizer=tokenizer, \n", + " train_set=pds,\n", + " data_collator=PrefixDataCollator(tokenizer),\n", + " mode='infer_and_train',\n", + " infer_client=client,\n", + " encode_template=doc_template,\n", + " decode_template=decode_template,\n", + " instruction_template=instruction_template,\n", + " remote_inference_kwargs={\n", + " 'stop': ['<\\s>'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " },\n", + " local_inference_kwargs={\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " }\n", + " )\n", + "\n", + " trainer.train()\n", + "\n", + " if training_args.local_rank == 0:\n", + " model.save_pretrained(training_args.output_dir)\n", + " tokenizer.save_pretrained(training_args.output_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "962dd399-1dec-4164-bd86-15aa8550c50b", + "metadata": {}, + "source": [ + "### Server Script(server.py)\n", + "\n", + "This script show how to setup a pdss task on the server side." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91b42972-5308-4ccf-a768-f7dfa087313e", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n", + "from fate_llm.algo.pdss.pdss_trainer import PDSSTraineServer\n", + "from jinja2 import Template\n", + "from fate.arch import Context\n", + "import sys\n", + "\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "name = \"fed1\"\n", + "\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession(data_dir=\"./session_dir\")\n", + " return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n", + "\n", + "\n", + "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "api = APICompletionInference(api_url='http://xxxx:8080/v1', api_key='EMPTY', model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat')\n", + "\n", + "ctx = create_ctx(arbiter)\n", + "server_api = InferDPTServer(ctx, api)\n", + "server = PDSSTraineServer(ctx, server_api)\n", + "server.train()" + ] + }, + { + "cell_type": "markdown", + "id": "125dd68e-c7d4-41aa-9972-4881b1330fb6", + "metadata": {}, + "source": [ + "### Start script\n", + "\n", + "You can launch client side training with following script:" + ] + }, + { + "cell_type": "markdown", + "id": "8a0eccf1-8807-42b7-8473-6ec8ab350626", + "metadata": {}, + "source": [ + "deepspeed --num_nodes 1 --num_gpus 4 deepspeed_run.py \\\n", + " --output_dir \"./\" \\\n", + " --per_device_train_batch_size \"1\" \\\n", + " --gradient_accumulation_steps \"8\" \\\n", + " --max_steps \"750\" \\\n", + " --fp16 \\\n", + " --logging_steps 10 \\\n", + " --save_only_model \\\n", + " --deepspeed \"./ds_config.json\" " + ] + }, + { + "cell_type": "markdown", + "id": "0b506c1c-51f4-448d-9b0b-adf1a71cc7cf", + "metadata": {}, + "source": [ + "and the ds_config.json is\n", + "```\n", + "{ \n", + " \"train_micro_batch_size_per_gpu\": 1,\n", + " \"gradient_accumulation_steps\": 8,\n", + " \"optimizer\": {\n", + " \"type\": \"AdamW\",\n", + " \"params\": {\n", + " \"lr\": 5e-5\n", + " }\n", + " },\n", + " \"fp16\": {\n", + " \"enabled\": true\n", + " },\n", + " \"zero_optimization\": {\n", + " \"stage\": 0\n", + " }\n", + "}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "613fbfb6-ac9e-485b-8587-ffef1e2361c1", + "metadata": {}, + "source": [ + "And server side:" + ] + }, + { + "cell_type": "markdown", + "id": "5b50adf0-8f9c-40e5-9a7d-40a70e30a420", + "metadata": {}, + "source": [ + "```python server.py```" + ] + }, + { + "cell_type": "markdown", + "id": "28a5de71-25fd-4042-a6b7-0ec2c505eaee", + "metadata": {}, + "source": [ + "## PDSS Pipeline Example\n", + "\n", + "You have the capability to submit a PDSS task within the FATE pipeline. By appropriately configuring the necessary settings, you can execute PDSS in a production environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52f1e19b-da8e-4977-adb1-42fb84dee407", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.runner.pdss_runner import PDSSRunner\n", + "from fate.components.components.nn.nn_runner import loader_load_from_conf\n", + "from fate.components.components.nn.loader import Loader\n", + "from fate_llm.dataset.pdss_dataset import PrefixDataset\n", + "from fate_client.pipeline.components.fate.nn.loader import ModelLoader, DatasetLoader, CustFuncLoader, Loader\n", + "from transformers import (\n", + " AutoConfig,\n", + " AutoModel,\n", + " AutoTokenizer,\n", + " DataCollatorForSeq2Seq,\n", + " HfArgumentParser,\n", + " Seq2SeqTrainingArguments,\n", + " set_seed,\n", + " Trainer\n", + ")\n", + "import argparse\n", + "from fate_client.pipeline.utils import test_utils\n", + "from fate_client.pipeline.components.fate.evaluation import Evaluation\n", + "from fate_client.pipeline.components.fate.reader import Reader\n", + "from fate_client.pipeline import FateFlowPipeline\n", + "from fate_client.pipeline.components.fate.nn.torch import nn, optim\n", + "from fate_client.pipeline.components.fate.nn.torch.base import Sequential\n", + "from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner\n", + "from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n", + "\n", + "def main(config=\"../../config.yaml\", namespace=\"\"):\n", + " # obtain config\n", + " if isinstance(config, str):\n", + " config = test_utils.load_job_config(config)\n", + " parties = config.parties\n", + " guest = '9999'\n", + " host = parties.host[0]\n", + " arbiter = '10000'\n", + "\n", + " pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n", + "\n", + " reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n", + " reader_0.guest.task_parameters(\n", + " namespace=\"experiment\",\n", + " name=\"arc_e_example\"\n", + " )\n", + "\n", + " model_conf = Loader(module_name='fate_llm.model_zoo.hf_model', item_name='HFAutoModelForCausalLM', \n", + " pretrained_model_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()\n", + " data_collator_conf = Loader(module_name='fate_llm.data.data_collator.pdss_collator', item_name='get_prefix_data_collator', tokenizer_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()\n", + "\n", + " infer_init_conf_client = {\n", + " 'module_name': 'fate_llm.algo.inferdpt.init.default_init',\n", + " 'item_name': 'InferDPTAPIClientInit'\n", + " }\n", + "\n", + " infer_init_conf_server = {\n", + " 'module_name': 'fate_llm.algo.inferdpt.init.default_init',\n", + " 'item_name': 'InferDPTAPIServerInit'\n", + " }\n", + "\n", + " dataset_conf = {\n", + " 'module_name': 'fate_llm.dataset.pdss_dataset',\n", + " 'item_name': 'PrefixDataset',\n", + " 'kwargs':dict(\n", + " tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',\n", + " predict_input_template=\"\"\"Predict:\n", + " Question:{{question}}\n", + " Choices:{{choices.text}}\n", + " \"\"\",\n", + " predict_output_template=\"\"\"{{choices.text[choices.label.index(answerKey)]}}\"\"\",\n", + " rationale_input_template=\"\"\"Explain:\n", + " Question:{{question}}\n", + " Choices:{{choices.text}}\n", + " \"\"\",\n", + " rationale_output_template=\"\"\"{{infer_result}}\n", + " \"\"\",\n", + " max_input_length=128,\n", + " max_target_length=128,\n", + " split_key='train'\n", + " )\n", + " }\n", + "\n", + " encoder_prompt = \"\"\"{{question}}\n", + "Choices:{{choices.text}}\n", + "\"\"\"\n", + "\n", + " decoder_prompt = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.Use to finish your rationle.\n", + "\n", + "Example(s):\n", + "Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n", + "Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n", + "Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n", + "\n", + "Question:{{perturbed_doc}}\n", + "Rationale:{{perturbed_response | replace('\\n', '')}}\n", + "\n", + "Please explain:\n", + "Question:{{question}} \n", + "Choices:{{choices.text}}\n", + " \"\"\"\n", + "\n", + " instruction_prompt = \"\"\"<|im_start|>system\n", + "You are a helpful assistant<|im_end|>\n", + "<|im_start|>user\n", + "Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\n", + "\n", + "Example(s):\n", + "Question:Which factor will most likely cause a person to develop a fever?\n", + "Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\n", + "Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n", + "\n", + "Please explain:\n", + "Question:{{perturbed_doc}}\n", + "Rationale:\n", + "<|im_end|>\n", + "<|im_start|>assistant\n", + " \"\"\"\n", + "\n", + " remote_inference_kwargs={\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " }\n", + "\n", + " local_inference_kwargs={\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + " }\n", + "\n", + " ds_config = { \n", + " \"train_micro_batch_size_per_gpu\": 1,\n", + " \"gradient_accumulation_steps\": 8,\n", + " \"optimizer\": {\n", + " \"type\": \"AdamW\",\n", + " \"params\": {\n", + " \"lr\": 5e-5\n", + " }\n", + " },\n", + " \"fp16\": {\n", + " \"enabled\": True\n", + " },\n", + " \"zero_optimization\": {\n", + " \"stage\": 0\n", + " }\n", + " }\n", + "\n", + " training_args_dict = dict(\n", + " per_device_train_batch_size=1, \n", + " gradient_accumulation_steps=8,\n", + " logging_steps=10,\n", + " max_steps=30,\n", + " fp16=True,\n", + " log_level='debug'\n", + " )\n", + "\n", + " mode = 'infer_and_train'\n", + "\n", + " client_conf = dict(\n", + " model_conf=model_conf,\n", + " dataset_conf=dataset_conf,\n", + " training_args_conf=training_args_dict,\n", + " data_collator_conf=data_collator_conf,\n", + " mode=mode,\n", + " infer_inst_init_conf=infer_init_conf_client,\n", + " encode_template=encoder_prompt,\n", + " instruction_template=instruction_prompt,\n", + " decode_template=decoder_prompt,\n", + " remote_inference_kwargs=remote_inference_kwargs,\n", + " local_inference_kwargs=local_inference_kwargs,\n", + " perturb_doc_key='perturbed_doc',\n", + " perturbed_response_key='perturbed_response',\n", + " result_key='infer_result'\n", + " )\n", + "\n", + " server_conf = dict(\n", + " infer_inst_init_conf=infer_init_conf_server,\n", + " mode=mode\n", + " )\n", + "\n", + " homo_nn_0 = HomoNN(\n", + " 'nn_0',\n", + " train_data=reader_0.outputs[\"output_data\"],\n", + " runner_module=\"pdss_runner\",\n", + " runner_class=\"PDSSRunner\"\n", + " )\n", + "\n", + " homo_nn_0.guest.task_parameters(runner_conf=client_conf)\n", + " homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\n", + "\n", + " homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\")\n", + "\n", + " pipeline.add_tasks([reader_0, homo_nn_0])\n", + " pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 4}))\n", + " pipeline.compile()\n", + " pipeline.fit()\n", + "\n", + "if __name__ == \"__main__\":\n", + " parser = argparse.ArgumentParser(\"PIPELINE DEMO\")\n", + " parser.add_argument(\"--config\", type=str, default=\"../config.yaml\",\n", + " help=\"config file\")\n", + " parser.add_argument(\"--namespace\", type=str, default=\"\",\n", + " help=\"namespace for data stored in FATE\")\n", + " args = parser.parse_args()\n", + " main(config=args.config, namespace=args.namespace)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 2e261de7ffba00acfb87d29a3a24ef298c3d7538 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 11 Jul 2024 14:10:13 +0800 Subject: [PATCH 15/50] Fix doc Signed-off-by: weijingchen Signed-off-by: cwj --- doc/tutorial/pdss/pdss_tutorial.ipynb | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/doc/tutorial/pdss/pdss_tutorial.ipynb b/doc/tutorial/pdss/pdss_tutorial.ipynb index 588aab8..a80e729 100644 --- a/doc/tutorial/pdss/pdss_tutorial.ipynb +++ b/doc/tutorial/pdss/pdss_tutorial.ipynb @@ -666,6 +666,7 @@ "from fate_llm.dataset.pdss_dataset import PrefixDataset\n", "from fate_llm.algo.pdss.pdss_trainer import PDSSTrainerClient\n", "from fate_llm.data.data_collator.pdss_collator import PrefixDataCollator\n", + "from fate_llm.algo.inferdpt import inferdpt\n", "\n", "\n", "arbiter = (\"arbiter\", 10000)\n", @@ -761,6 +762,7 @@ " \"\"\",\n", " max_input_length=128,\n", " max_target_length=128,\n", + " split_key='train'\n", " )\n", " pds.load('/data/cephfs/llm/datasets/arce/')\n", " \n", @@ -770,11 +772,12 @@ " model.enable_input_require_grads()\n", "\n", " ctx = create_ctx(guest)\n", - " if self.training_args.local_rank == 0:\n", + " if training_args.local_rank == 0:\n", " # only rank 0 need to load infer instance\n", " save_kit_path = 'your path'\n", " kit = InferDPTKit.load_from_path(save_kit_path)\n", " # local deployed small model as decoding model\n", + " from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", " inference = APICompletionInference(api_url=\"http://xxxx/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n", " client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\n", " else:\n", @@ -877,14 +880,9 @@ "source": [ "### Start script\n", "\n", - "You can launch client side training with following script:" - ] - }, - { - "cell_type": "markdown", - "id": "8a0eccf1-8807-42b7-8473-6ec8ab350626", - "metadata": {}, - "source": [ + "You can launch client side training with following script:\n", + "\n", + "```\n", "deepspeed --num_nodes 1 --num_gpus 4 deepspeed_run.py \\\n", " --output_dir \"./\" \\\n", " --per_device_train_batch_size \"1\" \\\n", @@ -893,7 +891,8 @@ " --fp16 \\\n", " --logging_steps 10 \\\n", " --save_only_model \\\n", - " --deepspeed \"./ds_config.json\" " + " --deepspeed \"./ds_config.json\" \n", + "```" ] }, { From 50c1ccaf95dd39cd5e7b23a51cca47b88b3161ab Mon Sep 17 00:00:00 2001 From: weijingchen Date: Fri, 19 Jul 2024 15:45:31 +0800 Subject: [PATCH 16/50] Update inference structure Signed-off-by: weijingchen --- doc/tutorial/inferdpt/inferdpt_tutorial.ipynb | 19 ++------- doc/tutorial/pdss/pdss_tutorial.ipynb | 40 ++++--------------- python/fate_llm/algo/inferdpt/__init__.py | 0 python/fate_llm/algo/inferdpt/inferdpt.py | 2 +- .../algo/inferdpt/init/default_init.py | 2 +- .../pdss/encoder_decoder/init/default_init.py | 2 +- .../encoder_decoder/slm_encoder_decoder.py | 2 +- python/fate_llm/algo/pdss/pdss_trainer.py | 2 +- python/fate_llm/inference/__init__.py | 0 .../{algo/inferdpt => }/inference/api.py | 2 +- .../{algo/inferdpt => }/inference/hf_qw.py | 2 +- .../inferdpt => }/inference/inference_base.py | 0 .../{algo/inferdpt => }/inference/vllm.py | 2 +- 13 files changed, 19 insertions(+), 56 deletions(-) create mode 100644 python/fate_llm/algo/inferdpt/__init__.py create mode 100644 python/fate_llm/inference/__init__.py rename python/fate_llm/{algo/inferdpt => }/inference/api.py (95%) rename python/fate_llm/{algo/inferdpt => }/inference/hf_qw.py (95%) rename python/fate_llm/{algo/inferdpt => }/inference/inference_base.py (100%) rename python/fate_llm/{algo/inferdpt => }/inference/vllm.py (95%) diff --git a/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb b/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb index 5d51738..135a52f 100644 --- a/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb +++ b/doc/tutorial/inferdpt/inferdpt_tutorial.ipynb @@ -319,7 +319,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "# for client\n", "inference_client = APICompletionInference(api_url=\"http://127.0.0.1:8887/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n", "# for server\n", @@ -498,12 +498,9 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "from fate_llm.algo.inferdpt import inferdpt\n", "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", - "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", - "from jinja2 import Template\n", - "from fate.arch import Context\n", "import sys\n", "\n", "\n", @@ -615,10 +612,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", - "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", - "from jinja2 import Template\n", - "from fate.arch import Context\n", + "from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n", "import sys\n", "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", "\n", @@ -694,7 +688,7 @@ "outputs": [], "source": [ "from fate_llm.algo.inferdpt.init._init import InferClientInit\n", - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "from fate_llm.algo.inferdpt import inferdpt\n", "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", @@ -816,13 +810,8 @@ "source": [ "import argparse\n", "from fate_client.pipeline.utils import test_utils\n", - "from fate_client.pipeline.components.fate.evaluation import Evaluation\n", "from fate_client.pipeline.components.fate.reader import Reader\n", "from fate_client.pipeline import FateFlowPipeline\n", - "from fate_client.pipeline.components.fate.nn.torch import nn, optim\n", - "from fate_client.pipeline.components.fate.nn.torch.base import Sequential\n", - "from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner\n", - "from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n", "\n", "\n", "def main(config=\"../../config.yaml\", namespace=\"\"):\n", diff --git a/doc/tutorial/pdss/pdss_tutorial.ipynb b/doc/tutorial/pdss/pdss_tutorial.ipynb index a80e729..a074488 100644 --- a/doc/tutorial/pdss/pdss_tutorial.ipynb +++ b/doc/tutorial/pdss/pdss_tutorial.ipynb @@ -110,12 +110,9 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "from fate_llm.algo.inferdpt import inferdpt\n", "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", - "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", - "from jinja2 import Template\n", - "from fate.arch import Context\n", "import sys\n", "\n", "arbiter = (\"arbiter\", 10000)\n", @@ -225,12 +222,9 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", - "from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n", - "from jinja2 import Template\n", - "from fate.arch import Context\n", + "from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n", "import sys\n", - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "\n", "arbiter = (\"arbiter\", 10000)\n", "guest = (\"guest\", 10000)\n", @@ -297,7 +291,7 @@ }, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient\n", "\n", "arbiter = (\"arbiter\", 10000)\n", @@ -407,7 +401,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n", + "from fate_llm.inference.api import APICompletionInference\n", "from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderServer\n", "\n", "arbiter = (\"arbiter\", 10000)\n", @@ -833,8 +827,6 @@ "source": [ "from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n", "from fate_llm.algo.pdss.pdss_trainer import PDSSTraineServer\n", - "from jinja2 import Template\n", - "from fate.arch import Context\n", "import sys\n", "\n", "\n", @@ -954,30 +946,12 @@ "metadata": {}, "outputs": [], "source": [ - "from fate_llm.runner.pdss_runner import PDSSRunner\n", - "from fate.components.components.nn.nn_runner import loader_load_from_conf\n", - "from fate.components.components.nn.loader import Loader\n", - "from fate_llm.dataset.pdss_dataset import PrefixDataset\n", - "from fate_client.pipeline.components.fate.nn.loader import ModelLoader, DatasetLoader, CustFuncLoader, Loader\n", - "from transformers import (\n", - " AutoConfig,\n", - " AutoModel,\n", - " AutoTokenizer,\n", - " DataCollatorForSeq2Seq,\n", - " HfArgumentParser,\n", - " Seq2SeqTrainingArguments,\n", - " set_seed,\n", - " Trainer\n", - ")\n", + "from fate_client.pipeline.components.fate.nn.loader import Loader\n", "import argparse\n", "from fate_client.pipeline.utils import test_utils\n", - "from fate_client.pipeline.components.fate.evaluation import Evaluation\n", "from fate_client.pipeline.components.fate.reader import Reader\n", "from fate_client.pipeline import FateFlowPipeline\n", - "from fate_client.pipeline.components.fate.nn.torch import nn, optim\n", - "from fate_client.pipeline.components.fate.nn.torch.base import Sequential\n", - "from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner\n", - "from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n", + "\n", "\n", "def main(config=\"../../config.yaml\", namespace=\"\"):\n", " # obtain config\n", diff --git a/python/fate_llm/algo/inferdpt/__init__.py b/python/fate_llm/algo/inferdpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/algo/inferdpt/inferdpt.py b/python/fate_llm/algo/inferdpt/inferdpt.py index ea99cf2..ea5f7c7 100644 --- a/python/fate_llm/algo/inferdpt/inferdpt.py +++ b/python/fate_llm/algo/inferdpt/inferdpt.py @@ -23,7 +23,7 @@ from fate_llm.algo.inferdpt.utils import InferDPTKit from openai import OpenAI import logging -from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.inference.inference_base import Inference from fate_llm.algo.inferdpt._encode_decode import EncoderDecoder from fate_llm.dataset.hf_dataset import HuggingfaceDataset diff --git a/python/fate_llm/algo/inferdpt/init/default_init.py b/python/fate_llm/algo/inferdpt/init/default_init.py index bdcb56d..0050bfd 100644 --- a/python/fate_llm/algo/inferdpt/init/default_init.py +++ b/python/fate_llm/algo/inferdpt/init/default_init.py @@ -14,7 +14,7 @@ # limitations under the License. # from fate_llm.algo.inferdpt.init._init import InferInit -from fate_llm.algo.inferdpt.inference.api import APICompletionInference +from fate_llm.inference.api import APICompletionInference from fate_llm.algo.inferdpt import inferdpt from fate_llm.algo.inferdpt.utils import InferDPTKit from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer diff --git a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py index ff89ebc..7163521 100644 --- a/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py +++ b/python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py @@ -15,7 +15,7 @@ # from fate_llm.algo.inferdpt.init._init import InferInit -from fate_llm.algo.inferdpt.inference.api import APICompletionInference +from fate_llm.inference.api import APICompletionInference from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer diff --git a/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py b/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py index f6eca62..059c25c 100644 --- a/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py +++ b/python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py @@ -22,7 +22,7 @@ from fate_llm.algo.inferdpt.utils import InferDPTKit from openai import OpenAI import logging -from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.inference.inference_base import Inference from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer from fate_llm.dataset.hf_dataset import HuggingfaceDataset diff --git a/python/fate_llm/algo/pdss/pdss_trainer.py b/python/fate_llm/algo/pdss/pdss_trainer.py index 8ada9cb..fdf4e1b 100644 --- a/python/fate_llm/algo/pdss/pdss_trainer.py +++ b/python/fate_llm/algo/pdss/pdss_trainer.py @@ -32,7 +32,7 @@ from transformers import Seq2SeqTrainingArguments from transformers.trainer_utils import EvalPrediction from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments -from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.inference.inference_base import Inference from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer diff --git a/python/fate_llm/inference/__init__.py b/python/fate_llm/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/algo/inferdpt/inference/api.py b/python/fate_llm/inference/api.py similarity index 95% rename from python/fate_llm/algo/inferdpt/inference/api.py rename to python/fate_llm/inference/api.py index bf87930..dff1cf4 100644 --- a/python/fate_llm/algo/inferdpt/inference/api.py +++ b/python/fate_llm/inference/api.py @@ -14,7 +14,7 @@ # limitations under the License. # -from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.inference.inference_base import Inference from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GenerationConfig from typing import List diff --git a/python/fate_llm/algo/inferdpt/inference/hf_qw.py b/python/fate_llm/inference/hf_qw.py similarity index 95% rename from python/fate_llm/algo/inferdpt/inference/hf_qw.py rename to python/fate_llm/inference/hf_qw.py index ab42288..dd94c2a 100644 --- a/python/fate_llm/algo/inferdpt/inference/hf_qw.py +++ b/python/fate_llm/inference/hf_qw.py @@ -14,7 +14,7 @@ # limitations under the License. # -from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.inference.inference_base import Inference from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List import tqdm diff --git a/python/fate_llm/algo/inferdpt/inference/inference_base.py b/python/fate_llm/inference/inference_base.py similarity index 100% rename from python/fate_llm/algo/inferdpt/inference/inference_base.py rename to python/fate_llm/inference/inference_base.py diff --git a/python/fate_llm/algo/inferdpt/inference/vllm.py b/python/fate_llm/inference/vllm.py similarity index 95% rename from python/fate_llm/algo/inferdpt/inference/vllm.py rename to python/fate_llm/inference/vllm.py index 74d3d41..825d3ab 100644 --- a/python/fate_llm/algo/inferdpt/inference/vllm.py +++ b/python/fate_llm/inference/vllm.py @@ -14,7 +14,7 @@ # limitations under the License. # -from fate_llm.algo.inferdpt.inference.inference_base import Inference +from fate_llm.inference.inference_base import Inference from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GenerationConfig import logging From 1ae8865df05c976b9c738d9a464daeb99145e4ea Mon Sep 17 00:00:00 2001 From: weijingchen Date: Fri, 19 Jul 2024 15:57:58 +0800 Subject: [PATCH 17/50] Add new docs Signed-off-by: weijingchen --- .../pdss/encoder_decoder_tutorial.ipynb | 468 ++++++++++++++++++ python/fate_llm/algo/pdss/__init__.py | 0 .../algo/pdss/slm_encoder_decoder_trainer.py | 29 ++ 3 files changed, 497 insertions(+) create mode 100644 doc/tutorial/pdss/encoder_decoder_tutorial.ipynb create mode 100644 python/fate_llm/algo/pdss/__init__.py create mode 100644 python/fate_llm/algo/pdss/slm_encoder_decoder_trainer.py diff --git a/doc/tutorial/pdss/encoder_decoder_tutorial.ipynb b/doc/tutorial/pdss/encoder_decoder_tutorial.ipynb new file mode 100644 index 0000000..404bc2a --- /dev/null +++ b/doc/tutorial/pdss/encoder_decoder_tutorial.ipynb @@ -0,0 +1,468 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a163d9c2-f9d6-4c61-a8e8-76a3f66c38ae", + "metadata": {}, + "source": [ + "# PDSS - Train a SLM Encoder Decoder" + ] + }, + { + "cell_type": "markdown", + "id": "f2b56772-26d5-44fe-9c51-7bc662478b98", + "metadata": {}, + "source": [ + "PDSS is an innovative framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. This method involves a strategy that trains a small language model (SLM) to learn from perturbed and recovered texts. The SLM can then encode raw text, produce results similar to differential privacy mechanisms, and return higher quality recovered text.\n", + "\n", + "In this tutorial, we will introduce how to train an SLM using the built-in trainer." + ] + }, + { + "cell_type": "markdown", + "id": "62c6d18a-cc91-4cf5-9cfd-0f97095f7041", + "metadata": {}, + "source": [ + "## Prepare Data\n", + "\n", + "Several steps need to be done to prepare data for training a SLM encoder-decoder model:\n", + "- Sample data from original dataset(For example 50%)\n", + "- Organize raw text and get a direct rationale reply from a remote LLM\n", + "- Perturb doc using InferDPTKit to get perturbed docs\n", + "- Get perturbed replies from a remote LLM\n", + "- Organize training data\n", + "\n", + "### Sample data\n", + "Here we will use the arc-easy data as an example, and take first 50% of the original dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "40cc1bb8-a17c-4abc-9279-0849e98ca116", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset, load_from_disk\n", + "ds = load_dataset('arc_easy')['train']\n", + "ds = [ds[i] for i in range(len(ds)//2)]" + ] + }, + { + "cell_type": "markdown", + "id": "0caff897-5b2b-4409-8601-10f973133b10", + "metadata": {}, + "source": [ + "### Get Direct Replies from A Remote LLM\n", + "\n", + "We use the inference class to create an API for remote LLMs, or you can implement this part on your own." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "cf128b46-dea2-4eb4-bf31-568e56b9b78e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "from fate_llm.inference.api import APICompletionInference\n", + "from jinja2 import Template\n", + "from transformers import AutoTokenizer\n", + "\n", + "# We are using a Qwen 14B model as the remote model\n", + "# You can change the setting\n", + "api = APICompletionInference(\n", + " api_url='http://172.21.140.2:8081/v1',\n", + " api_key='EMPTY',\n", + " model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat'\n", + ")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B-Chat/')\n", + "\n", + "arc_e_template_r = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n", + "Use to finish your rationle.\n", + "\n", + "Example(s):\n", + "Question:Which factor will most likely cause a person to develop a fever?\n", + "Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\n", + "Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n", + "\n", + "Please explain:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "Rationale:\n", + "\"\"\"\n", + "\n", + "template = Template(arc_e_template_r)\n", + "docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in ds]\n", + "results = api.inference(docs_to_infer, {\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + "})\n", + "\n", + "for i, r in zip(ds, results):\n", + " i['rationale'] = r" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "212822ab-9f64-49a2-bb95-ef8ee2de8e49", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A fever is a response to an infection, typically caused by bacteria or viruses. So, the answer is 'a bacterial population in the bloodstream' because it indicates an immune response to a foreign invader. 'Several viral particles on the skin' could also lead to a fever if they enter the body, but bloodstream presence is more direct. The other choices are unrelated to fever development.\n" + ] + } + ], + "source": [ + "print(results[0])" + ] + }, + { + "cell_type": "markdown", + "id": "0f6a0039-1530-4b87-a098-fd2eb01805c2", + "metadata": {}, + "source": [ + "### Perturb Docs & Replies\n", + "\n", + "You can refer to the InferDPT tutorial for guidance on using the InferDPTKit to generate perturbed documents: [InferDPT Document](./)\n", + "We can produce perturbed doc using InferDPTKit:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "39249747-bfaa-43bf-8b66-896568941ab8", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.inferdpt.utils import InferDPTKit\n", + "path_to_kit = '/data/projects/inferdpt/test_fate_llm/'\n", + "kit = InferDPTKit.load_from_path(path_to_kit)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "39b9cefa-dfdb-4bac-b313-4ca3bc118aee", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "tmp_ds = copy.deepcopy(ds)\n", + "\n", + "q_doc = [kit.perturb(i, epsilon=1.0) for i in [Template(\"\"\"{{question}}\"\"\").render(i) for i in tmp_ds]]\n", + "c_doc = [kit.perturb(i, epsilon=1.0) for i in [Template(\"\"\"{{choices.text}}\"\"\").render(i) for i in tmp_ds]]\n", + "for i,q,c in zip(tmp_ds,q_doc,c_doc):\n", + " i['question'] = q\n", + " i['choices']['text'] = c" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "61b30886-746c-43c5-889a-a6583dc939d0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': 'Mercury_7179953',\n", + " 'question': 'stuff two alpha Rogers are today chap in Department?',\n", + " 'choices': {'text': \"['muscular and skeletal', 'digestive and muscular', 'skeletal and pasteiratory', 'respiratory and exhibive']\",\n", + " 'label': ['A', 'B', 'C', 'D']},\n", + " 'answerKey': 'A',\n", + " 'rationale': {...}}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tmp_ds[6]" + ] + }, + { + "cell_type": "markdown", + "id": "fed90297-9957-4f8b-a53c-37a03d516c78", + "metadata": {}, + "source": [ + "And then send formatted docs to remote LLM for perturbed responses:" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "5b8bd833-fb0f-418b-bd9b-6452e8ae4d6c", + "metadata": {}, + "outputs": [], + "source": [ + "template = Template(arc_e_template_r)\n", + "docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in tmp_ds]\n", + "p_results = api.inference(docs_to_infer, {\n", + " 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n", + " 'temperature': 0.01,\n", + " 'max_tokens': 256\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "187361fa-8b73-4a01-9039-f52ec98a5791", + "metadata": {}, + "outputs": [], + "source": [ + "for i, r in zip(ds, p_results):\n", + " i['p_rationale'] = r\n", + "\n", + "for i,q,c in zip(ds, q_doc, c_doc):\n", + " i['p_question'] = q\n", + " i['p_choice'] = c" + ] + }, + { + "cell_type": "markdown", + "id": "927b2265-4e87-4275-98dc-7f33d405e19a", + "metadata": {}, + "source": [ + "### Organize Training Data\n", + "\n", + "As described in the original paper, we need to train the encoder and decoder in one model.\n", + "We can organize the training data using templates below:" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "9292ad25-12c7-418a-9e77-b433b95f57ac", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = []\n", + "\n", + "encoder_prompt = Template(\"\"\"Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use to end your reply.\n", + "Origin Doc: \n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "\n", + "Perturbed Doc:\n", + "\"\"\")\n", + "\n", + "encoder_out = Template(\"\"\"\n", + "Question:{{p_question}}\n", + "Choices:{{p_choice}}\n", + "\"\"\")\n", + "\n", + "decoder_in = Template(\"\"\"This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.\n", + "\n", + "Perturbed doc and rationale:\n", + "Question:{{p_question}}\n", + "Choices:{{p_choice}}\n", + "Rationale:{{p_rationale}}\n", + "\n", + "Original Doc:\n", + "Question:{{question}}\n", + "Choices:{{choices.text}}\n", + "\n", + "Recover Rationale:\n", + "\"\"\")\n", + "\n", + "decoder_out = Template(\"\"\"{{rationale}}\"\"\")\n", + "\n", + "\n", + "for i in ds:\n", + " a = {}\n", + " a['encoder_in'] = encoder_prompt.render(i)\n", + " a['encoder_out'] = encoder_out.render(i)\n", + " a['decoder_in'] = decoder_in.render(i)\n", + " a['decoder_out'] = decoder_out.render(i)\n", + " train_data.append(a)\n", + "\n", + "import torch\n", + "torch.save(train_data, './slm_ed_train_data.pkl')" + ] + }, + { + "cell_type": "markdown", + "id": "dd73db44-4e73-4c1e-8f27-755522587636", + "metadata": {}, + "source": [ + "## Train Script\n", + "\n", + "The key step: preparing data is now done. Then we can train a SLM model using the train data. You can use following dataset&trainer class to train an encoder-decoder slm model. Here we use Qwen-0.5B as the example." + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "eb01c591-3c04-4317-8bb0-f55846fb1b66", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "f0da4e10-af80-4216-8ff8-5816dabc8526", + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForCausalLM.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/').half().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "634fc973-29c8-499e-a99e-d50b7ee54124", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class EDDataset(Dataset):\n", + "\n", + " def __init__(self, tokenizer, train_data, max_input_length=64, max_target_length=64):\n", + " self.tokenizer = tokenizer\n", + " self.dataset = train_data\n", + " self.max_input_length = max_input_length\n", + " self.max_target_length = max_target_length\n", + " self.max_seq_length = max_input_length + max_target_length + 1\n", + "\n", + " def get_str_item(self, i) -> dict:\n", + "\n", + " data_item = self.dataset[i]\n", + " ret_dict = {\n", + " 'encoder':{\n", + " 'input': data_item['encoder_in'],\n", + " 'output': data_item['encoder_out']\n", + " },\n", + " 'decoder':{\n", + " 'input': data_item['decoder_in'],\n", + " 'output': data_item['decoder_out']\n", + " }\n", + " }\n", + " return ret_dict\n", + "\n", + " def _process_item(self, data_item):\n", + "\n", + " a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True,\n", + " max_length=self.max_input_length)\n", + " b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True,\n", + " max_length=self.max_target_length)\n", + " context_length = len(a_ids)\n", + " input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]\n", + " labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]\n", + " pad_len = self.max_seq_length - len(input_ids)\n", + " input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len\n", + " labels = labels + [self.tokenizer.pad_token_id] * pad_len\n", + " labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]\n", + "\n", + " assert len(input_ids) == len(labels), f\"length mismatch: {len(input_ids)} vs {len(labels)}\"\n", + "\n", + " return {\n", + " \"input_ids\": input_ids,\n", + " \"labels\": labels\n", + " }\n", + "\n", + " def get_tokenized_item(self, i) -> dict: \n", + "\n", + " str_item = self.get_str_item(i)\n", + " ret_dict = {\n", + " 'encoder': self._process_item(str_item['encoder']),\n", + " 'docoder': self._process_item(str_item['decoder'])\n", + " }\n", + " return ret_dict\n", + "\n", + " def __getitem__(self, i) -> dict:\n", + " item = self.get_tokenized_item(i)\n", + " return item" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "5f914b1f-cf14-4bdc-acc9-ae1b73cf857c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "train_ds = EDDataset(AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/'), train_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "817084b2-2439-45d8-aa1b-da0b1a8a2846", + "metadata": {}, + "outputs": [], + "source": [ + "print(train_ds.get_str_item(0))\n", + "print(train_ds[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "303bcb23-d54b-4375-bad2-bf5450c14f28", + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.pdss.slm_encoder_decoder_trainer import EncoderDecoderPrefixTrainer, EDPrefixDataCollator" + ] + }, + { + "cell_type": "markdown", + "id": "aa5a0b4f-cd03-4867-8753-fc5bcb036c69", + "metadata": {}, + "source": [ + "After completing the setup, you can utilize the EncoderDecoderPrefixTrainer, EDPrefixDataCollator, and the training dataset to train an SLM encoder-decoder model following the Huggingface approach! " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/fate_llm/algo/pdss/__init__.py b/python/fate_llm/algo/pdss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/algo/pdss/slm_encoder_decoder_trainer.py b/python/fate_llm/algo/pdss/slm_encoder_decoder_trainer.py new file mode 100644 index 0000000..5083abc --- /dev/null +++ b/python/fate_llm/algo/pdss/slm_encoder_decoder_trainer.py @@ -0,0 +1,29 @@ +from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer +from transformers import DataCollatorForSeq2Seq +from transformers import AutoTokenizer +import pandas as pd + + +class EDPrefixDataCollator(DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + features_df = pd.DataFrame(features) + a = super().__call__(list(features_df['encoder']), return_tensors) + b = super().__call__(list(features_df['decoder']), return_tensors) + + return { + 'encoder': a, + 'decoder': b + } + + +class EncoderDecoderPrefixTrainer(Seq2SeqTrainer): + + def __init__(self, alpha=0.5, *args, **kwargs): + super().__init__(*args, **kwargs) + self.alpha = alpha + + def compute_loss(self, model, inputs, return_outputs=False): + out_a = model(**inputs['encoder']) + out_b = model(**inputs['decoder']) + loss = self.alpha * out_a.loss + (1. - self.alpha) * out_b.loss + return (loss, {'out_a': out_a, 'out_b': out_b}) if return_outputs else loss From aaf795a6abca20ebae2668ff8e7125cba4e9f35b Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Mon, 22 Jul 2024 09:49:01 +0800 Subject: [PATCH 18/50] update fdkt algo Signed-off-by: mgqa34 --- python/fate_llm/algo/dp/__init__.py | 2 + python/fate_llm/algo/dp/dp_trainer.py | 158 ++++++++++ .../algo/dp/opacus_compatibility/__init__.py | 16 ++ .../grad_sample/__init__.py | 0 .../grad_sample/embedding.py | 44 +++ .../optimizers/__init__.py | 0 .../optimizers/optimizer.py | 45 +++ .../transformers_compate.py | 30 ++ python/fate_llm/algo/fdkt/__init__.py | 26 ++ python/fate_llm/algo/fdkt/cluster/__init__.py | 15 + python/fate_llm/algo/fdkt/cluster/cluster.py | 39 +++ .../algo/fdkt/cluster/cluster_method.py | 35 +++ python/fate_llm/algo/fdkt/fdkt_data_aug.py | 270 ++++++++++++++++++ python/fate_llm/algo/fdkt/utils/__init__.py | 15 + python/fate_llm/algo/fdkt/utils/dp_loss.py | 52 ++++ .../fate_llm/algo/fdkt/utils/text_generate.py | 79 +++++ .../embedding_transformer/__init__.py | 0 .../embedding_transformer/st_model.py | 71 +++++ python/fate_llm/model_zoo/hf_model.py | 4 + python/fate_llm/runner/fdkt_runner.py | 195 +++++++++++++ 20 files changed, 1096 insertions(+) create mode 100644 python/fate_llm/algo/dp/__init__.py create mode 100644 python/fate_llm/algo/dp/dp_trainer.py create mode 100644 python/fate_llm/algo/dp/opacus_compatibility/__init__.py create mode 100644 python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py create mode 100644 python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py create mode 100644 python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py create mode 100644 python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py create mode 100644 python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py create mode 100644 python/fate_llm/algo/fdkt/__init__.py create mode 100644 python/fate_llm/algo/fdkt/cluster/__init__.py create mode 100644 python/fate_llm/algo/fdkt/cluster/cluster.py create mode 100644 python/fate_llm/algo/fdkt/cluster/cluster_method.py create mode 100644 python/fate_llm/algo/fdkt/fdkt_data_aug.py create mode 100644 python/fate_llm/algo/fdkt/utils/__init__.py create mode 100644 python/fate_llm/algo/fdkt/utils/dp_loss.py create mode 100644 python/fate_llm/algo/fdkt/utils/text_generate.py create mode 100644 python/fate_llm/model_zoo/embedding_transformer/__init__.py create mode 100644 python/fate_llm/model_zoo/embedding_transformer/st_model.py create mode 100644 python/fate_llm/runner/fdkt_runner.py diff --git a/python/fate_llm/algo/dp/__init__.py b/python/fate_llm/algo/dp/__init__.py new file mode 100644 index 0000000..ce0bd2e --- /dev/null +++ b/python/fate_llm/algo/dp/__init__.py @@ -0,0 +1,2 @@ +from .opacus_compatibility.transformers_compate import get_model_class +from .dp_trainer import DPTrainer, DPTrainingArguments diff --git a/python/fate_llm/algo/dp/dp_trainer.py b/python/fate_llm/algo/dp/dp_trainer.py new file mode 100644 index 0000000..c159cc8 --- /dev/null +++ b/python/fate_llm/algo/dp/dp_trainer.py @@ -0,0 +1,158 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import opacus +import os +import torch +from dataclasses import dataclass, field +from transformers.training_args_seq2seq import Seq2SeqTrainingArguments +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing import Optional, Callable +from .opacus_compatibility import add_layer_compatibility, add_optimizer_compatibility +from .opacus_compatibility.transformers_compate import prepare_position_ids + +logger = logging.getLogger(__name__) + + +@dataclass +class DPTrainingArguments(Seq2SeqTrainingArguments): + target_epsilon: float = field(default=3) + target_delta: float = field(default=1e-5) + freeze_embedding: bool = field(default=False) + device_id: int = field(default=0) + + +class DPTrainer(object): + def __init__( + self, + model: torch.nn.Module, + training_args: DPTrainingArguments, + train_set, + loss_fn, + optimizer: torch.optim.Optimizer = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + data_collator: Callable = None, + use_tqdm: bool = False, + ): + self.module = model + self.training_args = training_args + self.ori_optimizer = optimizer + self.lr_scheduler = scheduler + self.train_set = train_set + self.data_collator = data_collator + self.loss_fn = loss_fn + self.use_tqdm = use_tqdm + + self.data_loader = DataLoader( + dataset=self.train_set, + shuffle=True, + batch_size=self.training_args.per_device_train_batch_size, + collate_fn=self.data_collator + ) + + if not self.training_args.use_cpu: + self.module.cuda(self.training_args.device_id) + + if self.training_args.freeze_embedding: + self.freeze_model_embedding() + + self.dp_model = None + self.dp_optimizer = None + self.privacy_engine = None + self._init_dp_model() + + def _init_dp_model(self): + self.module.train() + + # add compatibility for layer hooks + add_layer_compatibility(opacus) + + self.privacy_engine = opacus.PrivacyEngine(accountant="rdp") + self.dp_model, self.dp_optimizer, _ = self.privacy_engine.make_private_with_epsilon( + module=self.module, + optimizer=self.ori_optimizer, + data_loader=self.data_loader, + target_delta=self.training_args.target_delta, + target_epsilon=self.training_args.target_epsilon, + max_grad_norm=self.training_args.max_grad_norm, + epochs=int(self.training_args.num_train_epochs), + ) + + add_optimizer_compatibility(self.dp_optimizer) + + def train(self): + for epoch in range(int(self.training_args.num_train_epochs)): + self._train_an_epoch() + + def _train_an_epoch(self): + if self.use_tqdm: + data_loader = tqdm(self.data_loader) + else: + data_loader = self.data_loader + + for batch_idx, batch_data in enumerate(tqdm(data_loader)): + input_ids = batch_data["input_ids"] + labels = batch_data["labels"] + + if "attention_mask" not in batch_data: + attention_mask = torch.ones(input_ids.shape) + else: + attention_mask = batch_data["attention_mask"] + + if not self.training_args.use_cpu: + input_ids = input_ids.to(self.module.device) + labels = labels.to(self.module.device) + attention_mask = attention_mask.to(self.module.device) + + inputs = self._prepare_batch_input(input_ids) + logits = self.dp_model(**inputs).logits + + loss = self.loss_fn(logits, labels, attention_mask) + + loss = loss.mean() + loss.backward() + + if (batch_idx + 1) % self.training_args.gradient_accumulation_steps == 0 or \ + batch_idx + 1 == len(self.data_loader): + self.dp_optimizer.step() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + self.dp_optimizer.zero_grad() + else: + self.dp_optimizer.step() + self.dp_optimizer.zero_grad() + + def _prepare_batch_input(self, input_ids) -> dict: + position_ids = prepare_position_ids(self.module, input_ids) + if not self.training_args.use_cpu: + position_ids = position_ids.to(self.module.device) + + return dict(input_ids=input_ids, position_ids=position_ids) + + def freeze_model_embedding(self): + self.module.get_input_embeddings().requires_grad_(False) + + def save_model( + self, + output_dir="./" + ): + if hasattr(self.module, "save_pretrained"): + self.module.save_pretrained(output_dir) + else: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + torch.save(self.module.state_dict(), output_dir + '/pytorch_model.bin') diff --git a/python/fate_llm/algo/dp/opacus_compatibility/__init__.py b/python/fate_llm/algo/dp/opacus_compatibility/__init__.py new file mode 100644 index 0000000..8d065fe --- /dev/null +++ b/python/fate_llm/algo/dp/opacus_compatibility/__init__.py @@ -0,0 +1,16 @@ +from .grad_sample.embedding import compute_embedding_grad_sample +from .optimizers.optimizer import add_noise_wrapper + + +def add_layer_compatibility(opacus): + replace_method = [] + for k, v in opacus.GradSampleModule.GRAD_SAMPLERS.items(): + if v.__name__ == "compute_embedding_grad_sample": + replace_method.append(k) + + for k in replace_method: + opacus.GradSampleModule.GRAD_SAMPLERS[k] = compute_embedding_grad_sample + + +def add_optimizer_compatibility(optimizer): + add_noise_wrapper(optimizer) diff --git a/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py new file mode 100644 index 0000000..fa77da6 --- /dev/null +++ b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py @@ -0,0 +1,44 @@ +from typing import Dict +import types + +import torch +import torch.nn as nn + + +def compute_embedding_grad_sample( + layer: nn.Embedding, activations: torch.Tensor, backprops: torch.Tensor +) -> Dict[nn.Parameter, torch.Tensor]: + """ + Computes per sample gradients for ``nn.Embedding`` layer. + + Args: + layer: Layer + activations: Activations + backprops: Backpropagations + """ + activations = activations[0] + ret = {} + if layer.weight.requires_grad: + saved = torch.backends.cudnn.deterministic + torch.backends.cudnn.deterministic = True + + batch_size = activations.shape[0] + if batch_size == 0: + ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0) + return ret + + index = ( + activations.unsqueeze(-1) + .expand(*activations.shape, layer.embedding_dim) + .reshape(batch_size, -1, layer.embedding_dim) + ) + grad_sample = torch.zeros( + batch_size, *layer.weight.shape, device=layer.weight.device, dtype=backprops.dtype + ) + grad_sample.scatter_add_( + 1, index, backprops.reshape(batch_size, -1, layer.embedding_dim) + ) + torch.backends.cudnn.deterministic = saved + ret[layer.weight] = grad_sample + + return ret diff --git a/python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py new file mode 100644 index 0000000..360188b --- /dev/null +++ b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py @@ -0,0 +1,45 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import types +from opacus.optimizers.optimizer import ( + _check_processed_flag, + _generate_noise, + _mark_as_processed +) + + +def add_noise(self): + """ + Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad`` + """ + + for p in self.params: + _check_processed_flag(p.summed_grad) + + noise = _generate_noise( + std=self.noise_multiplier * self.max_grad_norm, + reference=p.summed_grad, + generator=self.generator, + secure_mode=self.secure_mode, + ) + noise = noise.to(p.summed_grad.dtype) + p.grad = (p.summed_grad + noise).view_as(p) + + _mark_as_processed(p.summed_grad) + + +def add_noise_wrapper(optimizer): + optimizer.add_noise = types.MethodType(add_noise, optimizer) diff --git a/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py b/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py new file mode 100644 index 0000000..2235aff --- /dev/null +++ b/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py @@ -0,0 +1,30 @@ +import torch +import transformers +from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM +from transformers.modeling_utils import unwrap_model + + +def get_model_class(model): + if isinstance(model, PELLM): + model = model._pe_lm + + model = unwrap_model(model) + + return model.__class__ + + +def prepare_position_ids(model, input_ids): + if get_model_class(model) == transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel: + return _get_position_ids_for_gpt2(input_ids) + else: + raise ValueError(f"Can not prepare position_ids for model_type={model.__class__}") + + +def _get_position_ids_for_gpt2(input_ids): + past_length = 0 + position_ids = torch.arange(past_length, input_ids.shape[-1] + past_length, dtype=torch.long, + device=input_ids.device) + position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.repeat(input_ids.shape[0], 1) + + return position_ids diff --git a/python/fate_llm/algo/fdkt/__init__.py b/python/fate_llm/algo/fdkt/__init__.py new file mode 100644 index 0000000..256395b --- /dev/null +++ b/python/fate_llm/algo/fdkt/__init__.py @@ -0,0 +1,26 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .fdkt_data_aug import ( + FDKTSLM, + FDKTLLM, + FDKTTrainingArguments +) + +__all__ = [ + "FDKTSLM", + "FDKTLLM", + "FDKTTrainingArguments" +] diff --git a/python/fate_llm/algo/fdkt/cluster/__init__.py b/python/fate_llm/algo/fdkt/cluster/__init__.py new file mode 100644 index 0000000..ef471ba --- /dev/null +++ b/python/fate_llm/algo/fdkt/cluster/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# \ No newline at end of file diff --git a/python/fate_llm/algo/fdkt/cluster/cluster.py b/python/fate_llm/algo/fdkt/cluster/cluster.py new file mode 100644 index 0000000..92e0cd7 --- /dev/null +++ b/python/fate_llm/algo/fdkt/cluster/cluster.py @@ -0,0 +1,39 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List +from .cluster_method import get_cluster_runner + + +class SentenceCluster(object): + def __init__(self, model, cluster_method="kmeans", n_clusters=8, **other_cluster_args): + self.model = model + self.cluster_method = cluster_method + self.n_clusters = n_clusters + self.other_cluster_args = other_cluster_args + + def get_embeddings(self, sentences: List[str]): + return self.model.encode(sentences) + + def cluster(self, sentences): + embeddings = self.get_embeddings(sentences) + + cluster_runner = get_cluster_runner(method=self.cluster_method, + n_clusters=self.n_clusters, + **self.other_cluster_args) + + cluster_rets = cluster_runner.fit(embeddings) + + return cluster_rets diff --git a/python/fate_llm/algo/fdkt/cluster/cluster_method.py b/python/fate_llm/algo/fdkt/cluster/cluster_method.py new file mode 100644 index 0000000..f9b4247 --- /dev/null +++ b/python/fate_llm/algo/fdkt/cluster/cluster_method.py @@ -0,0 +1,35 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from sklearn.cluster import KMeans + + +class KMeansRunner(object): + def __init__(self, n_clusters, **other_cluster_args): + self.n_clusters = n_clusters + self.other_cluster_args = other_cluster_args + + def fit(self, x): + model = KMeans(n_clusters=self.n_clusters, **self.other_cluster_args) + model.fit(x) + + return model.labels_ + + +def get_cluster_runner(method, n_clusters, **other_cluster_args): + if method.lower() == "kmeans": + return KMeansRunner(n_clusters, **other_cluster_args) + else: + raise ValueError(f"cluster method={method} is not implemented") diff --git a/python/fate_llm/algo/fdkt/fdkt_data_aug.py b/python/fate_llm/algo/fdkt/fdkt_data_aug.py new file mode 100644 index 0000000..501b7fc --- /dev/null +++ b/python/fate_llm/algo/fdkt/fdkt_data_aug.py @@ -0,0 +1,270 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import torch +import logging +from dataclasses import dataclass, field +from ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments +from typing import Optional, Callable +from fate.arch import Context +from transformers import PreTrainedTokenizer +from .utils.text_generate import slm_text_generate, general_text_generate +from .cluster.cluster import SentenceCluster + + +logger = logging.getLogger(__name__) +SLM_SYNTHETIC_DATA = "slm_synthetic_data" +LLM_AUG_DATA = "llm_aug_data" + + +@dataclass +class FDKTTrainingArguments(Seq2SeqTrainingArguments): + """ + slm parameters + """ + dp_training: bool = field(default=True) + target_epsilon: float = field(default=3) + target_delta: float = field(default=1e-5) + freeze_embedding: bool = field(default=False) + device_id: int = field(default=0) + + """ + slm generation config + """ + seq_num_for_single_category: int = field(default=None) + + """ + dp loss params + """ + label_smoothing_factor = 0.02 + loss_reduce = True + + """ + llm parameters + """ + sample_num_per_cluster: int = field(default=None) + filter_data_batch_size: int = field(default=2) + filter_prompt_max_length: int = field(default=2048) + filter_generation_config: dict = field(default=None) + + aug_generation_config: dict = field(default=None) + aug_prompt_num: int = field(default=None) + aug_data_batch_size: int = field(default=2) + aug_prompt_max_length: int = field(default=2048) + + def to_dict(self): + from dataclasses import fields + from enum import Enum + d = {field.name: getattr(self, field.name) for field in fields(self) if field.init} + + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + return d + + +class FDKTSLM(object): + def __init__( + self, + ctx: Context, + model: torch.nn.Module, + training_args: FDKTTrainingArguments, + train_set, + optimizer: torch.optim.Optimizer = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + data_collator: Callable = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + ): + super(FDKTSLM, self).__init__() + self.ctx = ctx + self.training_args = training_args + self.train_set = train_set + self.model = model + self.tokenizer = tokenizer + self.optimizer = optimizer + self.scheduler = scheduler + + self.data_collator = data_collator + + if not self.training_args.use_cpu: + self.model.cuda(self.training_args.device_id) + + def aug_data(self): + if self.training_args.dp_training: + self.dp_train() + + prefix_prompt_ids_dict = self.train_set.get_generate_prompt(tokenize=True) + generated_text = slm_text_generate( + self.model, + self.tokenizer, + prompt_ids_dict=prefix_prompt_ids_dict, + seq_num_for_single_category=self.training_args.seq_num_for_single_category, + batch_size=self.training_args.per_device_train_batch_size, + use_cpu=self.training_args.use_cpu, + generation_config=self.training_args.generation_config + ) + + self.sync_synthetic_dataset(generated_text) + + def dp_train(self): + from ..dp import DPTrainer, DPTrainingArguments, get_model_class + from .utils.dp_loss import SequenceCrossEntropyLoss + dp_training_args = DPTrainingArguments( + target_delta=self.training_args.target_delta, + target_epsilon=self.training_args.target_epsilon, + freeze_embedding=self.training_args.freeze_embedding, + device_id=self.training_args.device_id, + num_train_epochs=self.training_args.num_train_epochs, + per_device_train_batch_size=self.training_args.per_device_train_batch_size, + output_dir="/" if self.training_args.output_dir is None else self.training_args.output_dir + ) + + loss_fn = SequenceCrossEntropyLoss( + get_model_class(self.model).__name__, + label_smoothing=self.training_args.label_smoothing_factor, + reduce=self.training_args.loss_reduce + ) + + dp_trainer = DPTrainer( + model=self.model, + training_args=dp_training_args, + train_set=self.train_set, + optimizer=self.optimizer, + scheduler=self.scheduler, + data_collator=self.data_collator, + loss_fn=loss_fn + ) + + dp_trainer.train() + + def sync_synthetic_dataset(self, data): + self.ctx.arbiter.put(SLM_SYNTHETIC_DATA, data) + + def sync_aug_data(self): + return self.ctx.arbiter.get(LLM_AUG_DATA) + + def save_model( + self, + output_dir="./" + ): + if hasattr(self.model, "save_pretrained"): + self.model.save_pretrained(output_dir) + else: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin') + + +class FDKTLLM(object): + def __init__( + self, + ctx: Context, + model: torch.nn.Module, + embedding_model: torch.nn.Module, + training_args: FDKTTrainingArguments, + dataset, + tokenizer: Optional[PreTrainedTokenizer] = None, + ): + super(FDKTLLM, self).__init__() + self.ctx = ctx + self.model = model + self.embedding_model = embedding_model + self.dataset = dataset + self.training_args = training_args + self.tokenizer = tokenizer + + if not self.training_args.use_cpu: + self.model.cuda(self.training_args.device_id) + + def sync_synthetic_data(self): + return self.ctx.guest.get(SLM_SYNTHETIC_DATA) + + def sync_aug_data(self, aug_data): + self.ctx.guest.put(LLM_AUG_DATA, aug_data) + + def aug_data(self): + slm_data = self.sync_synthetic_data() + + filter_data = self.filter_data(slm_data) + + aug_prompts = self.dataset.prepare_augment( + filter_data["inputs"], + filter_data["labels"], + aug_prompt_num=self.training_args.aug_prompt_num + ) + + aug_data = self._aug(aug_prompts) + self.sync_aug_data(aug_data) + + def _aug(self, aug_prompts): + aug_responses = general_text_generate( + model=self.model, + tokenizer=self.tokenizer, + generation_config=self.training_args.aug_generation_config, + prompts=aug_prompts, + batch_size=self.training_args.aug_data_batch_size, + use_cpu=self.training_args.use_cpu, + prompt_max_length=self.training_args.aug_prompt_max_length + ) + + aug_data = self.dataset.abstract_from_augmented(aug_responses) + + return aug_data + + def filter_data(self, slm_data): + clustered_sentences, clustered_labels = self.cluster_data(slm_data) + filter_prompts = self.dataset.prepare_query_to_filter_clustered(clustered_sentences, clustered_labels) + filter_responses = general_text_generate( + model=self.model, + tokenizer=self.tokenizer, + generation_config=self.training_args.filter_generation_config, + prompts=filter_prompts, + batch_size=self.training_args.filter_data_batch_size, + use_cpu=self.training_args.use_cpu, + prompt_max_length=self.training_args.filter_prompt_max_length + ) + + filtered_sentences, filtered_labels = self.dataset.parse_clustered_response( + clustered_sentence=clustered_sentences, + clustered_labels=clustered_labels, + response_list=filter_responses + ) + + return dict( + inputs=filtered_sentences, + labels=filtered_labels + ) + + def cluster_data(self, slm_data): + sentences = slm_data["inputs"] + labels = slm_data["labels"] + + n_clusters = (len(sentences) + self.training_args.sample_num_per_cluster - 1) // self.training_args.sample_num_per_cluster + + cluster_ret = SentenceCluster(model=self.embedding_model, n_clusters=n_clusters).cluster(sentences) + + clustered_sentences = [[] for _ in range(n_clusters)] + clustered_labels = [[] for _ in range(n_clusters)] + + for sentence_id, cluster_id in enumerate(cluster_ret): + clustered_sentences[cluster_id].append(sentences[sentence_id]) + clustered_labels[cluster_id].append(labels[sentence_id]) + + return clustered_sentences, clustered_labels diff --git a/python/fate_llm/algo/fdkt/utils/__init__.py b/python/fate_llm/algo/fdkt/utils/__init__.py new file mode 100644 index 0000000..ef471ba --- /dev/null +++ b/python/fate_llm/algo/fdkt/utils/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# \ No newline at end of file diff --git a/python/fate_llm/algo/fdkt/utils/dp_loss.py b/python/fate_llm/algo/fdkt/utils/dp_loss.py new file mode 100644 index 0000000..4c91058 --- /dev/null +++ b/python/fate_llm/algo/fdkt/utils/dp_loss.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + + +NUMERICAL_STABILITY_CONSTANT = 1e-13 + + +class SequenceCrossEntropyLoss(nn.Module): + def __init__(self, model_type, label_smoothing=-1, reduce=None): + super().__init__() + self.model_type = model_type + self.label_smoothing = label_smoothing + self.reduce = reduce + + def forward(self, logits, targets, mask): + return sequence_cross_entropy_with_logits(logits, targets, mask, self.label_smoothing, self.reduce, self.model_type) + + +def sequence_cross_entropy_with_logits(logits, targets, mask, label_smoothing, reduce, model_type): + if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + logits = logits[:, :-1].contiguous() + targets = targets[:, 1:] + mask = torch.ones_like(targets).float() + + logits_flat = logits.view(-1, logits.size(-1)) + log_probs_flat = F.log_softmax(logits_flat, dim=-1) + targets_flat = targets.reshape(-1, 1).long() + + if label_smoothing > 0.0: + num_classes = logits.size(-1) + smoothing_value = label_smoothing / float(num_classes) + one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing) + smoothed_targets = one_hot_targets + smoothing_value + negative_log_likelihood_flat = -log_probs_flat * smoothed_targets + negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True) + else: + negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) + + negative_log_likelihood = negative_log_likelihood_flat.view(-1, logits.shape[1]) + + loss = negative_log_likelihood + # loss = negative_log_likelihood * mask + + if reduce: + loss = loss.sum(1) / (mask.sum(1) + NUMERICAL_STABILITY_CONSTANT) + + if reduce is "batch": + loss = loss.mean() + + return loss diff --git a/python/fate_llm/algo/fdkt/utils/text_generate.py b/python/fate_llm/algo/fdkt/utils/text_generate.py new file mode 100644 index 0000000..408c547 --- /dev/null +++ b/python/fate_llm/algo/fdkt/utils/text_generate.py @@ -0,0 +1,79 @@ +from tqdm import tqdm +from typing import Any, Dict, List + + +def slm_text_generate( + model, + tokenizer, + prompt_ids_dict, + seq_num_for_single_category, + batch_size, + use_cpu, + generation_config +): + generated_ret = dict( + inputs=list(), + labels=list(), + ) + for label, prompt_ids in prompt_ids_dict.items(): + prompt_length = len(prompt_ids) + batch_num = (seq_num_for_single_category + batch_size - 1) // batch_size + for batch_idx in tqdm(range(batch_num)): + if batch_idx + 1 == batch_num: + cur_batch_size = seq_num_for_single_category - batch_idx * batch_size + else: + cur_batch_size = batch_size + input_ids = prompt_ids.repeat(cur_batch_size, 1) + + if not use_cpu: + input_ids = input_ids.to(model.device) + + output_sequences = model.generate( + input_ids=input_ids, + **generation_config + ) + output_sequences = output_sequences[:, prompt_length:] + + generated_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) + + for g in generated_sequences: + generated_ret["inputs"].append(g) + generated_ret["labels"].append(label) + + return generated_ret + + +def general_text_generate( + model, + tokenizer, + generation_config: Dict[Any, Any], + prompts: List[str], + batch_size, + use_cpu: bool, + prompt_max_length +): + generate_texts = [] + batch_num = (len(prompts) + batch_size - 1) // batch_size + for batch_idx in range(batch_num): + batch_data = prompts[batch_idx * batch_size: (batch_idx + 1) * batch_size] + + inputs = tokenizer(batch_data, return_tensors="pt", padding="longest", truncation=True, + max_length=prompt_max_length) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + if not use_cpu: + input_ids = input_ids.to(model.device) + attention_mask = attention_mask.to(model.device) + + output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + + batch_responses = tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True) + + generate_texts.extend(batch_responses) + + return generate_texts diff --git a/python/fate_llm/model_zoo/embedding_transformer/__init__.py b/python/fate_llm/model_zoo/embedding_transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/model_zoo/embedding_transformer/st_model.py b/python/fate_llm/model_zoo/embedding_transformer/st_model.py new file mode 100644 index 0000000..3896069 --- /dev/null +++ b/python/fate_llm/model_zoo/embedding_transformer/st_model.py @@ -0,0 +1,71 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from sentence_transformers import SentenceTransformer +from typing import Any, Optional, Dict, Union + + +class SentenceTransformerModel(object): + def __init__( + self, + model_name_or_path: Optional[str] = None, + device: Optional[str] = None, + prompts: Optional[Dict[str, str]] = None, + default_prompt_name: Optional[str] = None, + cache_folder: Optional[str] = None, + trust_remote_code: bool = False, + revision: Optional[str] = None, + local_files_only: bool = False, + token: Optional[Union[bool, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.model_name_or_path = model_name_or_path + self.device = device + self.prompts = prompts + self.default_prompt_name = default_prompt_name + self.cache_folder = cache_folder + self.trust_remote_code = trust_remote_code + self.revision = revision + self.local_files_only = local_files_only + self.token = token + self.use_auth_token = use_auth_token + self.truncate_dim = truncate_dim + self.model_kwargs = model_kwargs + self.tokenizer_kwargs = tokenizer_kwargs + self.config_kwargs = config_kwargs + + def load(self): + model = SentenceTransformer( + model_name_or_path=self.model_name_or_path, + device=self.device, + prompts=self.prompts, + default_prompt_name=self.default_prompt_name, + cache_folder=self.cache_folder, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + local_files_only=self.local_files_only, + token=self.token, + use_auth_token=self.use_auth_token, + truncate_dim=self.truncate_dim, + model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, + config_kwargs=self.config_kwargs + ) + + return model diff --git a/python/fate_llm/model_zoo/hf_model.py b/python/fate_llm/model_zoo/hf_model.py index a7701c5..ecc6c00 100644 --- a/python/fate_llm/model_zoo/hf_model.py +++ b/python/fate_llm/model_zoo/hf_model.py @@ -1,3 +1,4 @@ +import torch from transformers import AutoModelForCausalLM @@ -7,6 +8,9 @@ def __init__(self, pretrained_model_name_or_path, *model_args, **kwargs) -> None self.pretrained_model_name_or_path = pretrained_model_name_or_path self.model_args = model_args self.kwargs = kwargs + if "torch_dtype" in self.kwargs and self.kwargs["torch_dtype"] != "auto": + dtype = self.kwargs.pop("torch_dtype") + self.kwargs["torch_dtype"] = getattr(torch, dtype) def load(self): model = AutoModelForCausalLM.from_pretrained( diff --git a/python/fate_llm/runner/fdkt_runner.py b/python/fate_llm/runner/fdkt_runner.py new file mode 100644 index 0000000..8e4d611 --- /dev/null +++ b/python/fate_llm/runner/fdkt_runner.py @@ -0,0 +1,195 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +from fate.components.components.nn.nn_runner import ( + load_model_dict_from_path, + dir_warning, + loader_load_from_conf, + run_dataset_func, +) +from typing import Dict +from fate.arch.dataframe import PandasReader +from fate.components.components.nn.loader import Loader +from typing import Union, Optional, Literal +from transformers.trainer_utils import get_last_checkpoint +import logging +from fate.arch.dataframe import DataFrame +from fate.components.components.nn.runner.homo_default_runner import DefaultRunner +from fate_llm.algo.fdkt import FDKTTrainingArguments, FDKTSLM, FDKTLLM + +logger = logging.getLogger(__name__) + + +class FDKTRunner(DefaultRunner): + def __init__( + self, + algo: str = "fdkt", + model_conf: Optional[Dict] = None, + embedding_model_conf: Optional[Dict] = None, + optimizer_conf: Optional[Dict] = None, + training_args_conf: Optional[Dict] = None, + dataset_conf: Optional[Dict] = None, + data_collator_conf: Optional[Dict] = None, + tokenizer_conf: Optional[Dict] = None, + task_type: Literal["causal_lm", "others"] = "causal_lm", + save_dp_model: bool = False, + ) -> None: + super(FDKTRunner, self).__init__() + self.algo = algo + self.model_conf = model_conf + self.embedding_model_conf = embedding_model_conf + self.optimizer_conf = optimizer_conf + self.training_args_conf = training_args_conf + self.dataset_conf = dataset_conf + self.data_collator_conf = data_collator_conf + self.tokenizer_conf = tokenizer_conf + self.task_type = task_type + self.save_dp_model = save_dp_model + + self.training_args = None + + # check param + if self.algo.lower() != "fdkt": + raise ValueError(f"algo should be fdkt") + if self.task_type not in ["causal_lm"]: + raise ValueError("task_type should be causal_lm") + + self.aug_data = None + + def common_setup(self, saved_model=None, output_dir=None): + ctx = self.get_context() + + if output_dir is None: + output_dir = "./" + + model = loader_load_from_conf(self.model_conf) + if model is None: + raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") + + resume_path = None + if saved_model is not None: + model_dict = load_model_dict_from_path(saved_model) + model.load_state_dict(model_dict) + logger.info(f"loading model dict from {saved_model} to model done") + if get_last_checkpoint(saved_model) is not None: + resume_path = saved_model + logger.info(f"checkpoint detected, resume_path set to {resume_path}") + + # load tokenizer if import conf provided + tokenizer = loader_load_from_conf(self.tokenizer_conf) + + # args + dir_warning(self.training_args_conf) + training_args = FDKTTrainingArguments(**self.training_args_conf) + # reset to default, saving to arbitrary path is not allowed in + # DefaultRunner + training_args.output_dir = output_dir + training_args.resume_from_checkpoint = resume_path # resume path + + self.training_args = training_args + dataset = loader_load_from_conf(self.dataset_conf) + + return ctx, model, tokenizer, training_args, dataset + + def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None): + ctx, model, tokenizer, training_args, dataset = self.common_setup( + output_dir=output_dir, saved_model=saved_model) + + model = model.load() + embedding_model = loader_load_from_conf(self.embedding_model_conf) + if embedding_model is None: + raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") + embedding_model = embedding_model.load() + + trainer = FDKTLLM( + ctx=ctx, + model=model, + embedding_model=embedding_model, + training_args=training_args, + tokenizer=tokenizer, + dataset=dataset, + ) + + return trainer + + def slm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None): + ctx, model, tokenizer, training_args, dataset = self.common_setup( + output_dir=output_dir, saved_model=saved_model) + model = model.load() + + dataset.load(train_set) + + if self.data_collator_conf is not None: + data_collator = loader_load_from_conf(self.data_collator_conf) + else: + data_collator = None + + optimizer_loader = Loader.from_dict(self.optimizer_conf) + optimizer_ = optimizer_loader.load_item() + optimizer_params = optimizer_loader.kwargs + optimizer = optimizer_(model.parameters(), **optimizer_params) + + trainer = FDKTSLM( + ctx=ctx, + model=model, + training_args=training_args, + tokenizer=tokenizer, + train_set=dataset, + data_collator=data_collator, + optimizer=optimizer, + ) + + return trainer + + def train( + self, + train_data: Optional[Union[str, DataFrame]] = None, + validate_data: Optional[Union[str, DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None, + ): + + if self.is_client(): + trainer = self.slm_setup(train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path) + self.aug_data = trainer.aug_data() + + if self.save_dp_model: + trainer.save_model(output_dir) + + else: + trainer = self.llm_setup( + train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path + ) + trainer.aug_data() + + def predict(self, *args, **kwargs): + if self.is_client(): + ctx = self.get_context() + df = pd.DataFrame() + texts = self.aug_data["inputs"] + labels = self.aug_data["labels"] + + sample_id_name = "sample_id" + match_id_name = "match_id" + df[sample_id_name] = list(map(str, range(len(texts)))) + df[match_id_name] = list(map(str, range(len(texts)))) + import json + texts = [json.dumps(text) for text in texts] + labels = [json.dumps(label) for label in labels] + df["inputs"] = texts + df["labels"] = labels + + reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, dtype="object") + return reader.to_frame(ctx, df) From 854a48890354e83cac8d4c1b79e0d60550fe99df Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 22 Jul 2024 16:04:09 +0800 Subject: [PATCH 19/50] add flex dataset for FDKT augmentation to rebased branch; add new dataset configs to rebased branch Signed-off-by: Yu Wu --- .../dataset/data_config/default_ag_news.yaml | 25 ++ .../data_config/default_yelp_review.yaml | 26 ++ python/fate_llm/dataset/flex_dataset.py | 368 ++++++++++++++++++ 3 files changed, 419 insertions(+) create mode 100644 python/fate_llm/dataset/data_config/default_ag_news.yaml create mode 100644 python/fate_llm/dataset/data_config/default_yelp_review.yaml create mode 100644 python/fate_llm/dataset/flex_dataset.py diff --git a/python/fate_llm/dataset/data_config/default_ag_news.yaml b/python/fate_llm/dataset/data_config/default_ag_news.yaml new file mode 100644 index 0000000..42a126c --- /dev/null +++ b/python/fate_llm/dataset/data_config/default_ag_news.yaml @@ -0,0 +1,25 @@ +dataset_kwargs: + data_files: ag_news_review/AGnews/train.json +dataset_path: json +doc_to_target: '{{label}}' +metric_list: +- aggregation: mean + higher_is_better: true + metric: accuracy +output_type: generate_until +task: ag-news +validation_split: train +label_key: label +text_key: text +sub_domain: AGnews +few_shot_num_per_label: 2 +tokenize_format: "Product type: {{sub_domain}} | Text Category: {{label}}" +few_shot_format: "- : {{label}}.\n- : {{text}}\n\n" +augment_format: "The news' topics belong to the following 4 categories: 0.world 1.sports 2.business 3.science and technology. Please generate news according to the following format, bearing in mind that the generated results should not resemble the examples, but should align with the specified category: \n" +text_with_label_format: "******\n {{i}}.\nNews: {{text}}\nCategory: {{label}}.\n" +filter_format: "I will give you some news samples with their categories, The news' topics belong to the following 4 categories: 0.world 1.sports 2.business 3.science and technology. the samples are delimited by '******':\n {text_with_label} Please filter out texts that are ambiguous, do not belong to news or do not meet the categories, and leave news texts that meet the categories.\n You should also filter out news text that are too similar to other samples and keep the most representative ones. Your answer should begin with 'The eligible samples:\n\n' and the indexes of the texts you choose, use spaces to separate the indexes and do not provide duplicate indices or indices that exceed the maximum index of samples." +label_list: + - 'world' + - 'sports' + - 'business' + - 'science and technology' \ No newline at end of file diff --git a/python/fate_llm/dataset/data_config/default_yelp_review.yaml b/python/fate_llm/dataset/data_config/default_yelp_review.yaml new file mode 100644 index 0000000..b123aa5 --- /dev/null +++ b/python/fate_llm/dataset/data_config/default_yelp_review.yaml @@ -0,0 +1,26 @@ +dataset_kwargs: + data_files: yelp_review/Health/train.json +dataset_path: json +doc_to_target: '{{label}}' +metric_list: +- aggregation: mean + higher_is_better: true + metric: accuracy +output_type: generate_until +task: yelp-review +label_key: stars +text_key: text +validation_split: train +sub_domain: Health +few_shot_num_per_label: 2 +tokenize_format: "Product type: {{sub_domain}} | Review Score: {{label}}" +text_with_label_format: "******\n {{i}}.\nReview: {{text}}\nRating stars: {{label}}.\n" +few_shot_format: "******\n- : {{label}} stars.\n- : {{text}}\n\n" +augment_format: "The reviews are rated from 1 to 5 stars, with 1 being the worst, 3 being neutral and 5 being the best. Please generate more similar samples for each rating star about the Health domain as shown in the following format, bearing in mind that the generated results should not copy or resemble the examples, and should align with the {{sub_domain}} domain and the rating stars.\nThe examples are delimited by '******'." +filter_format: "I will give you some customer review text samples with their rating stars, these samples are indexed starting from 0, the samples are delimited by '******':\n {{text_with_label}}. These reviews gradually shift from negative to positive from 1 star to 5 stars. 1 star represents the worst, 2 stars are better than 1 star, but still indicate a negative review. 3 stars represent a neutral review. 4 stars indicate a positive review, but less positive than 5 stars. 5 stars represent perfection.\n Please filter out text that does not belong to customer reviews or does not meet the rating stars, and leave review texts that meet the labels.\n You should also filter out text that are too similar to other samples and keep the most representative ones. Your answer should begin with 'The eligible samples:\n\n' and the indexes of the texts you choose, use spaces to separate the indexes and do not provide duplicate indices or indices that exceed the maximum index of samples." +label_list: + - 1 + - 2 + - 3 + - 4 + - 5 \ No newline at end of file diff --git a/python/fate_llm/dataset/flex_dataset.py b/python/fate_llm/dataset/flex_dataset.py new file mode 100644 index 0000000..0a387c5 --- /dev/null +++ b/python/fate_llm/dataset/flex_dataset.py @@ -0,0 +1,368 @@ +# +# Copyright 2024 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import re +from datasets import load_dataset +from fastchat.model import get_conversation_template +from jinja2 import Template +from ruamel import yaml +from transformers import AutoTokenizer +from typing import Union, Literal + +from fate.ml.nn.dataset.base import Dataset +from fate_llm.dataset.data_config import DATA_CONFIG_TEMPLATE + +logger = logging.getLogger(__name__) + + +""" +Implementation of FDKT augmentation process, adopted from https://arxiv.org/abs/2405.14212 +""" + + +def get_jinjax_placeholders(jinjax_text, placeholder_count=2): + pattern = r"<([^>]+)>" + matches = re.findall(pattern, jinjax_text) + + return matches[:placeholder_count] + + +def regex_replace(string, pattern, repl, count: int = 0): + """ + adopted from lm-evaluation-harness/lm-eval/utils.py for offline use + Parameters + ---------- + string + pattern + repl + count + + Returns + ------- + + """ + return re.sub(pattern, repl, string, count=count) + + +def apply_template(template, data): + """ + adopted from lm-evaluation-harness/lm-eval/utils.py for offline use + Parameters + ---------- + template + data + + Returns + ------- + + """ + return Template(template).render(data) + + +def tokenize_flex_dataset(raw_datasets, tokenizer, sub_domain, tokenize_format, text_key, label_key, data_part="train", + save_path=None, max_prompt_len=256): + tokenizer.pad_token = tokenizer.eos_token + column_names = raw_datasets[data_part].column_names + + def tokenize_function(examples): + texts = tokenizer(examples[text_key]) + + label_processed = [apply_template(tokenize_format,{"sub_domain": sub_domain,"label": label}) + for label in examples[label_key]] + labels = tokenizer(label_processed) + input_ids = [i2 + i1 for i1, i2 in zip(texts['input_ids'], labels['input_ids'])] + attention_mask = [i2 + i1 for i1, i2 in zip(texts['attention_mask'], labels['attention_mask'])] + + """ + cut off max prompt length + """ + input_ids = [t[: max_prompt_len] for t in input_ids] + attention_mask = [t[: max_prompt_len] for t in attention_mask] + + out = {"input_ids": input_ids, + "attention_mask": attention_mask, + "labels": input_ids} + return out + + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=4, + remove_columns=column_names, + desc="Running tokenizer on dataset", + ) + + if save_path is not None: + tokenized_datasets.save_to_disk(save_path) + + return tokenized_datasets + + +class FlexDataset(Dataset): + def __init__(self, + tokenizer_path, + dataset_name: str, + load_from: Literal['json'] = 'json', + data_part: str = None, + config: Union[dict, str] = None, + need_preprocess: bool = True, + random_state: int = None, + max_prompt_len: int = 256, + select_num: int = None, + few_shot_num_per_label: int = None + ): + + super().__init__() + self.tokenizer = None + self.tokenizer_path = tokenizer_path + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True) + self.dataset_name = dataset_name + if self.dataset_name: + config = DATA_CONFIG_TEMPLATE.get("datasets", {}) + self.load_from = load_from + self.data_part = data_part + self.random_state = random_state + self.need_preprocess = need_preprocess + self.max_prompt_len = max_prompt_len + self.select_num = select_num + self.dataset = None + self.ds = None + self.label_key = None + self.text_key = None + self.augment_format = None + self.filter_format = None + self.few_shot_format = None + self.tokenize_format = None + self.sub_domain = None + self.label_list = None + self.text_with_label_format = None + self.few_shot_num_per_label = few_shot_num_per_label + self.config = config + if isinstance(config, str): + with open(config, 'r') as f: + self.config = yaml.safe_load(f) + self.parse_config() + + def parse_config(self, config=None): + if config is None: + config = self.config + self.label_key = config.get("label_key", None) + self.text_key = config.get("text_key", None) + self.augment_format = config.get("augment_format", None) + self.filter_format = config.get("filter_format", None) + self.tokenize_format = config.get("tokenize_format", None) + self.sub_domain = config.get("sub_domain", None) + self.label_list = config.get("label_list", None) + self.few_shot_format = config.get("few_shot_format", None) + self.text_with_label_format = config.get("text_with_label_format", None) + if self.few_shot_num_per_label is None: + self.few_shot_num_per_label = config.get("few_shot_num_per_label", 2) + + def get_generate_prompt(self, tokenize=True, return_tensors="pt"): + prompt_list = [apply_template(self.tokenize_format, + {"sub_domain": self.sub_domain, + "label": label}) for label in self.label_list] + if tokenize: + tokenized_prompts = self.tokenizer(prompt_list, return_tensors=return_tensors) + prompt_list = tokenized_prompts['input_ids'] + + return {label: prompt for label, prompt in zip(self.label_list, prompt_list)} + + @staticmethod + def construct_prompt_list(samples_dict, num_shot_per_label, prompt_num, format_template, random_state=None): + from sklearn.utils import resample + from collections import deque + + label_samples = {label: deque(resample(samples, + replace=False, + n_samples=len(samples))) for label, samples in samples_dict.items()} + def get_samples_for_label(label): + samples = [] + while len(samples) < num_shot_per_label: + remaining_needed = num_shot_per_label - len(samples) + if len(label_samples[label]) < remaining_needed: + batch_samples = list(label_samples[label]) + samples.extend(batch_samples) + # reset to allow repetition + label_samples[label] = deque(resample(samples_dict[label], + replace=False, + n_samples=len(samples_dict[label]))) + else: + batch_samples = [label_samples[label].popleft() for _ in range(remaining_needed)] + samples.extend(batch_samples) + return samples + + result = [] + for _ in range(prompt_num): + prompt = '' + for label in samples_dict.keys(): + samples = get_samples_for_label(label) + for text in samples: + prompt += apply_template(format_template, {"text": text, "label": label}) + result.append(prompt) + return result + + @staticmethod + def group_text_label_list(text_list, label_list): + group_data = [{"text": text, "label": label} for text, label in zip(text_list, label_list)] + return group_data + + def prepare_few_shot(self, text_list, label_list, aug_prompt_num): + from collections import defaultdict + data_dict = defaultdict(list) + for text, label in zip(text_list, label_list): + # in case extra labels are present, ignore + if label in self.label_list: + data_dict[label].append(text) + few_shot_list = FlexDataset.construct_prompt_list(samples_dict=data_dict, + num_shot_per_label=self.few_shot_num_per_label, + prompt_num=aug_prompt_num, + format_template=self.few_shot_format, + random_state=self.random_state) + + return few_shot_list + + def prepare_augment(self, text_list, label_list, aug_prompt_num): + few_shot_samples = self.prepare_few_shot(text_list, label_list, aug_prompt_num) + result = [] + instruction = apply_template(self.augment_format, {"sub_domain": self.sub_domain}) + for i, sample in enumerate(few_shot_samples): + query = instruction + '\n' + sample + formatted_query = self.apply_chat_template(query) + result.append(formatted_query) + return result + + def abstract_from_augmented(self, sample_list): + label_key, text_key = get_jinjax_placeholders(self.few_shot_format, 2) + res = {'inputs': [], 'labels': []} + print(f"text_key: {text_key}, label_key: {label_key}") + for sample in sample_list: + data_list = sample.split('\n\n-') + print(f"data list: {data_list}") + for entry in data_list: + temp = entry.split(f"<{text_key}>:") + print(f"temp: {temp}") + if len(temp) == 2 and f"<{label_key}>" in temp[0]: + label_str, input_str = temp + label = label_str.split(f"<{label_key}>:")[1].strip() + if label[0].isdigit(): + label = int(label[0]) + elif re.match(r'^\d+\.\d*?$', label): + label = float(label[0]) + text = input_str.replace('', '').rstrip('*') + text = text.strip() + res['inputs'].append(text) + res['labels'].append(label) + return res + + def prepare_query_to_filter_clustered(self, clustered_sentences_list, clustered_labels_list): + prompt_list = [] + for clustered_sentences, clustered_labels in zip(clustered_sentences_list, clustered_labels_list): + text_with_label = '' + for i in range(len(clustered_sentences)): + formatted_entry = apply_template(self.text_with_label_format, {"i": i, + "text": clustered_sentences[i], + "label": clustered_labels[i]}) + text_with_label += formatted_entry + cluster_query = apply_template(self.filter_format, {"text_with_label": text_with_label}) + prompt_list.append(self.apply_chat_template(cluster_query)) + return prompt_list + + def parse_clustered_response(self, clustered_sentence, clustered_labels, response_list): + """ + Parse the response from the clustering model and filter the data per cluster. + :param clustered_sentence: nested list of clustered sentences + :param clustered_labels: nested list of clustered labels + :param response_list: list of responses from the clustering model + """ + def parse_response(response): + pattern = r'The eligible samples:\s*((?:\b\d+\b[\s.,]*)+)' + matches = re.search(pattern, response, re.MULTILINE) + if matches: + digits = [int(i) for i in re.findall(r'\b\d+\b', matches.group())] + else: + digits = [] + return list(set(digits)) + + filtered_text_list = [] + filtered_label_list = [] + for i in range(len(clustered_sentence)): + parsed_response = parse_response(response_list[i]) + for idx in parsed_response: + if idx < len(clustered_sentence[i]): + filtered_label_list.append(clustered_labels[i][idx]) + filtered_text_list.append(clustered_sentence[i][idx]) + return filtered_text_list, filtered_label_list + + @staticmethod + def group_data_list(data_list, text_key, label_key): + inputs = [entry[text_key] for entry in data_list] + labels = [entry[label_key] for entry in data_list] + data_dict = {text_key: inputs, label_key: labels} + return data_dict + + def load(self, path): + local_data = load_dataset('json', data_files={self.data_part: path}) + self.dataset = local_data + if not self.need_preprocess: + self.ds = local_data + else: + tokenized_ds = tokenize_flex_dataset( + raw_datasets=local_data, + tokenizer=self.tokenizer, + sub_domain=self.sub_domain, + tokenize_format=self.tokenize_format, + text_key=self.text_key, + label_key=self.label_key + ) + self.ds = tokenized_ds[self.data_part] + + if self.select_num is not None: + self.ds = self.ds.select(range(self.select_num)) + + def apply_chat_template(self, query): + tokenizer = self.tokenizer + + if "llama-3" in self.tokenizer_path.lower(): + msg = [ + {"role": "system", "content": "You are a helpful assistant. "}, + {"role": "user", "content": query} + ] + prompt = tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize=False) + else: + conv = get_conversation_template(self.tokenizer_path) + conv.append_message(conv.roles[0], query) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt + + def get_raw_dataset(self): + return self.dataset + + def __len__(self): + return len(self.ds) + + def get_item(self, i): + return self.dataset[self.data_part][i] + + def get_item_dict(self, i): + return {"text": self.dataset[self.data_part][self.text_key][i], + "label": self.dataset[self.data_part][self.label_key][i]} + + def __getitem__(self, i) -> dict: + return self.ds[i] From 66e99e66c8b5f8f13186f4ece29180540aa69f2e Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 22 Jul 2024 16:14:55 +0800 Subject: [PATCH 20/50] add new dataset configs to rebased branch Signed-off-by: Yu Wu --- python/fate_llm/dataset/data_config/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 python/fate_llm/dataset/data_config/__init__.py diff --git a/python/fate_llm/dataset/data_config/__init__.py b/python/fate_llm/dataset/data_config/__init__.py new file mode 100644 index 0000000..6056e3a --- /dev/null +++ b/python/fate_llm/dataset/data_config/__init__.py @@ -0,0 +1,6 @@ +import os +# absolute path to current directory +parent_dir = os.path.dirname(os.path.realpath(__file__)) + +DATA_CONFIG_TEMPLATE = {"ag_news": os.path.join(parent_dir, "default_ag_news.yaml"), + "yelp_review": os.path.join(parent_dir, "default_yelp_review.yaml"),} \ No newline at end of file From c5d693624e3dd973f9de18f0872f34f8b5111fec Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 23 Jul 2024 15:31:49 +0800 Subject: [PATCH 21/50] only retain text-label pair if label types match Signed-off-by: Yu Wu --- python/fate_llm/dataset/flex_dataset.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/fate_llm/dataset/flex_dataset.py b/python/fate_llm/dataset/flex_dataset.py index 0a387c5..3a06731 100644 --- a/python/fate_llm/dataset/flex_dataset.py +++ b/python/fate_llm/dataset/flex_dataset.py @@ -15,6 +15,7 @@ # import logging +import pickle import re from datasets import load_dataset from fastchat.model import get_conversation_template @@ -131,8 +132,8 @@ def __init__(self, self.tokenizer_path = tokenizer_path self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True) self.dataset_name = dataset_name - if self.dataset_name: - config = DATA_CONFIG_TEMPLATE.get("datasets", {}) + if self.dataset_name and config is None: + config = DATA_CONFIG_TEMPLATE.get(self.dataset_name, "") self.load_from = load_from self.data_part = data_part self.random_state = random_state @@ -246,27 +247,31 @@ def prepare_augment(self, text_list, label_list, aug_prompt_num): result.append(formatted_query) return result - def abstract_from_augmented(self, sample_list): + def abstract_from_augmented(self, augmented_responses_path): + with open(augmented_responses_path, "rb") as fin: + sample_list = pickle.loads(fin.read()) label_key, text_key = get_jinjax_placeholders(self.few_shot_format, 2) res = {'inputs': [], 'labels': []} - print(f"text_key: {text_key}, label_key: {label_key}") for sample in sample_list: data_list = sample.split('\n\n-') - print(f"data list: {data_list}") for entry in data_list: temp = entry.split(f"<{text_key}>:") - print(f"temp: {temp}") + # print(f"temp: {temp}") if len(temp) == 2 and f"<{label_key}>" in temp[0]: label_str, input_str = temp label = label_str.split(f"<{label_key}>:")[1].strip() - if label[0].isdigit(): + if isinstance(self.label_list[0], int) and label[0].isdigit(): label = int(label[0]) - elif re.match(r'^\d+\.\d*?$', label): + elif isinstance(self.label_list[0], float) and re.match(r'^\d+\.\d*?$', label): label = float(label[0]) + # abstracted label value does not match the original label type + elif isinstance(self.label_list[0], int) or isinstance(self.label_list[0], float): + continue text = input_str.replace('', '').rstrip('*') text = text.strip() res['inputs'].append(text) res['labels'].append(label) + # print(f"res: {res}") return res def prepare_query_to_filter_clustered(self, clustered_sentences_list, clustered_labels_list): From 8a5ed5f1c2444f5bd3f23f31f16ac61845b0109a Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 23 Jul 2024 15:49:57 +0800 Subject: [PATCH 22/50] update fdkt: support api requests for llm generate Signed-off-by: mgqa34 --- python/fate_llm/algo/fdkt/fdkt_data_aug.py | 28 +++++++-- python/fate_llm/algo/fdkt/inference_inst.py | 10 ++++ .../fate_llm/algo/fdkt/utils/text_generate.py | 42 +++++++------ python/fate_llm/runner/fdkt_runner.py | 60 +++++++++---------- 4 files changed, 86 insertions(+), 54 deletions(-) create mode 100644 python/fate_llm/algo/fdkt/inference_inst.py diff --git a/python/fate_llm/algo/fdkt/fdkt_data_aug.py b/python/fate_llm/algo/fdkt/fdkt_data_aug.py index 501b7fc..c913699 100644 --- a/python/fate_llm/algo/fdkt/fdkt_data_aug.py +++ b/python/fate_llm/algo/fdkt/fdkt_data_aug.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os +import os.path + import torch import logging from dataclasses import dataclass, field @@ -23,6 +24,7 @@ from transformers import PreTrainedTokenizer from .utils.text_generate import slm_text_generate, general_text_generate from .cluster.cluster import SentenceCluster +from fate_llm.inference.inference_base import Inference logger = logging.getLogger(__name__) @@ -40,6 +42,8 @@ class FDKTTrainingArguments(Seq2SeqTrainingArguments): target_delta: float = field(default=1e-5) freeze_embedding: bool = field(default=False) device_id: int = field(default=0) + slm_generation_config: dict = field(default=None) + slm_generation_batch_size: dict = field(default=None) """ slm generation config @@ -116,13 +120,19 @@ def aug_data(self): self.tokenizer, prompt_ids_dict=prefix_prompt_ids_dict, seq_num_for_single_category=self.training_args.seq_num_for_single_category, - batch_size=self.training_args.per_device_train_batch_size, + batch_size=self.training_args.slm_generation_batch_size, use_cpu=self.training_args.use_cpu, - generation_config=self.training_args.generation_config + generation_config=self.training_args.slm_generation_config ) + if not self.training_args.use_cpu: + self.model.cpu() + torch.cuda.empty_cache() + self.sync_synthetic_dataset(generated_text) + return self.sync_aug_data() + def dp_train(self): from ..dp import DPTrainer, DPTrainingArguments, get_model_class from .utils.dp_loss import SequenceCrossEntropyLoss @@ -176,21 +186,25 @@ class FDKTLLM(object): def __init__( self, ctx: Context, - model: torch.nn.Module, embedding_model: torch.nn.Module, training_args: FDKTTrainingArguments, dataset, + model: Optional[torch.nn.Module] = None, tokenizer: Optional[PreTrainedTokenizer] = None, + inference_inst: Optional[Inference] = None, ): super(FDKTLLM, self).__init__() self.ctx = ctx - self.model = model + self.inference_inst = inference_inst self.embedding_model = embedding_model self.dataset = dataset self.training_args = training_args + self.model = model self.tokenizer = tokenizer - if not self.training_args.use_cpu: + if self.inference_inst is None and (self.model is None or self.tokenizer is None): + raise ValueError("Inference_inst and Model are both empty, should provided one") + if self.model is not None and self.training_args.device_id is not None and not self.training_args.use_cpu: self.model.cuda(self.training_args.device_id) def sync_synthetic_data(self): @@ -215,6 +229,7 @@ def aug_data(self): def _aug(self, aug_prompts): aug_responses = general_text_generate( + inference_inst=self.inference_inst, model=self.model, tokenizer=self.tokenizer, generation_config=self.training_args.aug_generation_config, @@ -232,6 +247,7 @@ def filter_data(self, slm_data): clustered_sentences, clustered_labels = self.cluster_data(slm_data) filter_prompts = self.dataset.prepare_query_to_filter_clustered(clustered_sentences, clustered_labels) filter_responses = general_text_generate( + inference_inst=self.inference_inst, model=self.model, tokenizer=self.tokenizer, generation_config=self.training_args.filter_generation_config, diff --git a/python/fate_llm/algo/fdkt/inference_inst.py b/python/fate_llm/algo/fdkt/inference_inst.py new file mode 100644 index 0000000..627e136 --- /dev/null +++ b/python/fate_llm/algo/fdkt/inference_inst.py @@ -0,0 +1,10 @@ +from fate_llm.inference.api import APICompletionInference + + +def init(api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600): + return APICompletionInference( + api_url=api_url, + model_name=model_name, + api_key=api_key, + api_timeout=api_timeout + ) \ No newline at end of file diff --git a/python/fate_llm/algo/fdkt/utils/text_generate.py b/python/fate_llm/algo/fdkt/utils/text_generate.py index 408c547..4e31b88 100644 --- a/python/fate_llm/algo/fdkt/utils/text_generate.py +++ b/python/fate_llm/algo/fdkt/utils/text_generate.py @@ -11,6 +11,7 @@ def slm_text_generate( use_cpu, generation_config ): + model.eval() generated_ret = dict( inputs=list(), labels=list(), @@ -44,6 +45,7 @@ def slm_text_generate( def general_text_generate( + inference_inst, model, tokenizer, generation_config: Dict[Any, Any], @@ -52,28 +54,32 @@ def general_text_generate( use_cpu: bool, prompt_max_length ): - generate_texts = [] - batch_num = (len(prompts) + batch_size - 1) // batch_size - for batch_idx in range(batch_num): - batch_data = prompts[batch_idx * batch_size: (batch_idx + 1) * batch_size] + if inference_inst is not None: + generate_texts = inference_inst.inference(prompts, generation_config) + else: + model.eval() + generate_texts = [] + batch_num = (len(prompts) + batch_size - 1) // batch_size + for batch_idx in range(batch_num): + batch_data = prompts[batch_idx * batch_size: (batch_idx + 1) * batch_size] - inputs = tokenizer(batch_data, return_tensors="pt", padding="longest", truncation=True, - max_length=prompt_max_length) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] + inputs = tokenizer(batch_data, return_tensors="pt", padding="longest", truncation=True, + max_length=prompt_max_length) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] - if not use_cpu: - input_ids = input_ids.to(model.device) - attention_mask = attention_mask.to(model.device) + if not use_cpu: + input_ids = input_ids.to(model.device) + attention_mask = attention_mask.to(model.device) - output = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - **generation_config - ) + output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) - batch_responses = tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True) + batch_responses = tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True) - generate_texts.extend(batch_responses) + generate_texts.extend(batch_responses) return generate_texts diff --git a/python/fate_llm/runner/fdkt_runner.py b/python/fate_llm/runner/fdkt_runner.py index 8e4d611..3594727 100644 --- a/python/fate_llm/runner/fdkt_runner.py +++ b/python/fate_llm/runner/fdkt_runner.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pandas as pd +import logging +import torch from fate.components.components.nn.nn_runner import ( load_model_dict_from_path, dir_warning, @@ -20,22 +21,23 @@ run_dataset_func, ) from typing import Dict -from fate.arch.dataframe import PandasReader from fate.components.components.nn.loader import Loader from typing import Union, Optional, Literal from transformers.trainer_utils import get_last_checkpoint -import logging from fate.arch.dataframe import DataFrame from fate.components.components.nn.runner.homo_default_runner import DefaultRunner from fate_llm.algo.fdkt import FDKTTrainingArguments, FDKTSLM, FDKTLLM logger = logging.getLogger(__name__) +AUG_DATA_SAVED_PATH_SUFFIX = "aug_data.pkl" +DP_MODEL_SAVED_PATH_SUFFIX = "dp_model" class FDKTRunner(DefaultRunner): def __init__( self, algo: str = "fdkt", + inference_inst_conf: Optional[Dict] = None, model_conf: Optional[Dict] = None, embedding_model_conf: Optional[Dict] = None, optimizer_conf: Optional[Dict] = None, @@ -48,6 +50,7 @@ def __init__( ) -> None: super(FDKTRunner, self).__init__() self.algo = algo + self.inference_inst_conf = inference_inst_conf self.model_conf = model_conf self.embedding_model_conf = embedding_model_conf self.optimizer_conf = optimizer_conf @@ -66,17 +69,16 @@ def __init__( if self.task_type not in ["causal_lm"]: raise ValueError("task_type should be causal_lm") - self.aug_data = None - def common_setup(self, saved_model=None, output_dir=None): ctx = self.get_context() if output_dir is None: output_dir = "./" - model = loader_load_from_conf(self.model_conf) - if model is None: - raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") + if self.model_conf is not None: + model = loader_load_from_conf(self.model_conf) + else: + model = None resume_path = None if saved_model is not None: @@ -88,7 +90,10 @@ def common_setup(self, saved_model=None, output_dir=None): logger.info(f"checkpoint detected, resume_path set to {resume_path}") # load tokenizer if import conf provided - tokenizer = loader_load_from_conf(self.tokenizer_conf) + if self.tokenizer_conf is not None: + tokenizer = loader_load_from_conf(self.tokenizer_conf) + else: + tokenizer = None # args dir_warning(self.training_args_conf) @@ -107,7 +112,13 @@ def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_mo ctx, model, tokenizer, training_args, dataset = self.common_setup( output_dir=output_dir, saved_model=saved_model) - model = model.load() + if model is not None: + model = model.load() + + inference_inst = None + if self.inference_inst_conf is not None: + inference_inst = loader_load_from_conf(self.inference_inst_conf) + embedding_model = loader_load_from_conf(self.embedding_model_conf) if embedding_model is None: raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") @@ -115,6 +126,7 @@ def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_mo trainer = FDKTLLM( ctx=ctx, + inference_inst=inference_inst, model=model, embedding_model=embedding_model, training_args=training_args, @@ -163,10 +175,15 @@ def train( if self.is_client(): trainer = self.slm_setup(train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path) - self.aug_data = trainer.aug_data() + aug_data = trainer.aug_data() + + data_saved_path = output_dir + '/' + AUG_DATA_SAVED_PATH_SUFFIX + logger.info('result save to path {}'.format(data_saved_path)) + torch.save(aug_data, data_saved_path) if self.save_dp_model: - trainer.save_model(output_dir) + model_save_dir = output_dir + "/" + DP_MODEL_SAVED_PATH_SUFFIX + trainer.save_model(model_save_dir) else: trainer = self.llm_setup( @@ -175,21 +192,4 @@ def train( trainer.aug_data() def predict(self, *args, **kwargs): - if self.is_client(): - ctx = self.get_context() - df = pd.DataFrame() - texts = self.aug_data["inputs"] - labels = self.aug_data["labels"] - - sample_id_name = "sample_id" - match_id_name = "match_id" - df[sample_id_name] = list(map(str, range(len(texts)))) - df[match_id_name] = list(map(str, range(len(texts)))) - import json - texts = [json.dumps(text) for text in texts] - labels = [json.dumps(label) for label in labels] - df["inputs"] = texts - df["labels"] = labels - - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, dtype="object") - return reader.to_frame(ctx, df) + pass From 8e3954f75e9ad7a7f83397bb196565bda047866e Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 23 Jul 2024 15:51:14 +0800 Subject: [PATCH 23/50] fix func input Signed-off-by: Yu Wu --- python/fate_llm/dataset/flex_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/fate_llm/dataset/flex_dataset.py b/python/fate_llm/dataset/flex_dataset.py index 3a06731..683cd6a 100644 --- a/python/fate_llm/dataset/flex_dataset.py +++ b/python/fate_llm/dataset/flex_dataset.py @@ -247,9 +247,7 @@ def prepare_augment(self, text_list, label_list, aug_prompt_num): result.append(formatted_query) return result - def abstract_from_augmented(self, augmented_responses_path): - with open(augmented_responses_path, "rb") as fin: - sample_list = pickle.loads(fin.read()) + def abstract_from_augmented(self, sample_list): label_key, text_key = get_jinjax_placeholders(self.few_shot_format, 2) res = {'inputs': [], 'labels': []} for sample in sample_list: From 39d4c0ec07461cc823b56502199e99b87ebc3334 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Fri, 26 Jul 2024 14:09:03 +0800 Subject: [PATCH 24/50] update fdkt: support vllm offline inference in slm Signed-off-by: mgqa34 --- python/fate_llm/algo/dp/__init__.py | 15 +++++ python/fate_llm/algo/dp/dp_trainer.py | 4 +- .../algo/dp/opacus_compatibility/__init__.py | 15 +++++ .../grad_sample/__init__.py | 15 +++++ .../grad_sample/embedding.py | 22 +++++- .../optimizers/__init__.py | 15 +++++ .../optimizers/optimizer.py | 3 + .../transformers_compate.py | 15 +++++ python/fate_llm/algo/fdkt/cluster/__init__.py | 2 +- python/fate_llm/algo/fdkt/fdkt_data_aug.py | 44 +++++++++++- python/fate_llm/algo/fdkt/inference_inst.py | 33 +++++++-- python/fate_llm/algo/fdkt/utils/__init__.py | 2 +- python/fate_llm/algo/fdkt/utils/dp_loss.py | 18 ++++- .../fate_llm/algo/fdkt/utils/text_generate.py | 67 +++++++++++++------ .../embedding_transformer/__init__.py | 15 +++++ python/fate_llm/model_zoo/hf_model.py | 15 +++++ python/fate_llm/runner/fdkt_runner.py | 1 + 17 files changed, 264 insertions(+), 37 deletions(-) diff --git a/python/fate_llm/algo/dp/__init__.py b/python/fate_llm/algo/dp/__init__.py index ce0bd2e..7e1c95d 100644 --- a/python/fate_llm/algo/dp/__init__.py +++ b/python/fate_llm/algo/dp/__init__.py @@ -1,2 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from .opacus_compatibility.transformers_compate import get_model_class from .dp_trainer import DPTrainer, DPTrainingArguments diff --git a/python/fate_llm/algo/dp/dp_trainer.py b/python/fate_llm/algo/dp/dp_trainer.py index c159cc8..81d7b35 100644 --- a/python/fate_llm/algo/dp/dp_trainer.py +++ b/python/fate_llm/algo/dp/dp_trainer.py @@ -32,7 +32,7 @@ class DPTrainingArguments(Seq2SeqTrainingArguments): target_epsilon: float = field(default=3) target_delta: float = field(default=1e-5) - freeze_embedding: bool = field(default=False) + freeze_embedding: bool = field(default=True) device_id: int = field(default=0) @@ -95,7 +95,9 @@ def _init_dp_model(self): add_optimizer_compatibility(self.dp_optimizer) def train(self): + logger.info(f"begin dp training, total epochs={self.training_args.num_train_epochs}") for epoch in range(int(self.training_args.num_train_epochs)): + logger.info(f"dp training on epoch={epoch}") self._train_an_epoch() def _train_an_epoch(self): diff --git a/python/fate_llm/algo/dp/opacus_compatibility/__init__.py b/python/fate_llm/algo/dp/opacus_compatibility/__init__.py index 8d065fe..8a0b176 100644 --- a/python/fate_llm/algo/dp/opacus_compatibility/__init__.py +++ b/python/fate_llm/algo/dp/opacus_compatibility/__init__.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from .grad_sample.embedding import compute_embedding_grad_sample from .optimizers.optimizer import add_noise_wrapper diff --git a/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py index e69de29..878d3a9 100644 --- a/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py +++ b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py index fa77da6..fcb4464 100644 --- a/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py +++ b/python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py @@ -1,10 +1,26 @@ -from typing import Dict -import types - +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch import torch.nn as nn +from typing import Dict +# the function is modified from https://github.com/pytorch/opacus/blob/main/opacus/grad_sample/embedding.py#L25, +# avoid dtype error when backprops's dtype isn't torch.float32 def compute_embedding_grad_sample( layer: nn.Embedding, activations: torch.Tensor, backprops: torch.Tensor ) -> Dict[nn.Parameter, torch.Tensor]: diff --git a/python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py index e69de29..878d3a9 100644 --- a/python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py +++ b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py index 360188b..add3104 100644 --- a/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py +++ b/python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py @@ -1,4 +1,5 @@ # +# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +22,8 @@ ) +# modified from https://github.com/pytorch/opacus/blob/main/opacus/optimizers/optimizer.py#L424 +# avoid dtype error when summed_grad's dtype isn't torch.float32 def add_noise(self): """ Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad`` diff --git a/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py b/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py index 2235aff..c26eab5 100644 --- a/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py +++ b/python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch import transformers from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM diff --git a/python/fate_llm/algo/fdkt/cluster/__init__.py b/python/fate_llm/algo/fdkt/cluster/__init__.py index ef471ba..878d3a9 100644 --- a/python/fate_llm/algo/fdkt/cluster/__init__.py +++ b/python/fate_llm/algo/fdkt/cluster/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/python/fate_llm/algo/fdkt/fdkt_data_aug.py b/python/fate_llm/algo/fdkt/fdkt_data_aug.py index c913699..bab5744 100644 --- a/python/fate_llm/algo/fdkt/fdkt_data_aug.py +++ b/python/fate_llm/algo/fdkt/fdkt_data_aug.py @@ -14,6 +14,7 @@ # limitations under the License. # import os.path +import shutil import torch import logging @@ -44,6 +45,8 @@ class FDKTTrainingArguments(Seq2SeqTrainingArguments): device_id: int = field(default=0) slm_generation_config: dict = field(default=None) slm_generation_batch_size: dict = field(default=None) + inference_method: str = field(default="native") + inference_inst_init_conf: dict = field(default=None) """ slm generation config @@ -111,20 +114,30 @@ def __init__( self.model.cuda(self.training_args.device_id) def aug_data(self): + logging.info("Start aug data process") + logging.debug(f"dp_training={self.training_args.dp_training}") if self.training_args.dp_training: + logging.info("Start dp training") self.dp_train() + logging.info("End dp training") + + inference_inst = self._create_inference_inst() + prefix_prompt_dict = self.train_set.get_generate_prompt( + tokenize=True if inference_inst is None else False) - prefix_prompt_ids_dict = self.train_set.get_generate_prompt(tokenize=True) generated_text = slm_text_generate( + inference_inst, self.model, self.tokenizer, - prompt_ids_dict=prefix_prompt_ids_dict, + prompt_dict=prefix_prompt_dict, seq_num_for_single_category=self.training_args.seq_num_for_single_category, batch_size=self.training_args.slm_generation_batch_size, use_cpu=self.training_args.use_cpu, generation_config=self.training_args.slm_generation_config ) + self._destroy_inference_inst() + if not self.training_args.use_cpu: self.model.cpu() torch.cuda.empty_cache() @@ -164,6 +177,29 @@ def dp_train(self): dp_trainer.train() + def _create_inference_inst(self): + if self.training_args.inference_method == "native": + return None + elif self.training_args.inference_method == "vllm": + from .inference_inst import vllm_init + + self.model.cpu() + model_temp_path = self.training_args.output_dir + "./model_for_inference" + self.tokenizer.save_pretrained(model_temp_path) + self.model.save_pretrained(model_temp_path) + + return vllm_init(model_temp_path) if self.training_args.inference_inst_init_conf is None \ + else vllm_init(model_temp_path, **self.training_args.inference_inst_init_conf) + + else: + raise ValueError(f"not supported inference_method={self.training_args.inference_method}") + + def _destroy_inference_inst(self): + if self.training_args.inference_method == "vllm": + shutil.rmtree(self.training_args.output_dir + "./model_for_inference") + elif not self.training_args.use_cpu: + self.model.cpu() + def sync_synthetic_dataset(self, data): self.ctx.arbiter.put(SLM_SYNTHETIC_DATA, data) @@ -214,16 +250,20 @@ def sync_aug_data(self, aug_data): self.ctx.guest.put(LLM_AUG_DATA, aug_data) def aug_data(self): + logging.info("sync slm synthetic_data") slm_data = self.sync_synthetic_data() + logging.info("filter slm synthetic data") filter_data = self.filter_data(slm_data) + logging.info("prepare prompts for aug") aug_prompts = self.dataset.prepare_augment( filter_data["inputs"], filter_data["labels"], aug_prompt_num=self.training_args.aug_prompt_num ) + logging.info("aug_data") aug_data = self._aug(aug_prompts) self.sync_aug_data(aug_data) diff --git a/python/fate_llm/algo/fdkt/inference_inst.py b/python/fate_llm/algo/fdkt/inference_inst.py index 627e136..3c0b336 100644 --- a/python/fate_llm/algo/fdkt/inference_inst.py +++ b/python/fate_llm/algo/fdkt/inference_inst.py @@ -1,10 +1,33 @@ -from fate_llm.inference.api import APICompletionInference - - -def init(api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600): +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +def api_init(api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600): + from fate_llm.inference.api import APICompletionInference return APICompletionInference( api_url=api_url, model_name=model_name, api_key=api_key, api_timeout=api_timeout - ) \ No newline at end of file + ) + + +def vllm_init(model_path: str, num_gpu=1, dtype='float16', gpu_memory_utilization=0.9): + from fate_llm.inference.vllm import VLLMInference + return VLLMInference( + model_path=model_path, + num_gpu=num_gpu, + dtype=dtype, + gpu_memory_utilization=gpu_memory_utilization + ) diff --git a/python/fate_llm/algo/fdkt/utils/__init__.py b/python/fate_llm/algo/fdkt/utils/__init__.py index ef471ba..878d3a9 100644 --- a/python/fate_llm/algo/fdkt/utils/__init__.py +++ b/python/fate_llm/algo/fdkt/utils/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/python/fate_llm/algo/fdkt/utils/dp_loss.py b/python/fate_llm/algo/fdkt/utils/dp_loss.py index 4c91058..a454cc2 100644 --- a/python/fate_llm/algo/fdkt/utils/dp_loss.py +++ b/python/fate_llm/algo/fdkt/utils/dp_loss.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch import torch.nn as nn import torch.nn.functional as F @@ -40,8 +55,7 @@ def sequence_cross_entropy_with_logits(logits, targets, mask, label_smoothing, r negative_log_likelihood = negative_log_likelihood_flat.view(-1, logits.shape[1]) - loss = negative_log_likelihood - # loss = negative_log_likelihood * mask + loss = negative_log_likelihood * mask if reduce: loss = loss.sum(1) / (mask.sum(1) + NUMERICAL_STABILITY_CONSTANT) diff --git a/python/fate_llm/algo/fdkt/utils/text_generate.py b/python/fate_llm/algo/fdkt/utils/text_generate.py index 4e31b88..60499d5 100644 --- a/python/fate_llm/algo/fdkt/utils/text_generate.py +++ b/python/fate_llm/algo/fdkt/utils/text_generate.py @@ -1,45 +1,68 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from tqdm import tqdm from typing import Any, Dict, List def slm_text_generate( + inference_inst, model, tokenizer, - prompt_ids_dict, + prompt_dict, seq_num_for_single_category, batch_size, use_cpu, generation_config ): - model.eval() generated_ret = dict( inputs=list(), labels=list(), ) - for label, prompt_ids in prompt_ids_dict.items(): - prompt_length = len(prompt_ids) - batch_num = (seq_num_for_single_category + batch_size - 1) // batch_size - for batch_idx in tqdm(range(batch_num)): - if batch_idx + 1 == batch_num: - cur_batch_size = seq_num_for_single_category - batch_idx * batch_size - else: - cur_batch_size = batch_size - input_ids = prompt_ids.repeat(cur_batch_size, 1) + if inference_inst is not None: + for label, prompt in prompt_dict.items(): + generated_sequences = inference_inst.inference([prompt] * seq_num_for_single_category, generation_config) + for g in generated_sequences: + generated_ret["inputs"].append(g) + generated_ret["labels"].append(label) + else: + model.eval() + for label, prompt_ids in prompt_dict.items(): + prompt_length = len(prompt_ids) + batch_num = (seq_num_for_single_category + batch_size - 1) // batch_size + for batch_idx in tqdm(range(batch_num)): + if batch_idx + 1 == batch_num: + cur_batch_size = seq_num_for_single_category - batch_idx * batch_size + else: + cur_batch_size = batch_size + input_ids = prompt_ids.repeat(cur_batch_size, 1) - if not use_cpu: - input_ids = input_ids.to(model.device) + if not use_cpu: + input_ids = input_ids.to(model.device) - output_sequences = model.generate( - input_ids=input_ids, - **generation_config - ) - output_sequences = output_sequences[:, prompt_length:] + output_sequences = model.generate( + input_ids=input_ids, + **generation_config + ) + output_sequences = output_sequences[:, prompt_length:] - generated_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) + generated_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) - for g in generated_sequences: - generated_ret["inputs"].append(g) - generated_ret["labels"].append(label) + for g in generated_sequences: + generated_ret["inputs"].append(g) + generated_ret["labels"].append(label) return generated_ret diff --git a/python/fate_llm/model_zoo/embedding_transformer/__init__.py b/python/fate_llm/model_zoo/embedding_transformer/__init__.py index e69de29..878d3a9 100644 --- a/python/fate_llm/model_zoo/embedding_transformer/__init__.py +++ b/python/fate_llm/model_zoo/embedding_transformer/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/fate_llm/model_zoo/hf_model.py b/python/fate_llm/model_zoo/hf_model.py index ecc6c00..58fd1f3 100644 --- a/python/fate_llm/model_zoo/hf_model.py +++ b/python/fate_llm/model_zoo/hf_model.py @@ -1,3 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import torch from transformers import AutoModelForCausalLM diff --git a/python/fate_llm/runner/fdkt_runner.py b/python/fate_llm/runner/fdkt_runner.py index 3594727..66d852f 100644 --- a/python/fate_llm/runner/fdkt_runner.py +++ b/python/fate_llm/runner/fdkt_runner.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# import logging import torch from fate.components.components.nn.nn_runner import ( From 454ad9e0d14459591a35038c25afcf1138b4490c Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Fri, 26 Jul 2024 16:04:50 +0800 Subject: [PATCH 25/50] add invalid data filtering process Signed-off-by: mgqa34 --- python/fate_llm/algo/fdkt/fdkt_data_aug.py | 9 +++-- .../algo/fdkt/utils/invalid_data_filter.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 python/fate_llm/algo/fdkt/utils/invalid_data_filter.py diff --git a/python/fate_llm/algo/fdkt/fdkt_data_aug.py b/python/fate_llm/algo/fdkt/fdkt_data_aug.py index bab5744..a55044d 100644 --- a/python/fate_llm/algo/fdkt/fdkt_data_aug.py +++ b/python/fate_llm/algo/fdkt/fdkt_data_aug.py @@ -23,6 +23,7 @@ from typing import Optional, Callable from fate.arch import Context from transformers import PreTrainedTokenizer +from .utils.invalid_data_filter import filter_invalid_data from .utils.text_generate import slm_text_generate, general_text_generate from .cluster.cluster import SentenceCluster from fate_llm.inference.inference_base import Inference @@ -41,7 +42,7 @@ class FDKTTrainingArguments(Seq2SeqTrainingArguments): dp_training: bool = field(default=True) target_epsilon: float = field(default=3) target_delta: float = field(default=1e-5) - freeze_embedding: bool = field(default=False) + freeze_embedding: bool = field(default=True) device_id: int = field(default=0) slm_generation_config: dict = field(default=None) slm_generation_batch_size: dict = field(default=None) @@ -125,7 +126,7 @@ def aug_data(self): prefix_prompt_dict = self.train_set.get_generate_prompt( tokenize=True if inference_inst is None else False) - generated_text = slm_text_generate( + generated_texts = slm_text_generate( inference_inst, self.model, self.tokenizer, @@ -142,7 +143,8 @@ def aug_data(self): self.model.cpu() torch.cuda.empty_cache() - self.sync_synthetic_dataset(generated_text) + generated_texts = filter_invalid_data(generated_texts) + self.sync_synthetic_dataset(generated_texts) return self.sync_aug_data() @@ -265,6 +267,7 @@ def aug_data(self): logging.info("aug_data") aug_data = self._aug(aug_prompts) + aug_data = filter_invalid_data(aug_data) self.sync_aug_data(aug_data) def _aug(self, aug_prompts): diff --git a/python/fate_llm/algo/fdkt/utils/invalid_data_filter.py b/python/fate_llm/algo/fdkt/utils/invalid_data_filter.py new file mode 100644 index 0000000..3f26712 --- /dev/null +++ b/python/fate_llm/algo/fdkt/utils/invalid_data_filter.py @@ -0,0 +1,34 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +INVALID_CHARACTERS = "".join([' ', '-', '.', '_', '~', '/', '\\', '*', '|', '#']) +LEAST_WORDS = 10 + + +def filter_invalid_data(data_dict): + sample_num = len(data_dict["inputs"]) + new_data_dict = dict( + inputs=list(), + labels=list() + ) + for idx in range(sample_num): + text = data_dict["inputs"][idx].strip(INVALID_CHARACTERS) + if len(text.split()) < LEAST_WORDS: + continue + + new_data_dict["inputs"].append(text) + new_data_dict["labels"].append(data_dict["labels"][idx]) + + return new_data_dict From 31a0b0df9fda65ba0caf7a4381ae34c1d3b00ce0 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Fri, 26 Jul 2024 16:31:15 +0800 Subject: [PATCH 26/50] add max_prompt_len when tokenizer data Signed-off-by: mgqa34 --- python/fate_llm/dataset/flex_dataset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/fate_llm/dataset/flex_dataset.py b/python/fate_llm/dataset/flex_dataset.py index 683cd6a..305cc4e 100644 --- a/python/fate_llm/dataset/flex_dataset.py +++ b/python/fate_llm/dataset/flex_dataset.py @@ -115,7 +115,7 @@ def tokenize_function(examples): class FlexDataset(Dataset): def __init__(self, - tokenizer_path, + tokenizer_name_or_path, dataset_name: str, load_from: Literal['json'] = 'json', data_part: str = None, @@ -129,8 +129,8 @@ def __init__(self, super().__init__() self.tokenizer = None - self.tokenizer_path = tokenizer_path - self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=True) self.dataset_name = dataset_name if self.dataset_name and config is None: config = DATA_CONFIG_TEMPLATE.get(self.dataset_name, "") @@ -330,7 +330,8 @@ def load(self, path): sub_domain=self.sub_domain, tokenize_format=self.tokenize_format, text_key=self.text_key, - label_key=self.label_key + label_key=self.label_key, + max_prompt_len=self.max_prompt_len ) self.ds = tokenized_ds[self.data_part] @@ -340,14 +341,14 @@ def load(self, path): def apply_chat_template(self, query): tokenizer = self.tokenizer - if "llama-3" in self.tokenizer_path.lower(): + if "llama-3" in self.tokenizer_name_or_path.lower(): msg = [ {"role": "system", "content": "You are a helpful assistant. "}, {"role": "user", "content": query} ] prompt = tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize=False) else: - conv = get_conversation_template(self.tokenizer_path) + conv = get_conversation_template(self.tokenizer_name_or_path) conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() From a7b14fcd6673b39c442ea88bb6e8bc40d5aa0486 Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 29 Jul 2024 14:26:13 +0800 Subject: [PATCH 27/50] Fix emulaotr & fix inferdpt code Signed-off-by: weijingchen Signed-off-by: cwj --- .../algo/offsite_tuning/offsite_tuning.py | 4 +- .../offsite_tuning/offsite_tuning_model.py | 48 ++++++++++++------- python/fate_llm/runner/inferdpt_runner.py | 2 +- 3 files changed, 33 insertions(+), 21 deletions(-) diff --git a/python/fate_llm/algo/offsite_tuning/offsite_tuning.py b/python/fate_llm/algo/offsite_tuning/offsite_tuning.py index 168471a..6ef6c5c 100644 --- a/python/fate_llm/algo/offsite_tuning/offsite_tuning.py +++ b/python/fate_llm/algo/offsite_tuning/offsite_tuning.py @@ -125,7 +125,7 @@ def on_train_end(self, ctx: Context, aggregator: Aggregator, fed_args: FedArgume if args.local_rank == 0: if args.world_size > 1: model = unwrap_model(model) - return_weights = model.get_submodel_weights() + return_weights = model.get_submodel_weights(with_emulator=False) ctx.arbiter.put('trained_sub_model_para', return_weights) logger.info('weights sent back to the server') @@ -153,7 +153,7 @@ def on_train_begin(self, ctx: Context, aggregator: Aggregator): def on_train_end(self, ctx: Context, aggregator: Aggregator): parameters_to_get = ctx.guest.get('trained_sub_model_para') - self.model.load_submodel_weights(parameters_to_get) + self.model.load_submodel_weights(parameters_to_get, with_emulator=False) logger.info('received trained submodel weigths from the client') def on_federation(self, ctx: Context, aggregator, agg_iter_idx: int): diff --git a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py index 73f50ef..1a6c2c2 100644 --- a/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py +++ b/python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py @@ -150,38 +150,50 @@ def get_numpy_state_dict(self, module_dict): v in v.state_dict().items()} return weight_dict - def get_submodel_weights(self) -> dict: - submodel_weights = { - "emulator": { - k: self._get_numpy_arr(v) for k, - v in self.get_emulator().state_dict().items()}, - "adapter_top": { - k: self._get_numpy_arr(v) for k, - v in self.get_adapter_top().state_dict().items()}, - "adapter_bottom": { - k: self._get_numpy_arr(v) for k, - v in self.get_adapter_bottom().state_dict().items()}} + def get_submodel_weights(self, with_emulator=True) -> dict: + if with_emulator: + submodel_weights = { + "emulator": { + k: self._get_numpy_arr(v) for k, + v in self.get_emulator().state_dict().items()}, + "adapter_top": { + k: self._get_numpy_arr(v) for k, + v in self.get_adapter_top().state_dict().items()}, + "adapter_bottom": { + k: self._get_numpy_arr(v) for k, + v in self.get_adapter_bottom().state_dict().items()}} + else: + submodel_weights = { + "adapter_top": { + k: self._get_numpy_arr(v) for k, + v in self.get_adapter_top().state_dict().items()}, + "adapter_bottom": { + k: self._get_numpy_arr(v) for k, + v in self.get_adapter_bottom().state_dict().items()}} addition_weights = self.get_additional_param_state_dict() submodel_weights.update(addition_weights) return submodel_weights - def load_submodel_weights(self, submodel_weights: dict): + def load_submodel_weights(self, submodel_weights: dict, with_emulator=True): + + if with_emulator: + emulator_weights = { + k: t.tensor(v) for k, + v in submodel_weights['emulator'].items()} + emulator = self.get_emulator() + emulator.load_state_dict(emulator_weights) + - emulator_weights = { - k: t.tensor(v) for k, - v in submodel_weights['emulator'].items()} adapter_top_weights = { k: t.tensor(v) for k, v in submodel_weights['adapter_top'].items()} adapter_bottom_weights = { k: t.tensor(v) for k, v in submodel_weights['adapter_bottom'].items()} - - emulator = self.get_emulator() adapter_top = self.get_adapter_top() adapter_bottom = self.get_adapter_bottom() - emulator.load_state_dict(emulator_weights) + adapter_top.load_state_dict(adapter_top_weights) adapter_bottom.load_state_dict(adapter_bottom_weights) self.load_additional_param_state_dict(submodel_weights) diff --git a/python/fate_llm/runner/inferdpt_runner.py b/python/fate_llm/runner/inferdpt_runner.py index 20ec20c..e67c9d8 100644 --- a/python/fate_llm/runner/inferdpt_runner.py +++ b/python/fate_llm/runner/inferdpt_runner.py @@ -68,7 +68,7 @@ def __init__( self.perturbed_response_key = perturbed_response_key self.result_key = result_key - def _get_inferdpt_inst(self): + def _get_inst(self): loader = Loader.from_dict(self.inferdpt_init_conf) init_inst = loader.load_item()(self.get_context()) assert isinstance(init_inst, InferInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst)) From 2a21af114de0ec253f5f86ae032b3c2717812023 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 09:43:47 +0800 Subject: [PATCH 28/50] fix prompt len cutoff problem Signed-off-by: mgqa34 --- python/fate_llm/algo/fdkt/utils/text_generate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/fate_llm/algo/fdkt/utils/text_generate.py b/python/fate_llm/algo/fdkt/utils/text_generate.py index 60499d5..72891b9 100644 --- a/python/fate_llm/algo/fdkt/utils/text_generate.py +++ b/python/fate_llm/algo/fdkt/utils/text_generate.py @@ -78,6 +78,8 @@ def general_text_generate( prompt_max_length ): if inference_inst is not None: + if prompt_max_length is not None: + prompts = [prompt[:prompt_max_length] for prompt in prompts] generate_texts = inference_inst.inference(prompts, generation_config) else: model.eval() From 5fee90d8231ce875ae144278209aab4a11f38266 Mon Sep 17 00:00:00 2001 From: xwenhuang1 Date: Tue, 30 Jul 2024 11:09:47 +0800 Subject: [PATCH 29/50] Add Setup Script And Pipeline Examples Signed-off-by: xwenhuang1 --- examples/fedmkt/__init__.py | 0 examples/fedmkt/fedmkt.py | 262 ++++++++++++++++++ examples/fedmkt/fedmkt_config.yaml | 99 +++++++ examples/fedmkt/test_fedmkt_llmsuit.yaml | 14 + examples/offsite_tuning/__init__.py | 0 examples/offsite_tuning/offsite_tuning.py | 123 ++++++++ .../offsite_tuning/offsite_tuning_config.yaml | 67 +++++ .../test_offsite_tuning_llmsuite.yaml | 14 + python/requirements.txt | 2 + python/setup.py | 68 ++++- 10 files changed, 640 insertions(+), 9 deletions(-) create mode 100644 examples/fedmkt/__init__.py create mode 100644 examples/fedmkt/fedmkt.py create mode 100644 examples/fedmkt/fedmkt_config.yaml create mode 100644 examples/fedmkt/test_fedmkt_llmsuit.yaml create mode 100644 examples/offsite_tuning/__init__.py create mode 100644 examples/offsite_tuning/offsite_tuning.py create mode 100644 examples/offsite_tuning/offsite_tuning_config.yaml create mode 100644 examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml diff --git a/examples/fedmkt/__init__.py b/examples/fedmkt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/fedmkt/fedmkt.py b/examples/fedmkt/fedmkt.py new file mode 100644 index 0000000..5c3dd9e --- /dev/null +++ b/examples/fedmkt/fedmkt.py @@ -0,0 +1,262 @@ +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner +from fate_client.pipeline.components.fate.nn.algo_params import FedMKTTrainingArguments, FedAVGArguments +from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader +from peft import LoraConfig, TaskType +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate.reader import Reader +from transformers import AutoConfig +import argparse +import yaml +from typing import Union, Dict + +def main(config="./config.yaml", param: Union[Dict, str] = None): + if isinstance(config, str): + with open(config, 'r') as f: + config = yaml.safe_load(f) + + if isinstance(param, str): + param = yaml.safe_load(param) + + guest = config['parties']['guest'][0] # replace with actual guest party ID + host = config['parties']['host'][0] # replace with actual host party ID + arbiter = config['parties']['arbiter'][0] # replace with actual arbiter party ID + + process_data_output_dir = config['paths']['process_data_output_dir'] + llm_pretrained_path = config['paths']['llm_pretrained_path'] + slm_pretrained_paths = config['paths']['slm_pretrained_paths'] + vocab_mapping_directory = config['paths']['vocab_mapping_directory'] + + slm_to_llm_vocab_mapping_paths = [ + vocab_mapping_directory + "/" + path for path in config['paths']['slm_to_llm_vocab_mapping_paths'] + ] + llm_to_slm_vocab_mapping_paths = [ + vocab_mapping_directory + "/" + path for path in config['paths']['llm_to_slm_vocab_mapping_paths'] + ] + + slm_models = config['models']['slm_models'] + slm_lora_target_modules = config['lora_config']['slm_lora_target_modules'] + + def get_llm_conf(): + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=param['lora_config']['llm']['r'], + lora_alpha=param['lora_config']['llm']['lora_alpha'], + lora_dropout=param['lora_config']['llm']['lora_dropout'], + target_modules=param['lora_config']['llm']['target_modules'] + ) + lora_config.target_modules = list(lora_config.target_modules) + + llm_model = LLMModelLoader( + "pellm.llama", + "LLaMa", + pretrained_path=llm_pretrained_path, + peft_type="LoraConfig", + peft_config=lora_config.to_dict(), + torch_dtype="bfloat16" + ) + + pub_dataset = LLMDatasetLoader( + "qa_dataset", + "QaDataset", + tokenizer_name_or_path=llm_pretrained_path, + need_preprocess=True, + dataset_name="arc_challenge", + data_part="common", + seq_max_len=512 + ) + + training_args = FedMKTTrainingArguments( + global_epochs=param['training']['llm']['global_epochs'], + per_device_train_batch_size=param['training']['llm']['per_device_train_batch_size'], + gradient_accumulation_steps=param['training']['llm']['gradient_accumulation_steps'], + learning_rate=param['training']['llm']['learning_rate'], + output_dir=param['training']['llm']['output_dir'], + dataloader_num_workers=param['training']['llm']['dataloader_num_workers'], + remove_unused_columns=param['training']['llm']['remove_unused_columns'], + warmup_ratio=param['training']['llm']['warmup_ratio'], + lr_scheduler_type=param['training']['llm']['lr_scheduler_type'], + optim=param['training']['llm']['optim'], + adam_beta1=param['training']['llm']['adam_beta1'], + adam_beta2=param['training']['llm']['adam_beta2'], + weight_decay=param['training']['llm']['weight_decay'], + max_grad_norm=param['training']['llm']['max_grad_norm'], + use_cpu=param['training']['llm']['use_cpu'], + vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size, + ) + + fed_args = FedAVGArguments( + aggregate_strategy='epoch', + aggregate_freq=1 + ) + + tokenizer = LLMDataFuncLoader( + "tokenizers.cust_tokenizer", + "get_tokenizer", + tokenizer_name_or_path=llm_pretrained_path + ) + + slm_tokenizers = [ + LLMDataFuncLoader("tokenizers.cust_tokenizer", "get_tokenizer", tokenizer_name_or_path=path) + for path in slm_pretrained_paths + ] + + return get_config_of_fedmkt_runner( + model=llm_model, + training_args=training_args, + fed_args=fed_args, + pub_dataset=pub_dataset, + tokenizer=tokenizer, + slm_tokenizers=slm_tokenizers, + slm_to_llm_vocab_mapping_paths=slm_to_llm_vocab_mapping_paths, + pub_dataset_path=process_data_output_dir, + save_trainable_weights_only=True, + ) + + def get_slm_conf(slm_idx): + slm_pretrained_path = slm_pretrained_paths[slm_idx] + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=param['lora_config']['slm'][slm_idx]['r'], + lora_alpha=param['lora_config']['slm'][slm_idx]['lora_alpha'], + lora_dropout=param['lora_config']['slm'][slm_idx]['lora_dropout'], + target_modules=param['lora_config']['slm'][slm_idx]['target_modules'] + ) + lora_config.target_modules = list(lora_config.target_modules) + llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping_paths[slm_idx] + + slm_model = LLMModelLoader( + slm_models[slm_idx][0], + slm_models[slm_idx][1], + pretrained_path=slm_pretrained_path, + peft_type="LoraConfig", + peft_config=lora_config.to_dict(), + ) + vocab_size = AutoConfig.from_pretrained(slm_pretrained_path).vocab_size + + pub_dataset = LLMDatasetLoader( + "qa_dataset", + "QaDataset", + tokenizer_name_or_path=slm_pretrained_path, + need_preprocess=True, + dataset_name="arc_challenge", + data_part="common", + seq_max_len=512 + ) + + priv_dataset = LLMDatasetLoader( + "qa_dataset", + "QaDataset", + tokenizer_name_or_path=slm_pretrained_path, + need_preprocess=True, + dataset_name="arc_challenge", + data_part="client_0", + seq_max_len=512 + ) + + training_args = FedMKTTrainingArguments( + global_epochs=param['training']['slm']['global_epochs'], + per_device_train_batch_size=param['training']['slm']['per_device_train_batch_size'], + gradient_accumulation_steps=param['training']['slm']['gradient_accumulation_steps'], + learning_rate=param['training']['slm']['learning_rate'] if slm_idx != 1 else 3e-4, + output_dir=param['training']['slm']['output_dir'], + dataloader_num_workers=param['training']['slm']['dataloader_num_workers'], + remove_unused_columns=param['training']['slm']['remove_unused_columns'], + warmup_ratio=param['training']['slm']['warmup_ratio'], + lr_scheduler_type=param['training']['slm']['lr_scheduler_type'], + optim=param['training']['slm']['optim'], + adam_beta1=param['training']['slm']['adam_beta1'], + adam_beta2=param['training']['slm']['adam_beta2'], + weight_decay=param['training']['slm']['weight_decay'], + max_grad_norm=param['training']['slm']['max_grad_norm'], + use_cpu=param['training']['slm']['use_cpu'], + vocab_size=vocab_size, + ) + + fed_args = FedAVGArguments( + aggregate_strategy='epoch', + aggregate_freq=1 + ) + + tokenizer = LLMDataFuncLoader( + "tokenizers.cust_tokenizer", + "get_tokenizer", + tokenizer_name_or_path=slm_pretrained_path + ) + + llm_tokenizer = LLMDataFuncLoader( + "tokenizers.cust_tokenizer", + "get_tokenizer", + tokenizer_name_or_path=llm_pretrained_path + ) + + data_collator = LLMDataFuncLoader( + module_name='data_collator.cust_data_collator', + item_name='get_seq2seq_data_collator', + tokenizer_name_or_path=slm_pretrained_path + ) + + return get_config_of_fedmkt_runner( + model=slm_model, + training_args=training_args, + fed_args=fed_args, + pub_dataset=pub_dataset, + priv_dataset=priv_dataset, + tokenizer=tokenizer, + llm_tokenizer=llm_tokenizer, + llm_to_slm_vocab_mapping_path=llm_to_slm_vocab_mapping, + pub_dataset_path=process_data_output_dir, + save_trainable_weights_only=True, + data_collator=data_collator + ) + + pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter, host=host) + pipeline.bind_local_path(path=process_data_output_dir, namespace="experiment", name="arc_challenge") + + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=config['data']['guest']['namespace'], + name=config['data']['guest']['name'] + ) + reader_0.hosts[[0, 1, 2]].task_parameters( + namespace=config['data']['host']['namespace'], + name=config['data']['host']['name'] + ) + + homo_nn_0 = HomoNN( + 'nn_0', + train_data=reader_0.outputs["output_data"], + runner_module="fedmkt_runner", + runner_class="FedMKTRunner", + ) + + homo_nn_0.arbiter.task_parameters( + runner_conf=get_llm_conf() + ) + + homo_nn_0.guest.task_parameters( + runner_conf=get_slm_conf(slm_idx=0) + ) + + for idx in range(1): + homo_nn_0.hosts[idx].task_parameters( + runner_conf=get_slm_conf(slm_idx=idx + 1) + ) + + homo_nn_0.guest.conf.set("launcher_name", "deepspeed") # tell scheduler engine to run task with deepspeed + homo_nn_0.hosts[0].conf.set("launcher_name", "deepspeed") # tell scheduler engine to run task with deepspeed + homo_nn_0.arbiter.conf.set("launcher_name", "deepspeed") # tell scheduler engine to run task with deepspeed + + pipeline.add_tasks([reader_0, homo_nn_0]) + pipeline.conf.set("task", dict(engine_run={"cores": 1})) # the number of gpus of each party + + pipeline.compile() + pipeline.fit() + +if __name__ == "__main__": + parser = argparse.ArgumentParser("LLMSUITE PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, help="config file", default="./config.yaml") + parser.add_argument("-p", "--param", type=str, help="config file for params", default="./fedmkt_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/fedmkt/fedmkt_config.yaml b/examples/fedmkt/fedmkt_config.yaml new file mode 100644 index 0000000..597bee6 --- /dev/null +++ b/examples/fedmkt/fedmkt_config.yaml @@ -0,0 +1,99 @@ +# fedmkt_config.yaml + +# Configuration for Lora +lora_config: + llm: + r: 8 + lora_alpha: 16 + lora_dropout: 0.05 + target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + slm: + - # Configuration for the first SLM model + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - v_proj + - # Configuration for the second SLM model + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - c_attn + +# Training configuration +training: + llm: + global_epochs: 5 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 4 + learning_rate: 3e-5 + output_dir: "./" + dataloader_num_workers: 4 + remove_unused_columns: false + warmup_ratio: 0.008 + lr_scheduler_type: "cosine" + optim: "adamw_torch" + adam_beta1: 0.9 + adam_beta2: 0.95 + weight_decay: 0.1 + max_grad_norm: 1.0 + use_cpu: false + slm: + global_epochs: 5 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 4 + learning_rate: 3e-5 # Adjust learning rate for SLM models + output_dir: "./" + dataloader_num_workers: 4 + remove_unused_columns: false + warmup_ratio: 0.008 + lr_scheduler_type: "cosine" + optim: "adamw_torch" + adam_beta1: 0.9 + adam_beta2: 0.95 + weight_decay: 0.1 + max_grad_norm: 1.0 + use_cpu: false + +# Paths configuration +paths: + process_data_output_dir: "" + llm_pretrained_path: "Llama-2-7b-hf" + slm_pretrained_paths: + - "opt-1.3b" + - "gpt2" + vocab_mapping_directory: "" + slm_to_llm_vocab_mapping_paths: + - "opt_to_llama.json" + - "gpt2_to_llama.json" + - "llama_small_to_llama.json" + llm_to_slm_vocab_mapping_paths: + - "llama_to_opt.json" + - "llama_to_gpt2.json" + - "llama_to_llama_small" + +# Models configuration +models: + slm_models: + - ["pellm.opt", "OPT"] + - ["pellm.gpt2", "GPT2CLM"] + +# Data configuration +data: + guest: + namespace: "experiment" + name: "arc_challenge" + host: + namespace: "experiment" + name: "arc_challenge" + +# Example: Additional custom configuration +custom_config: + some_param: "value" + another_param: 123 diff --git a/examples/fedmkt/test_fedmkt_llmsuit.yaml b/examples/fedmkt/test_fedmkt_llmsuit.yaml new file mode 100644 index 0000000..1516809 --- /dev/null +++ b/examples/fedmkt/test_fedmkt_llmsuit.yaml @@ -0,0 +1,14 @@ +data: + - file: + table_name: arc_challenge + namespace: experiment + role: guest_0 + - file: + table_name: arc_challenge + namespace: experiment + role: host_0 +bloom_lora_vs_zero_shot: + gpt2_fedmkt: + pretrained: "gpt2" + script: "./fedmkt.py" + conf: "./fedmkt_config.yaml" \ No newline at end of file diff --git a/examples/offsite_tuning/__init__.py b/examples/offsite_tuning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/offsite_tuning/offsite_tuning.py b/examples/offsite_tuning/offsite_tuning.py new file mode 100644 index 0000000..73f6884 --- /dev/null +++ b/examples/offsite_tuning/offsite_tuning.py @@ -0,0 +1,123 @@ +import argparse +import yaml +from fate_client.pipeline.components.fate.reader import Reader +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner +from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments +from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader +from fate_client.pipeline.components.fate.nn.torch.base import Sequential +from fate_client.pipeline.components.fate.nn.torch import nn + +def load_params(file_path): + """Load and parse the YAML params file.""" + with open(file_path, 'r') as f: + params = yaml.safe_load(f) + return params + +def setup_pipeline(params): + """Set up the pipeline using the provided parameters.""" + guest = params['pipeline']['guest'] + arbiter = params['pipeline']['arbiter'] + pretrained_model_path = params['paths']['pretrained_model_path'] + + pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter) + + reader = Reader("reader_0", runtime_parties=dict(guest=guest)) + reader.guest.task_parameters( + namespace=params['pipeline']['namespace'], + name=params['pipeline']['name'] + ) + + client_model = LLMModelLoader( + module_name=params['models']['client']['module_name'], + item_name=params['models']['client']['item_name'], + model_name_or_path=pretrained_model_path, + emulator_layer_num=params['models']['client']['emulator_layer_num'], + adapter_top_layer_num=params['models']['client']['adapter_top_layer_num'], + adapter_bottom_layer_num=params['models']['client']['adapter_bottom_layer_num'] + ) + + server_model = LLMModelLoader( + module_name=params['models']['server']['module_name'], + item_name=params['models']['server']['item_name'], + model_name_or_path=pretrained_model_path, + emulator_layer_num=params['models']['server']['emulator_layer_num'], + adapter_top_layer_num=params['models']['server']['adapter_top_layer_num'], + adapter_bottom_layer_num=params['models']['server']['adapter_bottom_layer_num'] + ) + + dataset = LLMDatasetLoader( + module_name=params['dataset']['module_name'], + item_name=params['dataset']['item_name'], + tokenizer_name_or_path=params['dataset']['tokenizer_name_or_path'], + select_num=params['dataset']['select_num'] + ) + + data_collator = LLMDataFuncLoader( + module_name=params['data_collator']['module_name'], + item_name=params['data_collator']['item_name'], + tokenizer_name_or_path=params['data_collator']['tokenizer_name_or_path'] + ) + + train_args = Seq2SeqTrainingArguments( + per_device_train_batch_size=params['training']['batch_size'], + learning_rate=params['training']['learning_rate'], + disable_tqdm=False, + num_train_epochs=params['training']['num_train_epochs'], + logging_steps=params['training']['logging_steps'], + logging_strategy='steps', + dataloader_num_workers=4, + use_cpu=False, + deepspeed=params['training']['deepspeed'], # Add DeepSpeed config here + remove_unused_columns=False, + fp16=True + ) + + client_conf = get_conf_of_ot_runner( + model=client_model, + dataset=dataset, + data_collator=data_collator, + training_args=train_args, + fed_args=FedAVGArguments(), + aggregate_model=False, + ) + + server_conf = get_conf_of_ot_runner( + model=server_model, + dataset=dataset, + data_collator=data_collator, + training_args=train_args, + fed_args=FedAVGArguments(), + aggregate_model=False + ) + + homo_nn = HomoNN( + 'nn_0', + train_data=reader.outputs["output_data"], + runner_module="offsite_tuning_runner", + runner_class="OTRunner" + ) + + homo_nn.guest.task_parameters(runner_conf=client_conf) + homo_nn.arbiter.task_parameters(runner_conf=server_conf) + + # If using Eggroll, you can add this line to submit your job + homo_nn.guest.conf.set("launcher_name", "deepspeed") + + pipeline.add_tasks([reader, homo_nn]) + pipeline.conf.set("task", dict(engine_run=params['pipeline']['engine_run'])) + pipeline.compile() + pipeline.fit() + +def main(config_file, param_file): + params = load_params(param_file) + setup_pipeline(params) + +if __name__ == "__main__": + parser = argparse.ArgumentParser("LLMSUITE Offsite-tuning JOB") + parser.add_argument("-c", "--config", type=str, + help="Path to config file", default="./config.yaml") + parser.add_argument("-p", "--param", type=str, + help="Path to parameter file", default="./test_offsite_tuning_llmsuite.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/offsite_tuning/offsite_tuning_config.yaml b/examples/offsite_tuning/offsite_tuning_config.yaml new file mode 100644 index 0000000..deb79cf --- /dev/null +++ b/examples/offsite_tuning/offsite_tuning_config.yaml @@ -0,0 +1,67 @@ +# params.yaml + +paths: + pretrained_model_path: 'gpt2' + +pipeline: + guest: '9999' + arbiter: '9999' + namespace: 'experiment' + name: 'sciq' + engine_run: + cores: 1 + +training: + batch_size: 1 + learning_rate: 5e-5 + num_train_epochs: 1 + logging_steps: 10 + deepspeed: + train_micro_batch_size_per_gpu: 1 + optimizer: + type: "Adam" + params: + lr: 5e-5 + torch_adam: true + adam_w_mode: false + fp16: + enabled: true + gradient_accumulation_steps: 1 + zero_optimization: + stage: 2 + allgather_partitions: true + allgather_bucket_size: 1e8 + overlap_comm: true + reduce_scatter: true + reduce_bucket_size: 1e8 + contiguous_gradients: true + offload_optimizer: + device: "cpu" + offload_param: + device: "cpu" + +models: + client: + module_name: 'offsite_tuning.gpt2' + item_name: 'GPT2LMHeadSubModel' + emulator_layer_num: 11 + adapter_top_layer_num: 2 + adapter_bottom_layer_num: 2 + + server: + module_name: 'offsite_tuning.gpt2' + item_name: 'GPT2LMHeadMainModel' + emulator_layer_num: 11 + adapter_top_layer_num: 2 + adapter_bottom_layer_num: 2 + +dataset: + module_name: 'qa_dataset' + item_name: 'QaDataset' + tokenizer_name_or_path: 'gpt2' + select_num: 100 + +data_collator: + module_name: 'data_collator.cust_data_collator' + item_name: 'get_seq2seq_data_collator' + tokenizer_name_or_path: 'gpt2' diff --git a/examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml b/examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml new file mode 100644 index 0000000..0cdd007 --- /dev/null +++ b/examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml @@ -0,0 +1,14 @@ +data: + - file: + table_name: sciq + namespace: experiment + role: guest_0 + - file: + table_name: sciq + namespace: experiment + role: host_0 +bloom_lora_vs_zero_shot: + gpt2_ot: + pretrained: "gpt2" + script: "./offsite_tuning.py" + conf: "./offsite_tuning_config.yaml" \ No newline at end of file diff --git a/python/requirements.txt b/python/requirements.txt index fb6f6dd..d64d0d0 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,4 +6,6 @@ lm_eval==0.4.2 rouge-score==0.1.2 datasets editdistance +torch==2.3.1 +transformers==4.42.4 diff --git a/python/setup.py b/python/setup.py index de2fe15..bdc4d26 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,13 +1,63 @@ # -*- coding: utf-8 -*- +# +# Copyright 2024 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -from setuptools import setup +from setuptools import find_packages, setup -packages = ["fate_llm", "fate_llm.evaluate.scripts"] -entry_points = {"console_scripts": ["fate_llm = fate_llm.evaluate.scripts.fate_llm_cli:fate_llm_cli"]} +import fate_llm -setup( - name='fate_llm', - version='2.1.0', - packages=packages, - entry_points=entry_points -) +# Define the packages and modules +packages = find_packages(include=["fate_llm", "fate_llm.*"]) + +# Define dependencies +install_requires = [ + "accelerate==0.27.2", + "deepspeed==0.13.3", + "peft==0.8.2", + "sentencepiece==0.2.0", + "lm_eval==0.4.2", + "rouge-score==0.1.2", + "datasets", + "editdistance", + "torch==2.3.1", # Added dependency + "transformers==4.42.4" # Added dependency +] + +# Define the entry points for command-line tools +entry_points = { + "console_scripts": [ + "fate_llm = fate_llm.evaluate.scripts.fate_llm_cli:fate_llm_cli" + ] +} + +# Configure and call the setup function +setup_kwargs = { + "name": "fate_llm", + "version": "2.1.0", + "description": "Federated Learning for Large Language Models", + "long_description": "Federated Learning for Large Language Models (FATE-LLM) provides a framework to train and evaluate large language models in a federated manner.", + "long_description_content_type": "text/markdown", + "author": "FederatedAI", + "author_email": "contact@FedAI.org", + "url": "https://fate.fedai.org/", + "packages": packages, + "install_requires": install_requires, + "entry_points": entry_points, + "python_requires": ">=3.8", + "include_package_data": True +} + +setup(**setup_kwargs) From 150f1def6c49fb979dda53b182c111c07dc06fab Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 15:07:49 +0800 Subject: [PATCH 30/50] update requirements of fate-llm Signed-off-by: mgqa34 --- python/requirements.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/requirements.txt b/python/requirements.txt index fb6f6dd..3027a4d 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,4 +6,9 @@ lm_eval==0.4.2 rouge-score==0.1.2 datasets editdistance +torch==2.3.1 +opacus==1.4.1 +fastchat +Jinja2 +sentence-transformers From 8a4578d8e56dfd4159421fd4fc7cc8d00a9f681c Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 15:14:04 +0800 Subject: [PATCH 31/50] update reqs Signed-off-by: mgqa34 --- python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/requirements.txt b/python/requirements.txt index 3027a4d..6969b59 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -4,7 +4,7 @@ peft==0.8.2 sentencepiece==0.2.0 lm_eval==0.4.2 rouge-score==0.1.2 -datasets +datasets==2.18.0 editdistance torch==2.3.1 opacus==1.4.1 From 72c0265eb296346a1a229617cf3c717890cdef6e Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 16:39:25 +0800 Subject: [PATCH 32/50] fix setup, add yaml files Signed-off-by: mgqa34 --- python/MANIFEST.in | 2 ++ python/fate_llm/data/__init__.py | 0 python/setup.py | 21 +++++++++++++++------ 3 files changed, 17 insertions(+), 6 deletions(-) create mode 100644 python/MANIFEST.in create mode 100644 python/fate_llm/data/__init__.py diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 0000000..6ddc55e --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,2 @@ +include fate_llm/dataset/data_config/*yaml +include python/fate_llm/evaluate/tasks/*/*yaml \ No newline at end of file diff --git a/python/fate_llm/data/__init__.py b/python/fate_llm/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/setup.py b/python/setup.py index bdc4d26..46e9538 100644 --- a/python/setup.py +++ b/python/setup.py @@ -17,10 +17,9 @@ from setuptools import find_packages, setup -import fate_llm - # Define the packages and modules -packages = find_packages(include=["fate_llm", "fate_llm.*"]) +packages = find_packages(".") +package_data = {"": ["*"]} # Define dependencies install_requires = [ @@ -32,8 +31,12 @@ "rouge-score==0.1.2", "datasets", "editdistance", - "torch==2.3.1", # Added dependency - "transformers==4.42.4" # Added dependency + "torch==2.3.1", + "transformers==4.37.2", + "opacus==1.4.1", + "fastchat", + "Jinja2", + "sentence-transformers" ] # Define the entry points for command-line tools @@ -43,10 +46,15 @@ ] } +extras_require = { + "fate": ["pyfate==2.2.0"], + "fate_flow": ["fate_flow==2.2.0"] +} + # Configure and call the setup function setup_kwargs = { "name": "fate_llm", - "version": "2.1.0", + "version": "2.2.0", "description": "Federated Learning for Large Language Models", "long_description": "Federated Learning for Large Language Models (FATE-LLM) provides a framework to train and evaluate large language models in a federated manner.", "long_description_content_type": "text/markdown", @@ -56,6 +64,7 @@ "packages": packages, "install_requires": install_requires, "entry_points": entry_points, + "extras_require": extras_require, "python_requires": ">=3.8", "include_package_data": True } From 54ba7d97711b8751b68c63e66624db7fc82bd7c3 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 16:39:48 +0800 Subject: [PATCH 33/50] fix setup, add __init__ files Signed-off-by: mgqa34 --- python/fate_llm/evaluate/tasks/advertise_gen/__init__.py | 0 python/fate_llm/evaluate/tasks/dolly_15k/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/fate_llm/evaluate/tasks/advertise_gen/__init__.py create mode 100644 python/fate_llm/evaluate/tasks/dolly_15k/__init__.py diff --git a/python/fate_llm/evaluate/tasks/advertise_gen/__init__.py b/python/fate_llm/evaluate/tasks/advertise_gen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/evaluate/tasks/dolly_15k/__init__.py b/python/fate_llm/evaluate/tasks/dolly_15k/__init__.py new file mode 100644 index 0000000..e69de29 From 73290b3b3be1dbd20f79484cd366c92663c0543f Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 16:53:33 +0800 Subject: [PATCH 34/50] update __init__ files Signed-off-by: mgqa34 --- python/fate_llm/model_zoo/offsite_tuning/__init__.py | 0 python/fate_llm/model_zoo/pellm/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/fate_llm/model_zoo/offsite_tuning/__init__.py create mode 100644 python/fate_llm/model_zoo/pellm/__init__.py diff --git a/python/fate_llm/model_zoo/offsite_tuning/__init__.py b/python/fate_llm/model_zoo/offsite_tuning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/model_zoo/pellm/__init__.py b/python/fate_llm/model_zoo/pellm/__init__.py new file mode 100644 index 0000000..e69de29 From 29a0893e5aa0df0e4d2a9239f95e8c28b6be893d Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 17:07:45 +0800 Subject: [PATCH 35/50] update __init__ files Signed-off-by: mgqa34 --- python/fate_llm/algo/pdss/encoder_decoder/__init__.py | 0 python/fate_llm/algo/pdss/encoder_decoder/init/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/fate_llm/algo/pdss/encoder_decoder/__init__.py create mode 100644 python/fate_llm/algo/pdss/encoder_decoder/init/__init__.py diff --git a/python/fate_llm/algo/pdss/encoder_decoder/__init__.py b/python/fate_llm/algo/pdss/encoder_decoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/fate_llm/algo/pdss/encoder_decoder/init/__init__.py b/python/fate_llm/algo/pdss/encoder_decoder/init/__init__.py new file mode 100644 index 0000000..e69de29 From 9d73b402993b25407996bb943c0d7be11f9fed53 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 20:38:41 +0800 Subject: [PATCH 36/50] update tutorial of fdkt Signed-off-by: mgqa34 --- doc/tutorial/fdkt/fdkt.ipynb | 732 +++++++++++++++++++++++++++++++++++ 1 file changed, 732 insertions(+) create mode 100644 doc/tutorial/fdkt/fdkt.ipynb diff --git a/doc/tutorial/fdkt/fdkt.ipynb b/doc/tutorial/fdkt/fdkt.ipynb new file mode 100644 index 0000000..76c9312 --- /dev/null +++ b/doc/tutorial/fdkt/fdkt.ipynb @@ -0,0 +1,732 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Synthesize Data With FDKT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutoria, we will demonstrate how to Synthesize data using the FATE-LLM framework. In FATE-LLM, we introduce the \"FDKT\" module, specifically designed for domain-specific knowledge transfer on large language models using synthetic data. FDKT Algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on\n", + "Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), We integrate its code into the FATE-LLM framework. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset: Yelp\n", + "We processed and sample data of 'Health' subdomain from [Yelp dataset](https://arxiv.org/abs/1509.01626) , the dataset can be downloaded from [here](https://www.yelp.com/dataset). \n", + "Once the dataset has been downloaded, execute the following command to untar the downloaded dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "```shell\n", + "tar -xvf yelp_dataset.tar\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following code will sample 5000 datalines of 'Health' subdomain, and train data will generated under the folder './balance_processed_data/Health/train.json'" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import sys\n", + "import random\n", + "from pathlib import Path\n", + "random.seed(42)\n", + "\n", + "\n", + "base_dir = \"./\"\n", + "business_data_path = os.path.join(base_dir, 'yelp_academic_dataset_business.json')\n", + "review_data_path = os.path.join(base_dir, 'yelp_academic_dataset_review.json')\n", + "\n", + "business_data_file = open(business_data_path, 'r')\n", + "review_data_file = open(review_data_path, 'r')\n", + "\n", + "categories_list = ['Restaurants', 'Shopping', 'Arts', 'Health']\n", + "business_dic = {}\n", + "data_dict = {}\n", + "for category in categories_list:\n", + " business_dic[category] = set()\n", + " data_dict[category] = []\n", + "\n", + "\n", + "def get_categories(categories):\n", + " return_list = []\n", + " for category in categories_list:\n", + " if category in categories:\n", + " return_list.append(category)\n", + " return return_list\n", + "\n", + "\n", + "for line in business_data_file.readlines():\n", + " dic = json.loads(line)\n", + " if 'categories' in dic.keys() and dic['categories'] is not None:\n", + " category = get_categories(dic['categories'])\n", + " if len(category) == 1:\n", + " business_dic[category[0]].add(dic['business_id'])\n", + "\n", + "# for category in categories_list:\n", + "for line in review_data_file.readlines():\n", + " dic = json.loads(line)\n", + " if 'business_id' in dic.keys() and dic['business_id'] is not None:\n", + " for category in categories_list:\n", + " if dic['business_id'] in business_dic[category]:\n", + " if dic['text'] is not None and dic['stars'] is not None:\n", + " data_dict[category].append({'text': dic['text'], 'stars': dic['stars']})\n", + " break\n", + "\n", + "train_data_path = os.path.join('processed_data', \"Health\", 'train.json')\n", + "os.makedirs(Path(train_data_path).parent, exist_ok=True)\n", + "train_data_file = open(train_data_path, 'w')\n", + "data_list = data_dict[\"Health\"]\n", + "\n", + "sample_data_dict = dict()\n", + "\n", + "for data in data_list:\n", + " star = int(data[\"stars\"])\n", + " if star not in sample_data_dict:\n", + " sample_data_dict[star] = []\n", + "\n", + " sample_data_dict[star].append(data)\n", + "\n", + "data_list = []\n", + "star_keys = list(sample_data_dict.keys())\n", + "for star in star_keys:\n", + " sample_data = sample_data_dict[star][:1000]\n", + " random.shuffle(sample_data)\n", + " data_list.extend(sample_data)\n", + "\n", + "random.shuffle(data_list)\n", + "json.dump(data_list, train_data_file, indent=4)\n", + "train_data_file.close()\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Models Use\n", + "Please download the following models, these models are used for data augmentation process.\n", + "\n", + "LLM: [Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat) \n", + "SLM: [gpt2-xl](https://huggingface.co/openai-community/gpt2-xl)\n", + "\n", + "MeanWhile, 'all-mpnet-base-v2' is used to generate embedding vectors in LLM side.\n", + "\n", + "Embedding Model: [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running FDKT Data Synthetic Process With Launcher (Experimential Using)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SLM Setting\n", + "\n", + "In this section, we will introduce some key configurations in SLM side." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1. loading model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import transformers\n", + "from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n", + "\n", + "\n", + "slm_pretrained_path = \"gpt2-xl\" # modity this to local directory\n", + "slm = transformers.AutoModelForCausalLM.from_pretrained(slm_pretrained_path, torch_dtype=torch.bfloat16)\n", + "tokenizer = get_tokenizer(slm_pretrained_path)\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2. Initialize SLM Training Arugments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.fdkt.fdkt_data_aug import FDKTTrainingArguments\n", + "\n", + "\n", + "training_args = FDKTTrainingArguments(\n", + " use_cpu=False, # use gpu to do dp(differential privacy) training process\n", + " device_id=0, # the device number of gpu\n", + " num_train_epochs=1, # dp training epochs\n", + " per_device_train_batch_size=2, # batch size of dp training\n", + " slm_generation_batch_size=32, # batch_size to generate data in slm side\n", + " seq_num_for_single_category=300, # data num for each category(label)\n", + " slm_generation_config=dict(\n", + " max_new_tokens=256,\n", + " temperature=1.0,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " repetition_penalty=1.0,\n", + " pad_token_id=tokenizer.eos_token_id\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3. Initlaize DataSet Instance\n", + "\n", + "We provide default templates for dataset \"Yelp\" and \"AGNews\", user can refer [here](https://github.com/FederatedAI/FATE-LLM/tree/dev-2.2.0/python/fate_llm/dataset/data_config) for more details. If you want to use your own dataset, please provide fields label_key/text_key/augment_format/filter_format/tokenize_format/sub_domain/label_list/few_shot_format/text_with_label_format like the two default templates and passing it as and argument." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.dataset.flex_dataset import FlexDataset\n", + "\n", + "\n", + "ds = FlexDataset(\n", + " tokenizer_name_or_path=slm_pretrained_path,\n", + " load_from=\"json\",\n", + " data_part=\"train\",\n", + " dataset_name=\"yelp_review\", # use default template\n", + " # config=dict/template_path # if dataset_name not equals to \"yelp_review\" or \"ag_news\"\n", + " need_preprocess=True,\n", + " select_num=2000, # use data_num=2000 to train, default is None, None means using all data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### LLM Setting\n", + "\n", + "In this section, we will introduce some key configurations in LLM side." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1. Deploy VLLM Server And Use OpenAI API Protocol To SpeedUp LLM Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "please copy the following code to local file create_and_start_vllm.sh, then run the bash code by executing \"bash create_and_start_vllm.sh\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create_and_start_vllm.sh\n", + "# create vllm enviroment\n", + "\n", + "python -m venv vllm_venv\n", + "source vllm_venv/bin/activate\n", + "pip install vllm==0.4.3\n", + "pip install numpy==1.26.4 # numpy >= 2.0.0 will raise error, so reinstall numpy<2.0.0\n", + "\n", + "# please modify Qwen1.5-7B-Chat to local llm model saving path\n", + "export CUDA_VISIBLE_DEVICES=1,2\n", + "nohup python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 9999 --model Qwen1.5-7B-Chat --dtype=half --enforce-eager --api-key demo --device cuda -tp 2 &" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2. Initialize LLM Training Arugments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.fdkt.fdkt_data_aug import FDKTTrainingArguments\n", + "\n", + "\n", + "training_args = FDKTTrainingArguments(\n", + " sample_num_per_cluster=4, # use this to estimate the number of clusters, n_clusters=(len(dataset) + sample_num_per_cluster - 1) // sample_num_per_cluster\n", + " filter_prompt_max_length=2**16,\n", + " filter_generation_config=dict(\n", + " max_tokens=512,\n", + " ),\n", + " aug_generation_config=dict(\n", + " max_tokens=4096,\n", + " temperature=0.8,\n", + " top_p=0.9,\n", + " ),\n", + " aug_prompt_num=20000, # prompts use for data augmentation\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3. Initialize Embedding Generated Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.model_zoo.embedding_transformer.st_model import SentenceTransformerModel\n", + "\n", + "\n", + "embedding_lm = SentenceTransformerModel(model_name_or_path=\"all-mpnet-base-v2\").load() # modified model_name_or_path to local model saved path" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 4. Initalize OpenAI Api For Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_llm.algo.fdkt.inference_inst import api_init\n", + "\n", + "\n", + "inference_inst = api_init(\n", + " api_url=\"http://127.0.0.1:9999/v1/\",\n", + " model_name=\"Qwen1.5-7B-Chat\", # modified model_name to local Meta-Llama-3-8B-Instruct saved path\n", + " api_key=\"demo\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete Code \n", + "\n", + "Please paste the code in \"run_fdkt_by_launcher.py\" and execute it with the following command. Once the process is finished, augmentation data will be saved in the current directory, whose filename is aug_data_result.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "python run_fdkt_by_launcher.py --parties guest:9999 arbiter:10000 --log_level INFO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "\n", + "import torch\n", + "from fate.arch import Context\n", + "from fate.arch.launchers.multiprocess_launcher import launch\n", + "\n", + "# please replace the following four variables to local paths\n", + "llm_pretrained_path = \"Qwen1.5-7B-Chat\"\n", + "embedding_model_path = \"all-mpnet-base-v2\"\n", + "slm_pretrained_path = \"gpt2-xl\"\n", + "slm_data_path = \"./process/Health/train.json\"\n", + "\n", + "\n", + "def get_optimizer(model, optimizer=\"adam\", lr=1e-4):\n", + " if optimizer == \"adam\":\n", + " optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)\n", + " elif optimizer == \"adamw\":\n", + " optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)\n", + " else:\n", + " raise NotImplementedError(\"Given optimizer type is not supported\")\n", + " return optimizer\n", + "\n", + "\n", + "def train_slm(ctx):\n", + " import transformers\n", + " from fate_llm.algo.fdkt.fdkt_data_aug import (\n", + " FDKTSLM,\n", + " FDKTTrainingArguments\n", + " )\n", + " from fate_llm.dataset.flex_dataset import FlexDataset\n", + " from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n", + " from transformers.data import DataCollatorForSeq2Seq\n", + "\n", + " slm = transformers.AutoModelForCausalLM.from_pretrained(slm_pretrained_path, torch_dtype=torch.bfloat16)\n", + " tokenizer = get_tokenizer(slm_pretrained_path)\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + " training_args = FDKTTrainingArguments(\n", + " use_cpu=False,\n", + " device_id=0,\n", + " num_train_epochs=1,\n", + " per_device_train_batch_size=2,\n", + " slm_generation_batch_size=32,\n", + " seq_num_for_single_category=2000,\n", + " slm_generation_config=dict(\n", + " max_new_tokens=256,\n", + " temperature=1.0,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " repetition_penalty=1.0,\n", + " pad_token_id=tokenizer.eos_token_id\n", + " ),\n", + " # inference_method=\"vllm\",\n", + " )\n", + "\n", + " ds = FlexDataset(\n", + " tokenizer_name_or_path=slm_pretrained_path,\n", + " load_from=\"json\",\n", + " data_part=\"train\",\n", + " dataset_name=\"yelp_review\",\n", + " need_preprocess=True,\n", + " select_num=2000, # use 2000 data to train, default is None, using all data\n", + " )\n", + " ds.load(slm_data_path)\n", + "\n", + " fdkt_runner = FDKTSLM(\n", + " ctx=ctx,\n", + " model=slm,\n", + " training_args=training_args,\n", + " tokenizer=tokenizer,\n", + " train_set=ds,\n", + " optimizer=get_optimizer(slm),\n", + " data_collator=DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=tokenizer.pad_token_id)\n", + " )\n", + "\n", + " aug_data = fdkt_runner.aug_data()\n", + " with open(\"./aug_data_result.json\", \"w\") as fout:\n", + " fout.write(json.dumps(aug_data, indent=4))\n", + "\n", + "\n", + "def train_llm(ctx):\n", + " from fate_llm.algo.fdkt.fdkt_data_aug import (\n", + " FDKTLLM,\n", + " FDKTTrainingArguments\n", + " )\n", + " from fate_llm.model_zoo.embedding_transformer.st_model import SentenceTransformerModel\n", + " from fate_llm.dataset.flex_dataset import FlexDataset\n", + " from fate_llm.algo.fdkt.inference_inst import api_init, vllm_init\n", + "\n", + " embedding_lm = SentenceTransformerModel(model_name_or_path=embedding_model_path).load()\n", + " training_args = FDKTTrainingArguments(\n", + " sample_num_per_cluster=5,\n", + " filter_prompt_max_length=2**14,\n", + " filter_generation_config=dict(\n", + " max_tokens=4096,\n", + " ),\n", + " use_cpu=False,\n", + " aug_generation_config=dict(\n", + " max_tokens=4096,\n", + " temperature=0.8,\n", + " top_p=0.9,\n", + " ),\n", + " aug_prompt_num=20000,\n", + " )\n", + "\n", + " ds = FlexDataset(\n", + " tokenizer_name_or_path=llm_pretrained_path,\n", + " load_from=\"json\",\n", + " data_part=\"train\",\n", + " dataset_name=\"yelp_review\",\n", + " need_preprocess=True,\n", + " few_shot_num_per_label=1,\n", + " )\n", + "\n", + " inference_inst = api_init(\n", + " api_url=\"http://127.0.0.1:9999/v1/\",\n", + " model_name=llm_pretrained_path,\n", + " api_key=\"demo\"\n", + " )\n", + "\n", + " fdkt_runner = FDKTLLM(\n", + " ctx=ctx,\n", + " embedding_model=embedding_lm,\n", + " training_args=training_args,\n", + " dataset=ds,\n", + " inference_inst=inference_inst,\n", + " )\n", + "\n", + " fdkt_runner.aug_data()\n", + "\n", + "\n", + "def run(ctx: Context):\n", + " if ctx.is_on_arbiter:\n", + " train_llm(ctx)\n", + " else:\n", + " os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + " train_slm(ctx)\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " launch(run)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running FEDMKT with Pipeline (Industrial Using)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Please make sure that FATE and FATE-Flow has been deployed, paste the following code to test_fdkt_by_pipeline.py, the execute \"python test_fdkt_by_pipeline.py\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fdkt_runner\n", + "from fate_client.pipeline.components.fate.nn.algo_params import FDKTTrainingArguments\n", + "from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n", + "from fate_client.pipeline import FateFlowPipeline\n", + "from fate_client.pipeline.components.fate.reader import Reader\n", + "from fate_client.pipeline.components.fate.nn.torch import nn, optim\n", + "\n", + "\n", + "guest = '9999'# replace this party id to actual guest party id in your enviroment\n", + "arbiter = '9999'# replace this party id to actual arbiter party id in your enviroment\n", + "\n", + "# please replace the following four variables to local paths\n", + "llm_pretrained_path = \"Qwen1.5-7B-Chat\"\n", + "embedding_model_path = \"all-mpnet-base-v2/\"\n", + "slm_pretrained_path = \"gpt2-xl\"\n", + "slm_data_path = \"./process/Health/train.json\" # should be absolute path\n", + "\n", + "\n", + "def get_llm_conf():\n", + " embedding_model = LLMModelLoader(\n", + " \"embedding_transformer.st_model\",\n", + " \"SentenceTransformerModel\",\n", + " model_name_or_path=embedding_model_path\n", + " )\n", + "\n", + " dataset = LLMDatasetLoader(\n", + " \"flex_dataset\",\n", + " \"FlexDataset\",\n", + " tokenizer_name_or_path=llm_pretrained_path,\n", + " need_preprocess=True,\n", + " dataset_name=\"yelp_review\",\n", + " data_part=\"train\",\n", + " load_from=\"json\",\n", + " few_shot_num_per_label=1,\n", + " )\n", + "\n", + " training_args = FDKTTrainingArguments(\n", + " sample_num_per_cluster=5,\n", + " filter_prompt_max_length=2 ** 14,\n", + " filter_generation_config=dict(\n", + " max_tokens=4096,\n", + " ),\n", + " use_cpu=False,\n", + " aug_generation_config=dict(\n", + " max_tokens=4096,\n", + " temperature=0.8,\n", + " top_p=0.9,\n", + " ),\n", + " aug_prompt_num=20000,\n", + " )\n", + "\n", + " inference_inst_conf = dict(\n", + " module_name=\"fate_llm.algo.fdkt.inference_inst\",\n", + " item_name=\"api_init\",\n", + " kwargs=dict(\n", + " api_url=\"http://127.0.0.1:9999/v1/\",\n", + " model_name=llm_pretrained_path,\n", + " api_key=\"demo\"\n", + " )\n", + " )\n", + "\n", + " return get_config_of_fdkt_runner(\n", + " training_args=training_args,\n", + " embedding_model=embedding_model,\n", + " dataset=dataset,\n", + " inference_inst_conf=inference_inst_conf,\n", + " )\n", + "\n", + "\n", + "def get_slm_conf():\n", + " slm_model = LLMModelLoader(\n", + " \"hf_model\",\n", + " \"HFAutoModelForCausalLM\",\n", + " pretrained_model_name_or_path=slm_pretrained_path,\n", + " torch_dtype=\"bfloat16\",\n", + " )\n", + "\n", + " tokenizer = LLMDataFuncLoader(\n", + " \"tokenizers.cust_tokenizer\",\n", + " \"get_tokenizer\",\n", + " tokenizer_name_or_path=slm_pretrained_path,\n", + " pad_token_id=50256\n", + " )\n", + "\n", + " training_args = FDKTTrainingArguments(\n", + " use_cpu=False,\n", + " device_id=1,\n", + " num_train_epochs=1,\n", + " per_device_train_batch_size=2,\n", + " slm_generation_batch_size=32,\n", + " seq_num_for_single_category=2000,\n", + " slm_generation_config=dict(\n", + " max_new_tokens=256,\n", + " temperature=1.0,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " repetition_penalty=1.0,\n", + " pad_token_id=50256\n", + " ),\n", + " )\n", + "\n", + " dataset = LLMDatasetLoader(\n", + " \"flex_dataset\",\n", + " \"FlexDataset\",\n", + " tokenizer_name_or_path=slm_pretrained_path,\n", + " need_preprocess=True,\n", + " dataset_name=\"yelp_review\",\n", + " data_part=\"train\",\n", + " load_from=\"json\",\n", + " select_num=2000,\n", + " few_shot_num_per_label=1,\n", + " )\n", + "\n", + " optimizer = optim.Adam(lr=0.01)\n", + "\n", + " return get_config_of_fdkt_runner(\n", + " model=slm_model,\n", + " tokenizer=tokenizer,\n", + " training_args=training_args,\n", + " dataset=dataset,\n", + " optimizer=optimizer,\n", + " data_collator=LLMDataFuncLoader(\n", + " \"data_collator.cust_data_collator\",\n", + " \"get_seq2seq_data_collator\",\n", + " label_pad_token_id=50256,\n", + " tokenizer_name_or_path=slm_pretrained_path,\n", + " pad_token_id=50256,\n", + " ),\n", + " )\n", + "\n", + "\n", + "pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n", + "pipeline.bind_local_path(path=slm_data_path, namespace=\"experiment\", name=\"slm_train\")\n", + "\n", + "\n", + "reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n", + "reader_0.guest.task_parameters(\n", + " namespace=\"experiment\",\n", + " name=\"slm_train\"\n", + ")\n", + "\n", + "\n", + "homo_nn_0 = HomoNN(\n", + " 'homo_nn_0',\n", + " train_data=reader_0.outputs[\"output_data\"],\n", + " runner_module=\"fdkt_runner\",\n", + " runner_class=\"FDKTRunner\",\n", + ")\n", + "\n", + "homo_nn_0.arbiter.task_parameters(\n", + " runner_conf=get_llm_conf()\n", + ")\n", + "\n", + "homo_nn_0.guest.task_parameters(\n", + " runner_conf=get_slm_conf()\n", + ")\n", + "\n", + "pipeline.add_tasks([reader_0, homo_nn_0])\n", + "pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1}))\n", + "\n", + "pipeline.compile()\n", + "pipeline.fit()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 0277f94618ca41fa80ae05980f12acd851e5f0bd Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 20:47:27 +0800 Subject: [PATCH 37/50] update release note of fate_llm 2.2.0 Signed-off-by: mgqa34 --- RELEASE.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index 733cf9e..d66a282 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,12 @@ +## Release 2.2.0 +### Major Features and Improvements +* Integrate the PDSS algorithm, a novel framework that enhances local small language models (SLMs) using differentially private protected Chain of Thoughts (Cot) generated by remote LLMs: + * Implement InferDPT for privacy-preserving Cot generation. + * Support an encoder-decoder mechanism for privacy-preserving Cot generation. + * Add prefix trainers for step-by-step distillation and text encoder-decoder training. +* Integrate the FDKT algorithm, a framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy + + ## Release 2.1.0 ### Major Features and Improvements * New FedMKT Federated Tuning Algorithms: Federated Mutual Knowledge Transfer for Large and Small Language Models From 9b70fb0d71845cbf4cb98c7cae2b19576bbac101 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 30 Jul 2024 20:52:53 +0800 Subject: [PATCH 38/50] update readme: add link of pdss and fdkt Signed-off-by: mgqa34 --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e964ac1..9277772 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ with Communication Cost under 18 Kilobytes](./doc/tutorial/fedkseed/) - [InferDPT: Privacy-preserving Inference for Black-box Large Language Models](./doc/tutorial/inferdpt/inferdpt_tutorial.ipynb) - [FedMKT: Federated Mutual Knowledge Transfer for Large and Small Language Models](./doc/tutorial/fedmkt/) +- [PDSS: A Privacy-Preserving Framework for Step-by-Step Distillation of Large Language Models](./doc/tutorial/pdss) +- [FDKT: Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](./doc/tutorial/fdkt) ## FATE-LLM Evaluate From 27cff118f74ede62fb4ef1ef6e0882524b895590 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 31 Jul 2024 16:49:23 +0800 Subject: [PATCH 39/50] update reqs and setup script Signed-off-by: mgqa34 --- python/requirements.txt | 2 +- python/setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/requirements.txt b/python/requirements.txt index fa37b45..f15c793 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -12,4 +12,4 @@ opacus==1.4.1 fastchat Jinja2 sentence-transformers - +openai diff --git a/python/setup.py b/python/setup.py index 46e9538..07bc655 100644 --- a/python/setup.py +++ b/python/setup.py @@ -36,7 +36,8 @@ "opacus==1.4.1", "fastchat", "Jinja2", - "sentence-transformers" + "sentence-transformers", + "openai" ] # Define the entry points for command-line tools From 61e0fdc311eea99bc0d2b270a34b0cb142e81123 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 31 Jul 2024 16:51:35 +0800 Subject: [PATCH 40/50] update deployment documentation Signed-off-by: mgqa34 --- README.md | 11 +++--- RELEASE.md | 1 + doc/standalone_deploy.md | 84 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 doc/standalone_deploy.md diff --git a/README.md b/README.md index 9277772..84891bf 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,13 @@ FATE-LLM is a framework to support federated learning for large language models( -## Deployment - ### Standalone deployment -Please refer to [FATE-Standalone deployment](https://github.com/FederatedAI/FATE#standalone-deployment). -* To deploy FATE-LLM v2.0, deploy FATE-Standalone with version >= 2.1, then make a new directory `{fate_install}/fate_llm` and clone the code into it, install the python requirements, and add `{fate_install}/fate_llm/python` to `PYTHONPATH` -* To deploy FATE-LLM v1.x, deploy FATE-Standalone with 1.11.3 <= version < 2.0, then copy directory `python/fate_llm` to `{fate_install}/fate/python/fate_llm` +* To deploy FATE-LLM v2.2.0 or higher version, three ways are provided, please refer deploy tutorials for more details: + * deploy with FATE only from pypi then using Launcher to run tasks + * deploy with FATE、FATE-Flow、FATE-Client from pypi, user can run tasks with Pipeline +* To deploy lower versions: please refer to [FATE-Standalone deployment](https://github.com/FederatedAI/FATE#standalone-deployment). + * To deploy FATE-LLM v2.0.* - FATE-LLM v2.1.*, deploy FATE-Standalone with version >= 2.1, then make a new directory `{fate_install}/fate_llm` and clone the code into it, install the python requirements, and add `{fate_install}/fate_llm/python` to `PYTHONPATH` + * To deploy FATE-LLM v1.x, deploy FATE-Standalone with 1.11.3 <= version < 2.0, then copy directory `python/fate_llm` to `{fate_install}/fate/python/fate_llm` ### Cluster deployment Use [FATE-LLM deployment packages](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) to deploy, refer to [FATE-Cluster deployment](https://github.com/FederatedAI/FATE#cluster-deployment) for more deployment details. diff --git a/RELEASE.md b/RELEASE.md index d66a282..1b308ac 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,6 +5,7 @@ * Support an encoder-decoder mechanism for privacy-preserving Cot generation. * Add prefix trainers for step-by-step distillation and text encoder-decoder training. * Integrate the FDKT algorithm, a framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy +* Deployment Optimization: support installation of FATE-LLM by PyPi ## Release 2.1.0 diff --git a/doc/standalone_deploy.md b/doc/standalone_deploy.md new file mode 100644 index 0000000..aa8ffd0 --- /dev/null +++ b/doc/standalone_deploy.md @@ -0,0 +1,84 @@ +# FATE-LLM Single-Node Deployment Guide + +## 1. Introduction + +**Server Configuration:** + +- **Quantity:** 1 +- **Configuration:** 8 cores / 16GB memory / 500GB hard disk / GPU Machine +- **Operating System:** CentOS Linux release 7 +- **User:** User: app owner:apps + +The single-node version provides 3 deployment methods, which can be selected based on your needs: +- Install FATE-LLM from PyPI With FATE +- Install FATE-LLM from PyPI with FATE, FATE-Flow, FATE-Client + +## 2. Install FATE-LLM from PyPI With FATE +In this way, user can run tasks with Launcher, a convenient way for fast experimental using. + +### 2.1 Installing Python Environment +- Prepare and install [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment. +- Create a virtual environment: + +```shell +# FATE-LLM requires Python >= 3.10 +conda create -n fate_env python=3.10 +conda activate fate_env +``` + +### 2.2 Installing FATE-LLM +This section introduces how to install FATE-LLM from pypi with FATE, execute the following command to install FATE-LLM. + +```shell +pip install fate_llm[fate]==2.2.0 +``` + +### 2.3 Usage +After installing successfully, please refer to [tutorials](../README.md#quick-start) to run tasks, tasks describe in the tutorials running will Launcher are all supported. + + +## 3. Install FATE-LLM from PyPI with FATE, FATE-Flow, FATE-Client +In this way, user can run tasks with Pipeline or Launcher. + +### 3.1 Installing Python Environment +- Prepare and install [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment. +- Create a virtual environment: + +```shell +# FATE-LLM requires Python >= 3.10 +conda create -n fate_env python=3.10 +conda activate fate_env +``` + +### 3.2 Installing FATE-LLM with FATE, FATE-Flow, FATE-Client + +```shell +pip install fate_client[fate,fate_flow,fate_client]==2.2.0 +``` + +### 3.3 Service Initialization + +```shell +fate_flow init --ip 127.0.0.1 --port 9380 --home $HOME_DIR +pipeline init --ip 127.0.0.1 --port 9380 +``` +- `ip`: The IP address where the service runs. +- `port`: The HTTP port the service runs on. +- `home`: The data storage directory, including data, models, logs, job configurations, and SQLite databases. + +### 3.4 Start Fate-Flow Service + +```shell +fate_flow start +fate_flow status # make sure fate_flow service is started +``` + +FATE-Flow also provides other instructions like stop and restart, use only if users want to stop/restart fate_flow services. +```shell +# Warning: normal installing process does not need to execute stop/restart instructions. +fate_flow stop +fate_flow restart +``` + +### 3.5 Usage +Please refer to [tutorials](../README.md#quick-start) for more usage guides, tasks describe in the tutorials running will Pipeline or Launcher are all supported. From a168ac39e2fa119144822d4147df008c51a30384 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 31 Jul 2024 17:03:38 +0800 Subject: [PATCH 41/50] update setup script: add fate_client extra reqs Signed-off-by: mgqa34 --- python/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index 07bc655..980928c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -49,7 +49,8 @@ extras_require = { "fate": ["pyfate==2.2.0"], - "fate_flow": ["fate_flow==2.2.0"] + "fate_flow": ["fate_flow==2.2.0"], + "fate_client": ["fate_client==2.2.0"] } # Configure and call the setup function From 31cba66900c287f19a8e2a8a49ebe03c2f86047e Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 31 Jul 2024 17:41:31 +0800 Subject: [PATCH 42/50] add deployment tutorial link to readme Signed-off-by: mgqa34 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 84891bf..802c6a7 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ FATE-LLM is a framework to support federated learning for large language models( ### Standalone deployment -* To deploy FATE-LLM v2.2.0 or higher version, three ways are provided, please refer deploy tutorials for more details: +* To deploy FATE-LLM v2.2.0 or higher version, three ways are provided, please refer [deploy tutorial](./doc/standalone_deploy.md) for more details: * deploy with FATE only from pypi then using Launcher to run tasks * deploy with FATE、FATE-Flow、FATE-Client from pypi, user can run tasks with Pipeline * To deploy lower versions: please refer to [FATE-Standalone deployment](https://github.com/FederatedAI/FATE#standalone-deployment). From f97ef4d9cae010164014bfb7fd8d0ab73a87d1d6 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 31 Jul 2024 20:33:39 +0800 Subject: [PATCH 43/50] limit datasets version Signed-off-by: mgqa34 --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index 980928c..b512d28 100644 --- a/python/setup.py +++ b/python/setup.py @@ -29,7 +29,7 @@ "sentencepiece==0.2.0", "lm_eval==0.4.2", "rouge-score==0.1.2", - "datasets", + "datasets==2.18.0", "editdistance", "torch==2.3.1", "transformers==4.37.2", From 2946ea26d62ee032796851b4090712c04a195986 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 1 Aug 2024 15:40:36 +0800 Subject: [PATCH 44/50] update doc Signed-off-by: mgqa34 --- doc/tutorial/fdkt/fdkt.ipynb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/tutorial/fdkt/fdkt.ipynb b/doc/tutorial/fdkt/fdkt.ipynb index 76c9312..35eb4f1 100644 --- a/doc/tutorial/fdkt/fdkt.ipynb +++ b/doc/tutorial/fdkt/fdkt.ipynb @@ -426,6 +426,7 @@ " seq_num_for_single_category=2000,\n", " slm_generation_config=dict(\n", " max_new_tokens=256,\n", + " do_sample=True,\n", " temperature=1.0,\n", " top_k=50,\n", " top_p=0.9,\n", @@ -636,6 +637,7 @@ " seq_num_for_single_category=2000,\n", " slm_generation_config=dict(\n", " max_new_tokens=256,\n", + " do_sample=True,\n", " temperature=1.0,\n", " top_k=50,\n", " top_p=0.9,\n", From cc9c82e6575019668613ac35a6812245d2007632 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 1 Aug 2024 15:40:55 +0800 Subject: [PATCH 45/50] optimizer standalone deploy doc Signed-off-by: mgqa34 --- doc/standalone_deploy.md | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/doc/standalone_deploy.md b/doc/standalone_deploy.md index aa8ffd0..38a6d24 100644 --- a/doc/standalone_deploy.md +++ b/doc/standalone_deploy.md @@ -41,14 +41,7 @@ After installing successfully, please refer to [tutorials](../README.md#quick-st In this way, user can run tasks with Pipeline or Launcher. ### 3.1 Installing Python Environment -- Prepare and install [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment. -- Create a virtual environment: - -```shell -# FATE-LLM requires Python >= 3.10 -conda create -n fate_env python=3.10 -conda activate fate_env -``` +Please refer to section-2.1 ### 3.2 Installing FATE-LLM with FATE, FATE-Flow, FATE-Client @@ -59,7 +52,8 @@ pip install fate_client[fate,fate_flow,fate_client]==2.2.0 ### 3.3 Service Initialization ```shell -fate_flow init --ip 127.0.0.1 --port 9380 --home $HOME_DIR +mkdir fate_workspace +fate_flow init --ip 127.0.0.1 --port 9380 --home $(pwd)/fate_workspace pipeline init --ip 127.0.0.1 --port 9380 ``` - `ip`: The IP address where the service runs. From 660efcbd0dc66f1265168f3770340dd3cbaa0b99 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 1 Aug 2024 17:21:02 +0800 Subject: [PATCH 46/50] fix tutorial Signed-off-by: mgqa34 --- doc/tutorial/fdkt/fdkt.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/tutorial/fdkt/fdkt.ipynb b/doc/tutorial/fdkt/fdkt.ipynb index 35eb4f1..e4f0394 100644 --- a/doc/tutorial/fdkt/fdkt.ipynb +++ b/doc/tutorial/fdkt/fdkt.ipynb @@ -39,7 +39,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The following code will sample 5000 datalines of 'Health' subdomain, and train data will generated under the folder './balance_processed_data/Health/train.json'" + "The following code will sample 5000 datalines of 'Health' subdomain, and train data will generated under the folder './processed_data/Health/train.json'" ] }, { @@ -391,7 +391,7 @@ "llm_pretrained_path = \"Qwen1.5-7B-Chat\"\n", "embedding_model_path = \"all-mpnet-base-v2\"\n", "slm_pretrained_path = \"gpt2-xl\"\n", - "slm_data_path = \"./process/Health/train.json\"\n", + "slm_data_path = \"./processed_data/Health/train.json\"\n", "\n", "\n", "def get_optimizer(model, optimizer=\"adam\", lr=1e-4):\n", @@ -559,7 +559,7 @@ "llm_pretrained_path = \"Qwen1.5-7B-Chat\"\n", "embedding_model_path = \"all-mpnet-base-v2/\"\n", "slm_pretrained_path = \"gpt2-xl\"\n", - "slm_data_path = \"./process/Health/train.json\" # should be absolute path\n", + "slm_data_path = \"./processed_data/Health/train.json\" # should be absolute path\n", "\n", "\n", "def get_llm_conf():\n", From 8ac3bbb56483e820f86b6e2a4d03fdc7b68f681e Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 1 Aug 2024 17:23:36 +0800 Subject: [PATCH 47/50] fix tutorial params Signed-off-by: mgqa34 --- doc/tutorial/fdkt/fdkt.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/tutorial/fdkt/fdkt.ipynb b/doc/tutorial/fdkt/fdkt.ipynb index e4f0394..e6a4c0f 100644 --- a/doc/tutorial/fdkt/fdkt.ipynb +++ b/doc/tutorial/fdkt/fdkt.ipynb @@ -472,7 +472,7 @@ "\n", " embedding_lm = SentenceTransformerModel(model_name_or_path=embedding_model_path).load()\n", " training_args = FDKTTrainingArguments(\n", - " sample_num_per_cluster=5,\n", + " sample_num_per_cluster=4,\n", " filter_prompt_max_length=2**14,\n", " filter_generation_config=dict(\n", " max_tokens=4096,\n", @@ -581,7 +581,7 @@ " )\n", "\n", " training_args = FDKTTrainingArguments(\n", - " sample_num_per_cluster=5,\n", + " sample_num_per_cluster=4,\n", " filter_prompt_max_length=2 ** 14,\n", " filter_generation_config=dict(\n", " max_tokens=4096,\n", From 9f8e2e56baa3b2ca8b3eb4aec7c740bf472ccebb Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 1 Aug 2024 18:01:16 +0800 Subject: [PATCH 48/50] update readme to pdss docs Signed-off-by: mgqa34 --- doc/tutorial/pdss/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 doc/tutorial/pdss/README.md diff --git a/doc/tutorial/pdss/README.md b/doc/tutorial/pdss/README.md new file mode 100644 index 0000000..33fd710 --- /dev/null +++ b/doc/tutorial/pdss/README.md @@ -0,0 +1,13 @@ +# FATE-LLM: PDSS +The Algorithm is based on paper ["PDSS: A Privacy-Preserving Framework for Step-by-Step Distillation of Large Language Models"](https://arxiv.org/pdf/2406.12403), which introduce a novel framework for privacy preserving federated distillation. We integrate its code into the FATE-LLM framework. + +## Citation +If you publish work that uses PDSS, please cite PDSS as follows: +``` +@article{fan2024pdss, + title={PDSS: A Privacy-Preserving Framework for Step-by-Step Distillation of Large Language Models}, + author={Fan, Tao and Kang, Yan and Chen, Weijing and Gu, Hanlin and Song, Yuanfeng and Fan, Lixin and Chen, Kai and Yang, Qiang}, + journal={arXiv preprint arXiv:2406.12403}, + year={2024} +} +``` \ No newline at end of file From 74c4d4005af304ca51f2c5cbae9fe3728f15018e Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 1 Aug 2024 23:40:12 +0800 Subject: [PATCH 49/50] add readme to fdkt Signed-off-by: mgqa34 --- doc/tutorial/fdkt/README.md | 14 ++++++++++++++ doc/tutorial/fdkt/fdkt.ipynb | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 doc/tutorial/fdkt/README.md diff --git a/doc/tutorial/fdkt/README.md b/doc/tutorial/fdkt/README.md new file mode 100644 index 0000000..a3d32e8 --- /dev/null +++ b/doc/tutorial/fdkt/README.md @@ -0,0 +1,14 @@ +# FATE-LLM: FDKT +The Algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), +a novel framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy. + +## Citation +If you publish work that uses FDKT, please cite FDKT as follows: +``` +@article{li2024federated, + title={Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data}, + author={Li, Haoran and Zhao, Xinyuan and Guo, Dadi and Gu, Hanlin and Zeng, Ziqian and Han, Yuxing and Song, Yangqiu and Fan, Lixin and Yang, Qiang}, + journal={arXiv preprint arXiv:2405.14212}, + year={2024} +} +``` \ No newline at end of file diff --git a/doc/tutorial/fdkt/fdkt.ipynb b/doc/tutorial/fdkt/fdkt.ipynb index e6a4c0f..ef620f1 100644 --- a/doc/tutorial/fdkt/fdkt.ipynb +++ b/doc/tutorial/fdkt/fdkt.ipynb @@ -528,7 +528,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Running FEDMKT with Pipeline (Industrial Using)" + "## Running FDKT with Pipeline (Industrial Using)" ] }, { From 2ff4b062f02e4199104c9a8706bf70c05ed6cfe6 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Fri, 2 Aug 2024 14:26:51 +0800 Subject: [PATCH 50/50] fix doc Signed-off-by: mgqa34 --- doc/tutorial/fdkt/README.md | 4 ++-- doc/tutorial/fedmkt/README.md | 4 ++-- doc/tutorial/pdss/README.md | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/tutorial/fdkt/README.md b/doc/tutorial/fdkt/README.md index a3d32e8..ee8fe6f 100644 --- a/doc/tutorial/fdkt/README.md +++ b/doc/tutorial/fdkt/README.md @@ -1,5 +1,5 @@ # FATE-LLM: FDKT -The Algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), +The algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), a novel framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy. ## Citation @@ -11,4 +11,4 @@ If you publish work that uses FDKT, please cite FDKT as follows: journal={arXiv preprint arXiv:2405.14212}, year={2024} } -``` \ No newline at end of file +``` diff --git a/doc/tutorial/fedmkt/README.md b/doc/tutorial/fedmkt/README.md index 6ff5df9..d0218ff 100644 --- a/doc/tutorial/fedmkt/README.md +++ b/doc/tutorial/fedmkt/README.md @@ -1,6 +1,6 @@ # FATE-LLM: FedMKT -The Algorithm is based on paper ["FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models"](https://arxiv.org/pdf/2406.02224), We integrate its code into the FATE-LLM framework. +The algorithm is based on paper ["FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models"](https://arxiv.org/pdf/2406.02224), We integrate its code into the FATE-LLM framework. ## Citation If you publish work that uses FedMKT, please cite FedMKT as follows: @@ -11,4 +11,4 @@ If you publish work that uses FedMKT, please cite FedMKT as follows: journal={arXiv preprint arXiv:2406.02224}, year={2024} } -``` \ No newline at end of file +``` diff --git a/doc/tutorial/pdss/README.md b/doc/tutorial/pdss/README.md index 33fd710..aaaf59b 100644 --- a/doc/tutorial/pdss/README.md +++ b/doc/tutorial/pdss/README.md @@ -1,5 +1,5 @@ # FATE-LLM: PDSS -The Algorithm is based on paper ["PDSS: A Privacy-Preserving Framework for Step-by-Step Distillation of Large Language Models"](https://arxiv.org/pdf/2406.12403), which introduce a novel framework for privacy preserving federated distillation. We integrate its code into the FATE-LLM framework. +The algorithm is based on paper ["PDSS: A Privacy-Preserving Framework for Step-by-Step Distillation of Large Language Models"](https://arxiv.org/pdf/2406.12403), which introduce a novel framework for privacy preserving federated distillation. We integrate its code into the FATE-LLM framework. ## Citation If you publish work that uses PDSS, please cite PDSS as follows: @@ -10,4 +10,4 @@ If you publish work that uses PDSS, please cite PDSS as follows: journal={arXiv preprint arXiv:2406.12403}, year={2024} } -``` \ No newline at end of file +```