diff --git a/examples/huggingface/training.py b/examples/huggingface/training.py index a3083fd0a..320050b97 100644 --- a/examples/huggingface/training.py +++ b/examples/huggingface/training.py @@ -13,9 +13,7 @@ @dataclass class CustomArguments: - model_name: str = ( - "meta-llama/Meta-Llama-3-8B" - ) + model_name: str = "meta-llama/Meta-Llama-3-8B" dataset: str = "tatsu-lab/alpaca" max_seq_length: int = 512 use_liger: bool = False diff --git a/examples/lightning/training.py b/examples/lightning/training.py index 997ae99e9..cd502e970 100644 --- a/examples/lightning/training.py +++ b/examples/lightning/training.py @@ -7,13 +7,14 @@ import lightning.pytorch as pl import torch import transformers -from liger_kernel.transformers import apply_liger_kernel_to_llama from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision from torch.utils.data import DataLoader from transformers.models.llama.modeling_llama import LlamaDecoderLayer from trl import DataCollatorForCompletionOnlyLM +from liger_kernel.transformers import apply_liger_kernel_to_llama + apply_liger_kernel_to_llama(fused_linear_cross_entropy=True, cross_entropy=False)