Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Basic distilling. #6527

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,30 @@ class LoraArguments:
)


@dataclass
class DistillationArguments:
r"""
Arguments pertaining to the distillation training.
"""

distilling_lambda: float = field(
default=0.5,
metadata={"help": "The lambda parameter in the distilling loss."},
)
distilling_temperature: float = field(
default=1.0,
metadata={"help": "The temperature parameter in the distilling softmax."},
)
teacher_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the teacher model used for the distilling."},
)
teacher_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the teacher model."},
)


@dataclass
class RLHFArguments:
r"""
Expand Down Expand Up @@ -334,7 +358,13 @@ class SwanLabArguments:

@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
FreezeArguments,
LoraArguments,
RLHFArguments,
GaloreArguments,
BAdamArgument,
SwanLabArguments,
DistillationArguments,
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Expand All @@ -344,7 +374,7 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "distillation"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
Expand Down
18 changes: 18 additions & 0 deletions src/llamafactory/train/distil/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .workflow import run_distillation


__all__ = ["run_distillation"]
158 changes: 158 additions & 0 deletions src/llamafactory/train/distil/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override

from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler


if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin

from ...hparams import FinetuningArguments


logger = logging.get_logger(__name__)


class CustomDistillationTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""

def __init__(
self,
teacher_model: Union["PreTrainedModel", torch.nn.Module],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")

super().__init__(**kwargs)

if teacher_model is not None:
if self.is_deepspeed_enabled:
if not (
getattr(teacher_model, "is_loaded_in_8bit", False)
or getattr(teacher_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.teacher_model = self._prepare_deepspeed(teacher_model)
else:
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
self.finetuning_args = finetuning_args

if processor is not None:
self.add_callback(SaveProcessorCallback(processor))

if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()

@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return super()._get_train_sampler()

@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
labels = inputs.get("labels")
padding_mask = labels.eq(-100)
label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
self.teacher_model.eval()
with torch.no_grad():
teacher_outputs = self.teacher_model(**inputs)
# Shape: (batch_size, seq_len, vocab_size)
teacher_prob = torch.nn.functional.softmax(
teacher_outputs.logits / self.finetuning_args.distilling_temperature, dim=-1
)
student_logprob = torch.nn.functional.log_softmax(
outputs.logits / self.finetuning_args.distilling_temperature, dim=-1
)
kl_losses = (teacher_prob * (teacher_prob.log() - student_logprob)).sum(dim=-1)
kl_losses.masked_fill_(padding_mask, 0)
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
loss = (
self.finetuning_args.distilling_lambda
* kl_losses.mean()
/ (num_active_elements * student_logprob.shape[-1])
+ label_loss
)

if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
loss = loss / self.args.gradient_accumulation_steps

return (loss, outputs) if return_outputs else loss

def _prepare_deepspeed(self, model: "PreTrainedModel"):
import deepspeed # type: ignore

deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
140 changes: 140 additions & 0 deletions src/llamafactory/train/distil/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional

from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps, get_logits_processor
from ...extras.ploting import plot_loss
from ...hparams import FinetuningArguments, ModelArguments
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .trainer import CustomDistillationTrainer


if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback

from ...hparams import DataArguments, GeneratingArguments


logger = get_logger(__name__)


def run_distillation(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

# Load teacher model
# TODO teacher_model_quantization_bit
teacher_model_args = ModelArguments.copyfrom(
model_args,
model_name_or_path=finetuning_args.teacher_model,
adapter_name_or_path=finetuning_args.teacher_model_adapters,
)
teacher_finetuning_args = FinetuningArguments()
teacher_model = load_model(tokenizer, teacher_model_args, teacher_finetuning_args, is_trainable=False)
# Compare model and teacher tokenizer
teacher_tokenizer = load_tokenizer(teacher_model_args)["tokenizer"]
assert (
teacher_tokenizer.get_vocab() == tokenizer.get_vocab()
), "The teacher's and student's tokenizers must have the same vocabulary dictionary."

if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction

# TODO handling `prepare_decoder_input_ids_from_labels`.
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)

# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset

# Initialize our Trainer
trainer = CustomDistillationTrainer(
model=model,
teacher_model=teacher_model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**tokenizer_module,
)

# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()

# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)

trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])

if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation

# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Predict
if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)

# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
3 changes: 3 additions & 0 deletions src/llamafactory/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .distil import run_distillation
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
Expand Down Expand Up @@ -65,6 +66,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "distillation":
run_distillation(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")

Expand Down