diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 8b1c405b7..000000000 --- a/.flake8 +++ /dev/null @@ -1,10 +0,0 @@ -# .flake8 -[flake8] -max-line-length = 120 -exclude = - .git, - __pycache__, - benchmark_internal/others, - .venv -# E203: https://github.com/psf/black/issues/315 -extend-ignore=E501,B006,E731,A002,E203 diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 6860bdb0e..000000000 --- a/.isort.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[settings] -profile = black \ No newline at end of file diff --git a/Makefile b/Makefile index 00b677d3e..ed17e4a0f 100644 --- a/Makefile +++ b/Makefile @@ -10,10 +10,9 @@ test: # Command to run flake8 (code style check), isort (import ordering), and black (code formatting) # Subsequent commands still run if the previous fails, but return failure at the end checkstyle: - flake8 .; flake8_status=$$?; \ - isort .; isort_status=$$?; \ - black .; black_status=$$?; \ - if [ $$flake8_status -ne 0 ] || [ $$isort_status -ne 0 ] || [ $$black_status -ne 0 ]; then \ + ruff check . --fix; ruff_check_status=$$?; \ + ruff format .; ruff_format_status=$$?; \ + if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \ exit 1; \ fi diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py index 2cb9b1330..62be03cb2 100644 --- a/benchmark/benchmarks_visualizer.py +++ b/benchmark/benchmarks_visualizer.py @@ -1,5 +1,6 @@ import json import os + from argparse import ArgumentParser from dataclasses import dataclass @@ -39,9 +40,7 @@ def parse_args() -> VisualizationsConfig: VisualizationsConfig: Configuration object for the visualizations script. """ parser = ArgumentParser() - parser.add_argument( - "--kernel-name", type=str, required=True, help="Kernel name to benchmark" - ) + parser.add_argument("--kernel-name", type=str, required=True, help="Kernel name to benchmark") parser.add_argument( "--metric-name", type=str, @@ -54,9 +53,7 @@ def parse_args() -> VisualizationsConfig: required=True, help="Kernel operation mode to visualize (forward/backward/full)", ) - parser.add_argument( - "--display", action="store_true", help="Display the visualization" - ) + parser.add_argument("--display", action="store_true", help="Display the visualization") parser.add_argument( "--overwrite", action="store_true", @@ -126,7 +123,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig): lines = ax.get_lines() colors = [line.get_color() for line in lines] - for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): + for (_, group_data), color in zip(df.groupby("kernel_provider"), colors, strict=False): # for i, row in group_data.iterrows(): y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] @@ -145,9 +142,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig): plt.ylabel(ylabel) plt.tight_layout() - out_path = os.path.join( - VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png" - ) + out_path = os.path.join(VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png") if config.display: plt.show() diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index 5fc43c7ea..c0922f95c 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -3,14 +3,13 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.utils import infer_device @@ -33,9 +32,7 @@ def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100) from test.chunked_loss.test_cpo_loss import HFCPOLoss super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.cpo_loss = HFCPOLoss().get_batch_loss_metrics def forward(self, x, y): @@ -45,9 +42,7 @@ def forward(self, x, y): class LigerLMHeadCPO(torch.nn.Module): def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.cpo_loss = LigerFusedLinearCPOFunction.apply def forward(self, x, y): @@ -180,12 +175,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_fused_linear_cpo_loss, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_cross_entropy.py b/benchmark/scripts/benchmark_cross_entropy.py index f7b749c98..62ad6127f 100644 --- a/benchmark/scripts/benchmark_cross_entropy.py +++ b/benchmark/scripts/benchmark_cross_entropy.py @@ -1,14 +1,13 @@ import torch import triton + from torch.nn import CrossEntropyLoss -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.utils import infer_device @@ -86,9 +85,7 @@ def full(): y = fwd() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, rep=100, quantiles=QUANTILES - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) return SingleBenchmarkRunOutput( y_20=ms_20, @@ -115,12 +112,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_cross_entropy, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index af8e3dac5..23a4b08c1 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -1,15 +1,13 @@ -from test.chunked_loss.test_dpo_loss import HF_DPO_Loss - import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from test.chunked_loss.test_dpo_loss import HF_DPO_Loss +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction from liger_kernel.utils import infer_device @@ -28,9 +26,7 @@ def __init__( bias: bool = False, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index) def forward(self, x, target): @@ -53,9 +49,7 @@ def __init__( bias: bool = False, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.beta = beta self.ignore_index = ignore_index @@ -82,12 +76,8 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO ignore_index = input.extra_benchmark_config["ignore_index"] provider = input.kernel_provider - torch_dpo_loss = TorchDPOLoss( - H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias - ).to(device) - liger_dpo_loss = LigerDPOLoss( - H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias - ).to(device) + torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) # Input shape: [B, T, H] _input = torch.randn(B, T, H, device=device, dtype=dtype) @@ -129,12 +119,8 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu provider = input.kernel_provider mode = input.kernel_operation_mode - torch_dpo_loss = TorchDPOLoss( - H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias - ).to(device) - liger_dpo_loss = LigerDPOLoss( - H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias - ).to(device) + torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) # Input shape: [B, T, H] _input = torch.randn(B, T, H, device=device, dtype=dtype) @@ -215,7 +201,7 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( @@ -223,5 +209,5 @@ def full(): kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_embedding.py b/benchmark/scripts/benchmark_embedding.py index 40722ee1b..72153e021 100644 --- a/benchmark/scripts/benchmark_embedding.py +++ b/benchmark/scripts/benchmark_embedding.py @@ -1,14 +1,13 @@ import torch import triton + from torch.nn import Embedding -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.experimental.embedding import LigerEmbedding from liger_kernel.utils import infer_device @@ -50,9 +49,7 @@ def full(): if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) elif mode == "full": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, quantiles=QUANTILES, rep=100 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) return SingleBenchmarkRunOutput( y_20=ms_20, y_50=ms_50, @@ -118,12 +115,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_embedding, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py index 2e3b08732..a3cc584de 100644 --- a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -1,17 +1,14 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) - -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.utils import infer_device device = infer_device() @@ -28,12 +25,8 @@ class TorchLMHeadCE(torch.nn.Module): def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) - self.ce_loss = torch.nn.CrossEntropyLoss( - ignore_index=ignore_index, reduction="mean" - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction="mean") def forward(self, x, y): logits = self.lin(x) @@ -43,12 +36,8 @@ def forward(self, x, y): class LigerLMHeadCE(torch.nn.Module): def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) - self.ce_loss = LigerFusedLinearCrossEntropyLoss( - ignore_index=ignore_index, reduction="mean" - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction="mean") def forward(self, x, y): return self.ce_loss(self.lin.weight, x, y) @@ -161,9 +150,7 @@ def full(): "x_label": "B x T", "x_values": [2**i for i in range(12, 16)], "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} - ], + "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], "overwrite": args.overwrite, } @@ -172,12 +159,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_fused_linear_cross_entropy, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index dcefb2137..7fc3d4e9a 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -1,13 +1,12 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD from liger_kernel.utils import infer_device @@ -37,9 +36,9 @@ def forward( log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) - loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( - 1 - self.beta - ) * self.kl(torch.log(m), log_q).sum(dim=-1) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ).sum(dim=-1) if label is not None: loss = torch.where(label != self.ignore_index, loss, 0.0) @@ -73,12 +72,8 @@ def __init__( temperature: float = 1.0, ): super().__init__() - self.student_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) - self.teacher_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) + self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) self.temperature = temperature @@ -103,15 +98,9 @@ def __init__( temperature: float = 1.0, ): super().__init__() - self.student_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) - self.teacher_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) - self.fused_jsd = LigerFusedLinearJSD( - jsd_beta=beta, ignore_index=ignore_index, temperature=temperature - ) + self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.fused_jsd = LigerFusedLinearJSD(jsd_beta=beta, ignore_index=ignore_index, temperature=temperature) def forward(self, student_input, teacher_input, label=None): return self.fused_jsd( @@ -141,12 +130,12 @@ def bench_memory_fused_linear_jsd( liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) teacher_input = torch.rand(BT, H, dtype=dtype, device=device) @@ -189,12 +178,12 @@ def bench_speed_fused_linear_jsd( liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) teacher_input = torch.rand(BT, H, dtype=dtype, device=device) @@ -251,9 +240,7 @@ def full(): "x_label": "B x T", "x_values": [2**i for i in range(10, 14)], "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} - ], + "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], "overwrite": args.overwrite, } @@ -262,12 +249,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_fused_linear_jsd, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_geglu.py b/benchmark/scripts/benchmark_geglu.py index 7b0d237ca..a8923251d 100644 --- a/benchmark/scripts/benchmark_geglu.py +++ b/benchmark/scripts/benchmark_geglu.py @@ -1,15 +1,14 @@ import torch import triton + from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.utils import infer_device diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py index 0c3c05608..5a8bf37f4 100644 --- a/benchmark/scripts/benchmark_group_norm.py +++ b/benchmark/scripts/benchmark_group_norm.py @@ -1,13 +1,12 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.group_norm import LigerGroupNorm from liger_kernel.utils import infer_device @@ -27,12 +26,8 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun dtype = extra_benchmark_config["dtype"] x_shape = (M, C, H) - triton_ln = LigerGroupNorm( - num_channels=C, num_groups=C // channels_per_group, eps=eps - ).to(device) - torch_ln = torch.nn.GroupNorm( - num_groups=C // channels_per_group, num_channels=C, eps=eps - ).to(device) + triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device) + torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device) x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) @@ -45,9 +40,7 @@ def y_fwd(): return torch_ln(x) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) elif mode == "backward": y = y_fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( @@ -62,9 +55,7 @@ def full(): y = y_fwd() y.backward(dy, retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, quantiles=QUANTILES, grad_to_none=[x], rep=500 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) return SingleBenchmarkRunOutput( y_20=ms_20, @@ -84,12 +75,8 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu dtype = extra_benchmark_config["dtype"] x_shape = (M, C, H) - triton_ln = LigerGroupNorm( - num_channels=C, num_groups=C // channels_per_group, eps=eps - ).to(device) - torch_ln = torch.nn.GroupNorm( - num_groups=C // channels_per_group, num_channels=C, eps=eps - ).to(device) + triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device) + torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device) x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) @@ -139,12 +126,12 @@ def full(): kernel_operation_modes=["forward", "full", "backward"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_group_norm, kernel_operation_modes=["full", "forward", "backward"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py index c5f8bec18..004ec4a30 100644 --- a/benchmark/scripts/benchmark_jsd.py +++ b/benchmark/scripts/benchmark_jsd.py @@ -1,13 +1,12 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.jsd import LigerJSD from liger_kernel.utils import infer_device @@ -37,9 +36,9 @@ def forward( log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) - loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( - 1 - self.beta - ) * self.kl(torch.log(m), log_q).sum(dim=-1) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ).sum(dim=-1) if label is not None: loss = torch.where(label != self.ignore_index, loss, 0.0) @@ -59,9 +58,7 @@ def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: torch_jsd = TorchJSD() liger_jsd = LigerJSD() - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( - dim=-1 - ) + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) def fwd(): @@ -87,9 +84,7 @@ def full(): y = fwd() y.backward(retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, quantiles=QUANTILES, rep=100 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) return SingleBenchmarkRunOutput( y_20=ms_20, y_50=ms_50, @@ -104,9 +99,7 @@ def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput V = input.x B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( - dim=-1 - ) + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) def fwd(): diff --git a/benchmark/scripts/benchmark_kl_div.py b/benchmark/scripts/benchmark_kl_div.py index c52d8e658..4788826c8 100644 --- a/benchmark/scripts/benchmark_kl_div.py +++ b/benchmark/scripts/benchmark_kl_div.py @@ -1,14 +1,13 @@ import torch import torch.nn as nn import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.kl_div import LigerKLDIVLoss from liger_kernel.utils import infer_device @@ -25,9 +24,7 @@ def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu torch_kl_div = nn.KLDivLoss(reduction=reduction) liger_kl_div = LigerKLDIVLoss(reduction=reduction) - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( - dim=-1 - ) + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) target = torch.randn(B * T, V, device=device).softmax(dim=-1) def fwd(): @@ -53,9 +50,7 @@ def full(): y = fwd() y.backward(retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, quantiles=QUANTILES, rep=100 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) return SingleBenchmarkRunOutput( y_20=ms_20, y_50=ms_50, @@ -71,9 +66,7 @@ def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp V = input.x B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( - dim=-1 - ) + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) target = torch.randn(B * T, V, device=device).softmax(dim=-1) def fwd(): diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 4d36d4b4b..91650214e 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -1,13 +1,12 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.utils import infer_device @@ -39,9 +38,7 @@ def y_fwd(): return torch_ln(x) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) elif mode == "backward": y = y_fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( @@ -56,9 +53,7 @@ def full(): y = y_fwd() y.backward(dy, retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, quantiles=QUANTILES, grad_to_none=[x], rep=500 - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) return SingleBenchmarkRunOutput( y_20=ms_20, @@ -119,12 +114,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_layer_norm, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py index e1b2c8d25..ae50c1d49 100644 --- a/benchmark/scripts/benchmark_orpo_loss.py +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -3,14 +3,13 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.utils import infer_device @@ -33,9 +32,7 @@ def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100) from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics def forward(self, x, y): @@ -45,9 +42,7 @@ def forward(self, x, y): class LigerLMHeadORPO(torch.nn.Module): def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.orpo_loss = LigerFusedLinearORPOFunction.apply def forward(self, x, y): @@ -180,12 +175,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_fused_linear_orpo_loss, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_qwen2vl_mrope.py b/benchmark/scripts/benchmark_qwen2vl_mrope.py index dccb37d33..a426abdba 100644 --- a/benchmark/scripts/benchmark_qwen2vl_mrope.py +++ b/benchmark/scripts/benchmark_qwen2vl_mrope.py @@ -1,17 +1,14 @@ import torch import triton -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLRotaryEmbedding, - apply_multimodal_rotary_pos_emb, -) -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding +from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb from liger_kernel.utils import infer_device @@ -31,16 +28,8 @@ def bench_speed_qwen2vl_mrope( dtype = extra_benchmark_config["dtype"] # x can be either hidden_size or seq_len - hidden_size = ( - extra_benchmark_config["hidden_size"] - if "hidden_size" in extra_benchmark_config - else input.x - ) - seq_len = ( - extra_benchmark_config["seq_len"] - if "seq_len" in extra_benchmark_config - else input.x - ) + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x head_dim = hidden_size // num_q_heads rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) @@ -56,8 +45,9 @@ def bench_speed_qwen2vl_mrope( requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( - k, device=device + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), ) pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) @@ -87,9 +77,7 @@ def fwd(): elif mode == "backward": q_out, k_out = fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: torch.autograd.grad( - (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True - ), + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], rep=400, quantiles=QUANTILES, @@ -124,16 +112,8 @@ def bench_memory_qwen2vl_mrope( dtype = extra_benchmark_config["dtype"] # x can be either hidden_size or seq_len - hidden_size = ( - extra_benchmark_config["hidden_size"] - if "hidden_size" in extra_benchmark_config - else input.x - ) - seq_len = ( - extra_benchmark_config["seq_len"] - if "seq_len" in extra_benchmark_config - else input.x - ) + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x head_dim = hidden_size // num_q_heads rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) @@ -149,8 +129,9 @@ def bench_memory_qwen2vl_mrope( requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( - k, device=device + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), ) pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) @@ -164,16 +145,10 @@ def bench_memory_qwen2vl_mrope( def full(): if provider == "liger": - q_out, k_out = liger_multimodal_rotary_pos_emb( - q, k, cos, sin, mrope_section - ) + q_out, k_out = liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) else: - q_out, k_out = apply_multimodal_rotary_pos_emb( - q, k, cos, sin, mrope_section - ) - torch.autograd.grad( - (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True - ) + q_out, k_out = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) mem_50, mem_20, mem_80 = _test_memory( full, diff --git a/benchmark/scripts/benchmark_rms_norm.py b/benchmark/scripts/benchmark_rms_norm.py index 533a13aec..6bcd56a83 100644 --- a/benchmark/scripts/benchmark_rms_norm.py +++ b/benchmark/scripts/benchmark_rms_norm.py @@ -1,14 +1,13 @@ import torch import torch.nn as nn import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.utils import infer_device @@ -152,12 +151,12 @@ def full(): kernel_operation_modes=["forward", "full", "backward"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_rms_norm, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_rope.py b/benchmark/scripts/benchmark_rope.py index b505c6fe9..f0c2a4f02 100644 --- a/benchmark/scripts/benchmark_rope.py +++ b/benchmark/scripts/benchmark_rope.py @@ -1,17 +1,14 @@ import torch import triton -from transformers.models.llama.modeling_llama import ( - LlamaRotaryEmbedding, - apply_rotary_pos_emb, -) -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.utils import infer_device @@ -29,16 +26,8 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput dtype = extra_benchmark_config["dtype"] # x can be either hidden_size or seq_len - hidden_size = ( - extra_benchmark_config["hidden_size"] - if "hidden_size" in extra_benchmark_config - else input.x - ) - seq_len = ( - extra_benchmark_config["seq_len"] - if "seq_len" in extra_benchmark_config - else input.x - ) + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x head_dim = hidden_size // num_q_heads rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) @@ -54,8 +43,9 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( - k, device=device + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), ) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) @@ -78,9 +68,7 @@ def fwd(): elif mode == "backward": q_out, k_out = fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: torch.autograd.grad( - (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True - ), + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], rep=400, quantiles=QUANTILES, @@ -113,16 +101,8 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu dtype = extra_benchmark_config["dtype"] # x can be either hidden_size or seq_len - hidden_size = ( - extra_benchmark_config["hidden_size"] - if "hidden_size" in extra_benchmark_config - else input.x - ) - seq_len = ( - extra_benchmark_config["seq_len"] - if "seq_len" in extra_benchmark_config - else input.x - ) + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x head_dim = hidden_size // num_q_heads rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) @@ -138,8 +118,9 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( - k, device=device + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), ) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) @@ -149,9 +130,7 @@ def full(): q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids) else: q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids) - torch.autograd.grad( - (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True - ) + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) mem_50, mem_20, mem_80 = _test_memory( full, diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py index a8ee48dea..46eb4f31e 100644 --- a/benchmark/scripts/benchmark_simpo_loss.py +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -3,14 +3,13 @@ import torch import triton -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction from liger_kernel.utils import infer_device @@ -33,9 +32,7 @@ def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100) from test.chunked_loss.test_cpo_loss import HFCPOLoss super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics def forward(self, x, y): @@ -45,9 +42,7 @@ def forward(self, x, y): class LigerLMHeadSimPO(torch.nn.Module): def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.simpo_loss = LigerFusedLinearSimPOFunction.apply def forward(self, x, y): @@ -180,12 +175,12 @@ def full(): kernel_operation_modes=["forward", "full"], metric_name="speed", metric_unit="ms", - **common_configs + **common_configs, ) run_benchmarks( bench_test_fn=bench_memory_fused_linear_simpo_loss, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **common_configs + **common_configs, ) diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index 5feedb557..739284210 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -1,15 +1,14 @@ import torch import triton + from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.utils import infer_device @@ -128,9 +127,7 @@ def full(): elif mode == "backward": do = torch.randn_like(x) y = fwd() - mem_50, mem_20, mem_80 = _test_memory( - lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES - ) + mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES) else: mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index 6fa80a888..abe0232b6 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -3,11 +3,18 @@ import json import os import time + from collections import OrderedDict -from dataclasses import asdict, dataclass +from dataclasses import asdict +from dataclasses import dataclass from importlib.metadata import version from itertools import zip_longest -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Union import torch @@ -100,9 +107,7 @@ def _test_memory( total_mem = torch.tensor(total_mem, dtype=torch.float) if quantiles is not None: - quantiles_data = torch.quantile( - total_mem, torch.tensor(quantiles, dtype=torch.float) - ).tolist() + quantiles_data = torch.quantile(total_mem, torch.tensor(quantiles, dtype=torch.float)).tolist() if len(quantiles_data) == 1: quantiles_data = quantiles_data[0] return quantiles_data @@ -174,11 +179,7 @@ def create_unique_key(row): row["metric_name"], row["x_name"], str(row["x_value"]), - ( - row["extra_benchmark_config_str"] - if row["extra_benchmark_config_str"] - else "" - ), + (row["extra_benchmark_config_str"] if row["extra_benchmark_config_str"] else ""), row["gpu_name"], ) @@ -196,9 +197,7 @@ def create_unique_key(row): for row in reader: existing_data.append(row) - existing_data_dict = OrderedDict( - (create_unique_key(row), row) for row in existing_data - ) + existing_data_dict = OrderedDict((create_unique_key(row), row) for row in existing_data) for benchmark_data in benchmark_data_list: benchmark_data_dict = asdict(benchmark_data) @@ -208,9 +207,7 @@ def create_unique_key(row): y_values_80 = benchmark_data_dict.pop("y_values_80") # Need to convert benchmark_data into multiple rows based on x_values and y_values - for x_value, y_value_50, y_value_20, y_value_80 in zip_longest( - x_values, y_values_50, y_values_20, y_values_80 - ): + for x_value, y_value_50, y_value_20, y_value_80 in zip_longest(x_values, y_values_50, y_values_20, y_values_80): row = BenchmarkDataCSVRow( x_value=x_value, y_value_50=y_value_50, @@ -306,9 +303,7 @@ def run_benchmarks( kernel_operation_mode=kernel_operation_mode, extra_benchmark_config=extra_benchmark_config, ) - benchmark_result: SingleBenchmarkRunOutput = bench_test_fn( - single_benchmark_run_input - ) + benchmark_result: SingleBenchmarkRunOutput = bench_test_fn(single_benchmark_run_input) y_values_50.append(benchmark_result.y_50) y_values_20.append(benchmark_result.y_20) y_values_80.append(benchmark_result.y_80) @@ -326,9 +321,7 @@ def run_benchmarks( y_values_50=y_values_50, y_values_20=y_values_20, y_values_80=y_values_80, - extra_benchmark_config_str=json.dumps( - extra_benchmark_config, cls=CustomEncoder - ), + extra_benchmark_config_str=json.dumps(extra_benchmark_config, cls=CustomEncoder), timestamp=get_formatted_time(), liger_version=LIGER_KERNEL_VERSION, ) @@ -337,9 +330,7 @@ def run_benchmarks( print_benchmark_data(benchmark_data_list) - update_benchmark_data_csv( - benchmark_data_list=benchmark_data_list, overwrite=overwrite - ) + update_benchmark_data_csv(benchmark_data_list=benchmark_data_list, overwrite=overwrite) def parse_benchmark_script_args(): diff --git a/dev/fmt-requirements.txt b/dev/fmt-requirements.txt index f086aa46b..1d8f48692 100644 --- a/dev/fmt-requirements.txt +++ b/dev/fmt-requirements.txt @@ -1,3 +1 @@ -flake8 -isort -black +ruff>=0.1.6 diff --git a/dev/modal/tests.py b/dev/modal/tests.py index 686540ed7..092ba042f 100644 --- a/dev/modal/tests.py +++ b/dev/modal/tests.py @@ -25,6 +25,4 @@ def liger_tests(): cwd=REMOTE_ROOT_PATH, ) subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) - subprocess.run( - ["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH - ) + subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) diff --git a/dev/modal/tests_bwd.py b/dev/modal/tests_bwd.py index 231c5b4d7..13b822047 100644 --- a/dev/modal/tests_bwd.py +++ b/dev/modal/tests_bwd.py @@ -32,6 +32,4 @@ def liger_bwd_tests(): cwd=REMOTE_ROOT_PATH, ) subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) - subprocess.run( - ["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH - ) + subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) diff --git a/examples/alignment/run_orpo.py b/examples/alignment/run_orpo.py index 38352053b..2d734f8cb 100644 --- a/examples/alignment/run_orpo.py +++ b/examples/alignment/run_orpo.py @@ -1,6 +1,8 @@ import torch + from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer from trl import ORPOConfig # noqa: F401 from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401 @@ -28,8 +30,6 @@ save_strategy="no", ) -trainer = LigerORPOTrainer( - model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset -) +trainer = LigerORPOTrainer(model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset) trainer.train() diff --git a/examples/huggingface/callback.py b/examples/huggingface/callback.py index c612a79a9..c834fc566 100644 --- a/examples/huggingface/callback.py +++ b/examples/huggingface/callback.py @@ -1,9 +1,13 @@ import time + from dataclasses import dataclass import torch import transformers -from transformers import TrainerControl, TrainerState, TrainingArguments + +from transformers import TrainerControl +from transformers import TrainerState +from transformers import TrainingArguments from liger_kernel.utils import infer_device @@ -101,9 +105,7 @@ class EfficiencyCallback(transformers.TrainerCallback): n_decimal_TPS: number of decimal points for TPS """ - def __init__( - self, n_warmup_steps=2, n_decimal_time=2, n_decimal_memory=2, n_decimal_TPS=2 - ): + def __init__(self, n_warmup_steps=2, n_decimal_time=2, n_decimal_memory=2, n_decimal_TPS=2): self.state = State( n_warmup_steps, ) @@ -130,9 +132,7 @@ def on_init_end( 'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second' ) if args.logging_steps != 1: - raise Exception( - "Please set logging_steps=1 to track the efficiency metrics accurately" - ) + raise Exception("Please set logging_steps=1 to track the efficiency metrics accurately") def on_train_begin( self, @@ -152,9 +152,7 @@ def on_log( logs: dict[str, float], **kwargs, ): - if state.global_step < ( - self.state.global_start_step + self.state.n_warmup_steps - ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): return else: # spread self.time, self.memory, self.tps to logs @@ -186,9 +184,7 @@ def on_step_end( control: TrainerControl, **kwargs, ): - if state.global_step < ( - self.state.global_start_step + self.state.n_warmup_steps - ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): # The end the current step_start_tokens_seen is the start of next iteration # tokens @@ -206,12 +202,8 @@ def on_step_end( avg_step_time = self.state.elapsed_time / self.state.elapsed_step self.time.step = global_step - self.time.step_time_sec = round_to_n_decimal( - step_time, self.precision.n_decimal_time - ) - self.time.avg_step_time_sec = round_to_n_decimal( - avg_step_time, self.precision.n_decimal_time - ) + self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time) + self.time.avg_step_time_sec = round_to_n_decimal(avg_step_time, self.precision.n_decimal_time) self.time.time_to_completion_sec = round_to_n_decimal( avg_step_time * (state.max_steps - global_step), self.precision.n_decimal_time, @@ -221,19 +213,13 @@ def on_step_end( ) # memory - step_peak_memory_allocated = getattr( - torch, self.device - ).memory.max_memory_allocated() - step_peak_memory_reserved = getattr( - torch, self.device - ).memory.max_memory_reserved() + step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated() + step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved() self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory ) - self.state.total_peak_memory_allocated = max( - self.state.total_peak_memory_allocated, step_peak_memory_allocated - ) + self.state.total_peak_memory_allocated = max(self.state.total_peak_memory_allocated, step_peak_memory_allocated) self.memory.total_peak_memory_allocated_MB = round_to_n_decimal( self.state.total_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory, @@ -243,9 +229,7 @@ def on_step_end( step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory ) - self.state.total_peak_memory_reserved = max( - self.state.total_peak_memory_reserved, step_peak_memory_reserved - ) + self.state.total_peak_memory_reserved = max(self.state.total_peak_memory_reserved, step_peak_memory_reserved) self.memory.total_peak_memory_reserved_MB = round_to_n_decimal( self.state.total_peak_memory_reserved / M_BIN_UNIT, @@ -253,9 +237,7 @@ def on_step_end( ) # tokens - step_tokens_seen = ( - state.num_input_tokens_seen - self.state.step_start_tokens_seen - ) + step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen self.state.elapsed_tokens_seen += step_tokens_seen diff --git a/examples/huggingface/launch_on_modal.py b/examples/huggingface/launch_on_modal.py index d126940c1..1171ea42d 100644 --- a/examples/huggingface/launch_on_modal.py +++ b/examples/huggingface/launch_on_modal.py @@ -25,6 +25,7 @@ import os import modal + from modal import gpu TWO_HOURS = 2 * 60 * 60 @@ -32,11 +33,7 @@ app = modal.App("liger-example") -image = ( - modal.Image.debian_slim() - .pip_install_from_requirements("requirements.txt") - .copy_local_dir(".", "/root") -) +image = modal.Image.debian_slim().pip_install_from_requirements("requirements.txt").copy_local_dir(".", "/root") if "HF_TOKEN" not in os.environ: print("HF_TOKEN not found in environment variables, using an empty token.") diff --git a/examples/huggingface/training.py b/examples/huggingface/training.py index 505600268..84a7b288e 100644 --- a/examples/huggingface/training.py +++ b/examples/huggingface/training.py @@ -3,8 +3,10 @@ import datasets import torch import transformers + from callback import EfficiencyCallback -from trl import DataCollatorForCompletionOnlyLM, SFTTrainer +from trl import DataCollatorForCompletionOnlyLM +from trl import SFTTrainer from liger_kernel.transformers import AutoLigerKernelForCausalLM @@ -22,9 +24,7 @@ def formatting_prompts_func(example): def train(): - parser = transformers.HfArgumentParser( - (transformers.TrainingArguments, CustomArguments) - ) + parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments)) training_args, custom_args = parser.parse_args_into_dataclasses() tokenizer = transformers.AutoTokenizer.from_pretrained( custom_args.model_name, @@ -33,9 +33,7 @@ def train(): ) tokenizer.pad_token = tokenizer.eos_token - dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split( - test_size=0.1 - ) + dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split(test_size=0.1) train_dataset = dataset["train"] eval_dataset = dataset["test"] response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False) diff --git a/examples/huggingface/training_multimodal.py b/examples/huggingface/training_multimodal.py index 454fdb659..13f682448 100644 --- a/examples/huggingface/training_multimodal.py +++ b/examples/huggingface/training_multimodal.py @@ -1,9 +1,11 @@ import os + from dataclasses import dataclass import datasets import torch import transformers + from callback import EfficiencyCallback from datasets import Image as ImageFeature from trl import SFTTrainer @@ -66,7 +68,7 @@ def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn. def _validate_and_extract_the_cauldron(examples) -> dict[str, list]: batch_texts = [] batch_images = [] - for images, texts in zip(examples["images"], examples["texts"]): + for images, texts in zip(examples["images"], examples["texts"], strict=False): if not images: raise ValueError("No image found in example from the_cauldron dataset") if len(images) > 1: @@ -91,16 +93,12 @@ def _format_for_convo(example, tokenizer): def train(): - parser = transformers.HfArgumentParser( - (transformers.TrainingArguments, CustomArguments) - ) + parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments)) training_args, custom_args = parser.parse_args_into_dataclasses() training_args.remove_unused_columns = False # required to not drop the image column training_args.dataset_kwargs = {"skip_prepare_dataset": True} - model, processor, image_token_id = construct_model_and_processor( - custom_args.model_name, custom_args.use_liger - ) + model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger) dataset = ( datasets.load_dataset( diff --git a/examples/lightning/training.py b/examples/lightning/training.py index 6bf068d1b..8e58d8b11 100644 --- a/examples/lightning/training.py +++ b/examples/lightning/training.py @@ -1,14 +1,19 @@ import argparse import math import os -from dataclasses import _MISSING_TYPE, dataclass + +from dataclasses import _MISSING_TYPE +from dataclasses import dataclass import datasets import lightning.pytorch as pl import torch import transformers -from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy -from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision + +from lightning.pytorch.strategies import DeepSpeedStrategy +from lightning.pytorch.strategies import FSDPStrategy +from torch.distributed.fsdp import BackwardPrefetch +from torch.distributed.fsdp import MixedPrecision from torch.utils.data import DataLoader from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer @@ -45,9 +50,7 @@ def lr_lambda(current_step): return float(current_step) / float(max(1, warmup_steps)) else: # Cosine annealing - progress = float(current_step - warmup_steps) / float( - max(1, total_steps - warmup_steps) - ) + progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress))) return lr_lambda @@ -58,9 +61,7 @@ def parse_args() -> Args: for k, v in Args.__dataclass_fields__.items(): parser.add_argument(f"--{k}", type=v.type, default=v.default) parsed = parser.parse_args() - return Args( - **{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)} - ) + return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)}) class LanguageModel(pl.LightningModule): @@ -82,9 +83,7 @@ def configure_model(self): self.model.gradient_checkpointing_enable() def forward(self, input_ids, attention_mask, labels=None, **kwargs): - return self.model( - input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs - ) + return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) def training_step(self, batch): outputs = self.model( @@ -130,8 +129,7 @@ def configure_optimizers(self): fused=True, ) lr_lambda = warmup_cosine_schedule( - warmup_steps=self.trainer.estimated_stepping_batches - * self.args.warmup_ratio, + warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio, total_steps=self.trainer.estimated_stepping_batches, min_lr=0, ) @@ -148,9 +146,7 @@ def __init__(self, tokenizer, args: Args): self.args = args self.tokenizer = tokenizer self.response_template_str = " " - response_prompt = tokenizer.encode( - f"{self.response_template_str}", add_special_tokens=False - ) + response_prompt = tokenizer.encode(f"{self.response_template_str}", add_special_tokens=False) self.collator = DataCollatorForCompletionOnlyLM( tokenizer=tokenizer, response_template=response_prompt, @@ -237,9 +233,7 @@ def train(): layers = {Qwen2DecoderLayer} else: layers = {} - raise Warning( - f"Unimplemented layer wrap policy for {args.model} in this example" - ) + raise Warning(f"Unimplemented layer wrap policy for {args.model} in this example") if args.strategy == "fsdp": strategy = FSDPStrategy( @@ -248,9 +242,7 @@ def train(): backward_prefetch=BackwardPrefetch.BACKWARD_PRE, sync_module_states=True, activation_checkpointing_policy=layers, - mixed_precision=MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16 - ), + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16), forward_prefetch=True, ) precision = None @@ -268,20 +260,14 @@ def train(): trainer = pl.Trainer( accelerator=device, strategy=strategy, - devices=( - getattr(torch, device).device_count() - if args.num_gpu is None - else args.num_gpu - ), + devices=(getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu), default_root_dir=args.output_dir, log_every_n_steps=1, max_epochs=1, precision=precision, ) - tokenizer = transformers.AutoTokenizer.from_pretrained( - args.model, padding_side="left", truncation_side="left" - ) + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side="left", truncation_side="left") tokenizer.pad_token = tokenizer.eos_token data_module = DataModule( tokenizer=tokenizer, diff --git a/examples/medusa/callback.py b/examples/medusa/callback.py index 135f46f0b..33a9d1946 100644 --- a/examples/medusa/callback.py +++ b/examples/medusa/callback.py @@ -1,11 +1,15 @@ import os import time + from dataclasses import dataclass import torch import transformers + from accelerate.utils.constants import FSDP_SHARDING_STRATEGY -from transformers import TrainerControl, TrainerState, TrainingArguments +from transformers import TrainerControl +from transformers import TrainerState +from transformers import TrainingArguments from liger_kernel.utils import infer_device @@ -156,9 +160,7 @@ def on_init_end( 'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second' ) if args.logging_steps != 1: - raise Exception( - "Please set logging_steps=1 to track the efficiency metrics accurately" - ) + raise Exception("Please set logging_steps=1 to track the efficiency metrics accurately") def on_train_begin( self, @@ -178,9 +180,7 @@ def on_log( logs: dict[str, float], **kwargs, ): - if state.global_step < ( - self.state.global_start_step + self.state.n_warmup_steps - ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): return else: # spread self.time, self.memory, self.tps, self.mfu to logs @@ -213,9 +213,7 @@ def on_step_end( control: TrainerControl, **kwargs, ): - if state.global_step < ( - self.state.global_start_step + self.state.n_warmup_steps - ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration # tokens @@ -235,12 +233,8 @@ def on_step_end( avg_step_time = self.state.elapsed_time / self.state.elapsed_step self.time.step = global_step - self.time.step_time_sec = round_to_n_decimal( - step_time, self.precision.n_decimal_time - ) - self.time.avg_step_time_sec = round_to_n_decimal( - avg_step_time, self.precision.n_decimal_time - ) + self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time) + self.time.avg_step_time_sec = round_to_n_decimal(avg_step_time, self.precision.n_decimal_time) self.time.time_to_completion_sec = round_to_n_decimal( avg_step_time * (state.max_steps - global_step), self.precision.n_decimal_time, @@ -250,19 +244,13 @@ def on_step_end( ) # memory - step_peak_memory_allocated = getattr( - torch, self.device - ).memory.max_memory_allocated() - step_peak_memory_reserved = getattr( - torch, self.device - ).memory.max_memory_reserved() + step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated() + step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved() self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory ) - self.state.total_peak_memory_allocated = max( - self.state.total_peak_memory_allocated, step_peak_memory_allocated - ) + self.state.total_peak_memory_allocated = max(self.state.total_peak_memory_allocated, step_peak_memory_allocated) self.memory.total_peak_memory_allocated_MB = round_to_n_decimal( self.state.total_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory, @@ -272,9 +260,7 @@ def on_step_end( step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory ) - self.state.total_peak_memory_reserved = max( - self.state.total_peak_memory_reserved, step_peak_memory_reserved - ) + self.state.total_peak_memory_reserved = max(self.state.total_peak_memory_reserved, step_peak_memory_reserved) self.memory.total_peak_memory_reserved_MB = round_to_n_decimal( self.state.total_peak_memory_reserved / M_BIN_UNIT, @@ -282,9 +268,7 @@ def on_step_end( ) # tokens - step_tokens_seen = ( - state.num_input_tokens_seen - self.state.step_start_tokens_seen - ) + step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen self.state.elapsed_tokens_seen += step_tokens_seen @@ -329,20 +313,14 @@ def on_step_end( num_gpus = EfficiencyCallback._get_effective_num_gpus() step_achieved_tflops = step_flos / step_time / num_gpus / T_DEC_UNIT - avg_achieved_tflops = ( - self.state.elapsed_flos / self.state.elapsed_time / num_gpus / T_DEC_UNIT - ) + avg_achieved_tflops = self.state.elapsed_flos / self.state.elapsed_time / num_gpus / T_DEC_UNIT precision_bits = 16 if args.bf16 or args.fp16 else 32 gpu_peak_tflops = EfficiencyCallback._get_gpu_peak_tflops(precision_bits) - self.mfu.step_MFU = round_to_n_decimal( - step_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU - ) + self.mfu.step_MFU = round_to_n_decimal(step_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU) - self.mfu.avg_MFU = round_to_n_decimal( - avg_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU - ) + self.mfu.avg_MFU = round_to_n_decimal(avg_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU) # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration @@ -357,9 +335,7 @@ def _get_effective_num_gpus(): world_size = int(os.environ.get("WORLD_SIZE", "1")) if transformers.utils.strtobool(os.environ.get("ACCELERATE_USE_FSDP", "false")): - sharding_strategy = os.environ.get( - "FSDP_SHARDING_STRATEGY", FSDP_SHARDING_STRATEGY[0] - ).upper() + sharding_strategy = os.environ.get("FSDP_SHARDING_STRATEGY", FSDP_SHARDING_STRATEGY[0]).upper() # Either specified as string or enum number if sharding_strategy in { diff --git a/examples/medusa/medusa_util.py b/examples/medusa/medusa_util.py index 5b4f9ac9f..7c66e0e08 100644 --- a/examples/medusa/medusa_util.py +++ b/examples/medusa/medusa_util.py @@ -1,15 +1,16 @@ import types -from typing import List, Optional + +from typing import List +from typing import Optional import torch + from torch import nn from torch.nn import CrossEntropyLoss from transformers import PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss class MedusaConfig(PretrainedConfig): @@ -80,12 +81,7 @@ def calculate_loss_contribution( if i == 0: return loss_i if not medusa_only_heads else 0 else: - return ( - loss_i - * medusa_decay_coefficient**i - * medusa_heads_coefficient - * medusa_scheduler_coefficient - ) + return loss_i * medusa_decay_coefficient**i * medusa_heads_coefficient * medusa_scheduler_coefficient def add_medusa_heads( @@ -214,17 +210,11 @@ def forward( lce = LigerFusedLinearCrossEntropyLoss() for i in range(model.medusa_num_heads + 1): shift_hidden_states = ( - hidden_states[..., : -(1 + i), :] - .contiguous() - .view(-1, model.config.hidden_size) + hidden_states[..., : -(1 + i), :].contiguous().view(-1, model.config.hidden_size) ) shift_labels = labels[..., (1 + i) :].contiguous().view(-1) - weight = ( - model.lm_head.weight - if i == 0 - else model.medusa_head[i - 1][-1].weight - ) + weight = model.lm_head.weight if i == 0 else model.medusa_head[i - 1][-1].weight loss_i = lce(weight, shift_hidden_states, shift_labels) loss += calculate_loss_contribution( @@ -236,21 +226,11 @@ def forward( medusa_scheduler_coefficient, ) else: - loss_fct = CrossEntropyLoss() for i in range(model.medusa_num_heads + 1): - medusa_logits_i = ( - medusa_logits[i, :, : -(1 + i)] - .contiguous() - .view(-1, medusa_logits.shape[-1]) - ) + medusa_logits_i = medusa_logits[i, :, : -(1 + i)].contiguous().view(-1, medusa_logits.shape[-1]) medusa_logits_i = medusa_logits_i.float() - medusa_labels = ( - labels[..., (1 + i) :] - .contiguous() - .view(-1) - .to(medusa_logits_i.device) - ) + medusa_labels = labels[..., (1 + i) :].contiguous().view(-1).to(medusa_logits_i.device) loss_i = loss_fct(medusa_logits_i, medusa_labels) @@ -270,9 +250,7 @@ def forward( for i in range(model.medusa_num_heads): medusa_logits.append(model.medusa_head[i](hidden_states)) - return_dict = ( - return_dict if return_dict is not None else model.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else model.config.use_return_dict if not return_dict: output = (medusa_logits,) + outputs[1:] diff --git a/examples/medusa/train.py b/examples/medusa/train.py index 3fa879baa..64a6e5870 100644 --- a/examples/medusa/train.py +++ b/examples/medusa/train.py @@ -19,20 +19,22 @@ import json import os import pathlib -from dataclasses import dataclass, field -from typing import Dict, Optional + +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import Optional import torch import transformers + from callback import EfficiencyCallback from medusa_util import add_medusa_heads from safetensors.torch import save_file from sklearn.model_selection import train_test_split from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullStateDictConfig, - StateDictType, -) +from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from torch.utils.data import Dataset from transformers import Trainer from transformers.trainer_pt_utils import LabelSmoother @@ -53,9 +55,7 @@ class DataArguments: default="Aeala/ShareGPT_Vicuna_unfiltered", metadata={"help": "Path to the training data."}, ) - eval_data_path: str = field( - default=None, metadata={"help": "Path to the evaluation data."} - ) + eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."}) lazy_preprocess: bool = True @@ -66,9 +66,7 @@ class TrainingArguments(transformers.TrainingArguments): optim: str = field(default="adamw_torch") model_max_length: int = field( default=2048, - metadata={ - "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." - }, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) medusa_num_heads: int = field( default=1, @@ -102,9 +100,7 @@ class TrainingArguments(transformers.TrainingArguments): ) medusa_only_heads: bool = field( default=False, - metadata={ - "help": "If train medusa heads only, default is False, the whole model will be trained" - }, + metadata={"help": "If train medusa heads only, default is False, the whole model will be trained"}, ) use_liger: bool = field( default=False, @@ -162,9 +158,7 @@ def preprocess( } for c in conversation["conversations"] ] - prompt = tokenizer.apply_chat_template( - tokenizer_compatible_conv, tokenize=False - ) + prompt = tokenizer.apply_chat_template(tokenizer_compatible_conv, tokenize=False) prompts.append(prompt) conversations.append(tokenizer_compatible_conv) @@ -181,9 +175,7 @@ def preprocess( input_ids = encoding.input_ids # Mask targets. Only compute loss on the assistant outputs. - for conv_index, (conversation, target, prompt) in enumerate( - zip(conversations, targets, prompts) - ): + for conv_index, (conversation, target, prompt) in enumerate(zip(conversations, targets, prompts, strict=False)): # print(conv_index) for turn in conversation: if turn["role"] == "assistant": @@ -192,9 +184,7 @@ def preprocess( start = prompt.index(content.strip()) # stop = start + len(content) indices = [] - for tok_index, (tok_start, tok_stop) in enumerate( - encoding.offset_mapping[conv_index] - ): + for tok_index, (tok_start, tok_stop) in enumerate(encoding.offset_mapping[conv_index]): if tok_stop >= start or tok_start < tok_stop: indices.append(tok_index) target[indices] = encoding.input_ids[conv_index][indices] @@ -273,9 +263,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: return ret -def make_supervised_data_module( - tokenizer: transformers.PreTrainedTokenizer, data_args, test_size=0.05 -) -> Dict: +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, test_size=0.05) -> Dict: """Make dataset and collator for supervised fine-tuning. Args: @@ -286,18 +274,14 @@ def make_supervised_data_module( Returns: dict: A dictionary containing train and eval datasets. """ - dataset_cls = ( - LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset - ) + dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset rank0_print("Loading data...") # Load the entire dataset train_json = json.load(open(data_args.data_path, "r")) # Perform a train-test split based on test_size - train_data, eval_data = train_test_split( - train_json, test_size=test_size, random_state=42 - ) + train_data, eval_data = train_test_split(train_json, test_size=test_size, random_state=42) # Create the train and eval datasets train_dataset = dataset_cls(train_data, tokenizer=tokenizer) eval_dataset = dataset_cls(eval_data, tokenizer=tokenizer) @@ -308,9 +292,7 @@ def make_supervised_data_module( def train(): global local_rank - parser = transformers.HfArgumentParser( - (ModelArguments, DataArguments, TrainingArguments) - ) + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank @@ -326,9 +308,7 @@ def train(): # Making sure the tokenizer works before loading the model. print(tokenizer(["This is a test", "secondary"], padding=True)) - print( - tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]) - ) + print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}])) # Load model and tokenizer model = transformers.AutoModelForCausalLM.from_pretrained( diff --git a/pyproject.toml b/pyproject.toml index 37a3963f8..45f9237b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,4 +23,36 @@ namespaces = false pythonpath = ["src", "."] asyncio_mode = "auto" log_cli = true -log_cli_level = "INFO" \ No newline at end of file +log_cli_level = "INFO" + +[tool.ruff] +line-length = 120 +target-version = "py310" +respect-gitignore = true +src = ["src"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort +] +ignore = ["E501", "B006", "E731", "A002", "E203"] + +exclude = [ + ".git", + "__pycache__", + "benchmark_internal/others", + ".venv", +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.ruff.lint.isort] +known-first-party = ["liger_kernel"] +force-single-line = true +lines-between-types = 1 diff --git a/setup.py b/setup.py index 57ffbc7ce..ebc34985c 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # setup.py import subprocess + from typing import Literal from setuptools import setup diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..41ec78a9d 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -1,15 +1,12 @@ import torch import torch.nn.functional as F -from liger_kernel.chunked_loss.fused_linear_preference import ( - LigerFusedLinearPreferenceBase, -) +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): - @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0): """ Paper: https://arxiv.org/pdf/2401.08417 @@ -30,9 +27,12 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). full_target (torch.Tensor): Non chunked full target tensor beta (float): Weight for the CPO loss + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( + full_target.shape[0] // 2 + ) return loss @staticmethod @@ -45,6 +45,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=True, compiled=True, ): @@ -58,6 +59,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -65,7 +67,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -78,6 +80,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, ): @@ -90,6 +93,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled @@ -102,6 +106,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 5f1b17cf5..6b2b7b0d9 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -1,13 +1,10 @@ import torch import torch.nn.functional as F -from liger_kernel.chunked_loss.fused_linear_preference import ( - LigerFusedLinearPreferenceBase, -) +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): - @staticmethod def preference_loss_fn( chosen_logps, @@ -64,7 +61,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, - compute_nll_loss=True, + compute_nll_loss=False, compiled=True, use_ref_model=True, ): @@ -100,7 +97,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, - compute_nll_loss: bool = True, + compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = False, ): diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 10e726055..e4f06646f 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -2,11 +2,11 @@ from functools import partial import torch + from torch.nn import functional as F class LigerFusedLinearDistillationBase(torch.autograd.Function): - @abstractmethod def distillation_loss_fn(student_logits, teacher_logits, temperature): """ @@ -89,25 +89,25 @@ def _compute_loss( compute_ce_loss (bool): Whether to compute CE loss. loss_kwargs (dict): Additional arguments for the loss function. """ - student_logits_chunk, teacher_logits_chunk, hard_loss = ( - LigerFusedLinearDistillationBase.chunk_forward( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=student_bias, - teacher_bias=teacher_bias, - ignore_index=ignore_index, - compute_ce_loss=compute_ce_loss, - ) + ( + student_logits_chunk, + teacher_logits_chunk, + hard_loss, + ) = LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, ) hard_loss /= full_target.shape[0] - soft_loss = distillation_loss_fn( - student_logits_chunk, teacher_logits_chunk, temperature - ) + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) soft_loss /= full_target.shape[0] loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss @@ -174,17 +174,18 @@ def forward( def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): if student_bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_loss, + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_soft_loss, - chunk_hard_loss, - chunk_student_logits, - chunk_teacher_logits, + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), ), - ) = torch.func.grad_and_value( - loss_func_to_call, argnums=(0, 1, 5), has_aux=True - )( + ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)( student_input_chunk, student_weight, teacher_input_chunk, @@ -195,17 +196,18 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): ) grad_bias.add_(chunk_grad_bias) else: - (chunk_grad_input, chunk_grad_weight), ( - chunk_loss, + ( + (chunk_grad_input, chunk_grad_weight), ( - chunk_soft_loss, - chunk_hard_loss, - chunk_student_logits, - chunk_teacher_logits, + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), ), - ) = torch.func.grad_and_value( - loss_func_to_call, argnums=(0, 1), has_aux=True - )( + ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)( student_input_chunk, student_weight, teacher_input_chunk, @@ -229,9 +231,7 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): for student_input_chunk, teacher_input_chunk, target_chunk in zip( _student_input_chunks, _teacher_input_chunks, _target_chunks ): - grad_input = accumulate_chunk( - student_input_chunk, teacher_input_chunk, target_chunk - ) + grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk) grad_inputs.append(grad_input) ctx.save_for_backward( diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..f389050c0 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -2,11 +2,11 @@ from functools import partial import torch + from torch.nn import functional as F class LigerFusedLinearPreferenceBase(torch.autograd.Function): - @abstractmethod def preference_loss_fn(*args, **kwargs): """ @@ -102,9 +102,7 @@ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk): Fused forward and backward pass for a chunk of input and target. """ if bias is not None: - return torch.func.grad_and_value( - compute_loss, argnums=(0, 1, 3), has_aux=True - )( + return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)( input_chunk, weight, target_chunk, @@ -112,43 +110,47 @@ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk): ref_input_chunk=ref_input_chunk, ) else: - return torch.func.grad_and_value( - compute_loss, argnums=(0, 1), has_aux=True - )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk) + return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( + input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk + ) def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): if bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_loss, + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_chosen_logps, - chunk_rejected_logps, - chunk_chosen_logits_mean, - chunk_rejected_logits_mean, - chunk_nll_loss, - *aux_outputs, + chunk_loss, + ( + chunk_chosen_logps, + chunk_rejected_logps, + chunk_chosen_logits_mean, + chunk_rejected_logits_mean, + chunk_nll_loss, + *aux_outputs, + ), ), ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: - (chunk_grad_input, chunk_grad_weight), ( - chunk_loss, + ( + (chunk_grad_input, chunk_grad_weight), ( - chunk_chosen_logps, - chunk_rejected_logps, - chunk_chosen_logits_mean, - chunk_rejected_logits_mean, - chunk_nll_loss, - *aux_outputs, + chunk_loss, + ( + chunk_chosen_logps, + chunk_rejected_logps, + chunk_chosen_logits_mean, + chunk_rejected_logits_mean, + chunk_nll_loss, + *aux_outputs, + ), ), ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) # Accumulate gradients grad_weight.add_(chunk_grad_weight) grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]]) - grad_rejected_inputs.append( - chunk_grad_input[chosen_target_chunk.shape[0] :] - ) + grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :]) # Accumulate loss loss_acc.add_(chunk_loss) @@ -165,9 +167,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): if len(aggregated_aux_outputs) == 0: for aux in aux_outputs: if aux.ndim == 0: - aggregated_aux_outputs.append( - torch.zeros((), device=aux.device) - ) + aggregated_aux_outputs.append(torch.zeros((), device=aux.device)) else: aggregated_aux_outputs.append([]) @@ -189,12 +189,8 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) if use_ref_model: - _ref_chosen_input_chunks = torch.chunk( - ref_input[:len_chosen], chunks=chunks, dim=0 - ) - _ref_rejected_input_chunks = torch.chunk( - ref_input[len_chosen:], chunks=chunks, dim=0 - ) + _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0) + _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0) for ( chosen_input_chunk, @@ -208,26 +204,15 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): _rejected_input_chunks, _chosen_target_chunks, _rejected_target_chunks, - ( - _ref_chosen_input_chunks - if use_ref_model - else [None] * len(_chosen_input_chunks) - ), - ( - _ref_rejected_input_chunks - if use_ref_model - else [None] * len(_rejected_input_chunks) - ), + (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)), + (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)), + strict=False, ): input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) ref_input_chunk = ( - torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) - if use_ref_model - else None - ) - target_chunk = torch.cat( - [chosen_target_chunk, rejected_target_chunk], dim=0 + torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None ) + target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0) # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation torch._dynamo.mark_dynamic(input_chunk, 1) @@ -265,9 +250,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): @staticmethod def backward(ctx, *grad_output): grad_input, grad_weight, grad_bias = ctx.saved_tensors - if torch.ne( - grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device) - ): + if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)): grad_input = grad_input * grad_output[0][0] grad_weight = grad_weight * grad_output[0][0] grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None @@ -301,9 +284,7 @@ def chunk_forward( loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0) - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( - -1 - ) + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) chosen_logps = average_log_prob[:len_chosen_chunk] @@ -370,13 +351,8 @@ def _compute_loss( ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) + chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]) rejected_logits_mean = rejected_logits.sum() / ( full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] ) @@ -408,7 +384,7 @@ def _compute_loss( else: preference_loss, aux_outputs = preference_loss_outputs, [] - loss = alpha * chosen_nll_loss - preference_loss + loss = alpha * chosen_nll_loss + preference_loss return_vars = ( chosen_logps, rejected_logps, diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..dfed5d3a7 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -1,13 +1,10 @@ import torch import torch.nn.functional as F -from liger_kernel.chunked_loss.fused_linear_preference import ( - LigerFusedLinearPreferenceBase, -) +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): - @staticmethod def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): """ @@ -32,11 +29,10 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): beta (float): Weight for the odds ratio loss. """ log_odds = (chosen_logps - rejected_logps) - ( - torch.log1p(-torch.exp(chosen_logps)) - - torch.log1p(-torch.exp(rejected_logps)) + torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) ) ratio = F.logsigmoid(log_odds) - loss = beta * ratio.sum() / (full_target.shape[0] // 2) + loss = -beta * ratio.sum() / (full_target.shape[0] // 2) chosen_rewards = beta * chosen_logps rejected_rewards = beta * rejected_logps diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 7efa0603d..975bcefab 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -1,16 +1,18 @@ import torch import torch.nn.functional as F -from liger_kernel.chunked_loss.fused_linear_preference import ( - LigerFusedLinearPreferenceBase, -) +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): - @staticmethod def preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 + chosen_logps, + rejected_logps, + full_target, + beta=0.1, + gamma=0.5, + label_smoothing=0.0, ): """ Paper: https://arxiv.org/pdf/2405.14734 @@ -33,9 +35,13 @@ def preference_loss_fn( full_target: Non chunked full target tensor beta (float): beta weight gamma (float): gemma margin term + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - gamma - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( + full_target.shape[0] // 2 + ) + return loss @staticmethod @@ -48,6 +54,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=False, compiled=True, gamma=0.5, @@ -63,6 +70,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compiled=compiled, gamma=gamma, ) @@ -70,7 +78,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearSimPOLoss(torch.nn.Module): @@ -83,6 +91,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, @@ -96,6 +105,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.gamma = gamma @@ -109,6 +119,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, self.gamma, diff --git a/src/liger_kernel/env_report.py b/src/liger_kernel/env_report.py index 6739c5a68..ff3185509 100644 --- a/src/liger_kernel/env_report.py +++ b/src/liger_kernel/env_report.py @@ -1,5 +1,6 @@ import platform import sys + from importlib.metadata import version @@ -27,15 +28,9 @@ def print_env_report(): import torch print(f"PyTorch version: {torch.__version__}") - cuda_version = ( - torch.version.cuda if torch.cuda.is_available() else "Not available" - ) + cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available" print(f"CUDA version: {cuda_version}") - hip_version = ( - torch.version.hip - if torch.cuda.is_available() and torch.version.hip - else "Not available" - ) + hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available" print(f"HIP(ROCm) version: {hip_version}") except ImportError: @@ -58,9 +53,7 @@ def print_env_report(): print("Transformers: Not installed") try: - xpu_version = ( - torch.version.xpu if torch.xpu.is_available() else "XPU Not Available" - ) + xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available" print(f"XPU version: {xpu_version}") except ImportError: print("XPU version: Unable to query") diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index d9f7947b2..c7f049b56 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -1,11 +1,14 @@ import operator + from typing import Optional import torch import triton import triton.language as tl -from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip if compare_version("triton", operator.ge, "3.0.0"): try: @@ -17,8 +20,8 @@ else: from triton.language.math import tanh -_TRUE = tl.constexpr(1) -_FALSE = tl.constexpr(0) +_TRUE: tl.constexpr = tl.constexpr(1) +_FALSE: tl.constexpr = tl.constexpr(0) @triton.jit @@ -103,9 +106,7 @@ def liger_cross_entropy_kernel( # 3. [Online softmax] first pass: find max + sum m = float("-inf") # m is the max value. use the notation from the paper d = 0.0 # d is the sum. use the notation from the paper - ori_X_y = tl.load(X_ptr + y).cast( - tl.float32 - ) # we need to store the original value of X_y for the loss calculation + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation if HAS_SOFTCAPPING: ori_X_y = softcap * tanh(ori_X_y / softcap) @@ -284,14 +285,10 @@ def cross_entropy_forward( return_z_loss, ): if not isinstance(return_z_loss, int): - assert ( - return_z_loss in _bool_to_return_z_loss - ), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" return_z_loss = _bool_to_return_z_loss[return_z_loss] else: - assert ( - return_z_loss in _bool_to_return_z_loss - ), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" BT, V = _input.shape n_rows = BT diff --git a/src/liger_kernel/ops/experimental/embedding.py b/src/liger_kernel/ops/experimental/embedding.py index 985cca100..159b9a66d 100644 --- a/src/liger_kernel/ops/experimental/embedding.py +++ b/src/liger_kernel/ops/experimental/embedding.py @@ -34,9 +34,7 @@ def embedding_forward_kernel( ) output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] - tl.store( - output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :] - ) + tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]) @triton.jit diff --git a/src/liger_kernel/ops/experimental/mm_int8int2.py b/src/liger_kernel/ops/experimental/mm_int8int2.py index 4de17124b..326d53632 100644 --- a/src/liger_kernel/ops/experimental/mm_int8int2.py +++ b/src/liger_kernel/ops/experimental/mm_int8int2.py @@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor: else: packed_tensor_shape = (row_dim, *original_shape[1:]) - packed = torch.zeros( - packed_tensor_shape, device=intweights.device, dtype=torch.uint8 - ) + packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8) unpacked = intweights.to(torch.uint8) def lshift(t: torch.Tensor, bits: int): @@ -327,17 +325,13 @@ def matmul_kernel( def matmul(a, b): - assert ( - a.shape[1] == b.shape[0] * 4 - ), "Incompatible dimensions, the weight matrix need to be packed" + assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape _, N = b.shape # c is in int32 to avoid any overflows or underflows c = torch.empty((M, N), device=a.device, dtype=torch.int32) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) matmul_kernel[grid]( a, b, diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 15481c34d..26de6591a 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -2,12 +2,10 @@ import triton from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel -from liger_kernel.ops.utils import ( - amp_custom_bwd, - amp_custom_fwd, - element_mul_kernel, - is_hip, -) +from liger_kernel.ops.utils import amp_custom_bwd +from liger_kernel.ops.utils import amp_custom_fwd +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling @@ -41,14 +39,10 @@ def fused_linear_cross_entropy_forward( BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) inc_factor = triton.cdiv(V, H) # (V + H - 1) // H - chunk_size = triton.next_power_of_2( - triton.cdiv(BT, inc_factor) - ) # (BT + inc_factor - 1) // inc_factor + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size - grad_weight = ( - torch.zeros_like(weight, device=device) if weight.requires_grad else None - ) + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None grad_input = torch.zeros_like(_input, device=device) grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None # we use fp32 for loss accumulator @@ -144,15 +138,16 @@ def fused_linear_cross_entropy_forward( alpha=1.0, ) - loss = torch.sum(loss_1d) + if reduction == "none": + loss = loss_1d + else: + loss = torch.sum(loss_1d) return loss, grad_input, grad_weight, grad_bias -def fused_linear_cross_entropy_backward( - grad_output, grad_input, grad_weight, grad_bias -): +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. BT, H = grad_input.shape diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py index 288ee7403..f0c4d7bea 100644 --- a/src/liger_kernel/ops/fused_linear_jsd.py +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -4,12 +4,10 @@ import triton from liger_kernel.ops.jsd import _jsd_kernel -from liger_kernel.ops.utils import ( - amp_custom_bwd, - amp_custom_fwd, - element_mul_kernel, - is_hip, -) +from liger_kernel.ops.utils import amp_custom_bwd +from liger_kernel.ops.utils import amp_custom_fwd +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling @@ -43,16 +41,10 @@ def fused_linear_jsd_forward( BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) inc_factor = triton.cdiv(V, H) # (V + H - 1) // H - chunk_size = triton.next_power_of_2( - triton.cdiv(BT, inc_factor) - ) # (BT + inc_factor - 1) // inc_factor + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size - grad_weight = ( - torch.zeros_like(student_weight, device=device) - if student_weight.requires_grad - else None - ) + grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None grad_input = torch.zeros_like(student_input) # we use fp32 for loss accumulator loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) @@ -73,12 +65,8 @@ def fused_linear_jsd_forward( # shape: chunk_size x V # For anything starting from logits to the final JSD loss, we do computation # in FP32 to avoid losing numerical stability. - student_logits_chunk = (student_input_chunk @ student_weight.t()).to( - torch.float32 - ) - teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to( - torch.float32 - ) + student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32) + teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32) chunk_n_rows = student_logits_chunk.shape[0] # unreduced loss @@ -104,9 +92,7 @@ def fused_linear_jsd_forward( dX_ptr=student_prob_chunk, dX_stride=student_prob_chunk.stride(-2), label_ptr=( - shift_labels[start_idx:end_idx] - if has_label - else torch.empty(1, device=device) + shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device) ), # dummy ptr if no label beta=jsd_beta, n_non_ignore=n_non_ignore, @@ -121,9 +107,7 @@ def fused_linear_jsd_forward( student_logits_chunk = ( student_prob_chunk - torch.softmax(student_logits_chunk, dim=-1) - * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to( - student_prob_chunk.shape - ) + * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape) ) / temperature # now we traverse back to grad w.r.t. input to `lm_head` and grad # w.r.t. `lm_head` which should be computed in original dtype @@ -239,7 +223,5 @@ def forward( @amp_custom_bwd def backward(ctx, grad_output): (grad_input, grad_weight) = ctx.saved_tensors - grad_input, grad_weight = fused_linear_jsd_backward( - grad_output, grad_input, grad_weight - ) + grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight) return (grad_input, grad_weight, None, None, None, None, None, None) diff --git a/src/liger_kernel/ops/geglu.py b/src/liger_kernel/ops/geglu.py index cd16ee1a6..83868cf2d 100644 --- a/src/liger_kernel/ops/geglu.py +++ b/src/liger_kernel/ops/geglu.py @@ -4,11 +4,9 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import ( - calculate_settings, - compare_version, - ensure_contiguous, -) +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous if compare_version("triton", operator.ge, "3.0.0"): try: @@ -22,9 +20,7 @@ @triton.jit -def _geglu_tanh_forward_kernel( - a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr -): +def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): program_id = tl.program_id(0).to(tl.int64) # locate start index @@ -49,9 +45,7 @@ def _geglu_tanh_forward_kernel( @triton.jit -def _geglu_tanh_backward_kernel( - dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr -): +def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): program_id = tl.program_id(0).to(tl.int64) # locate start index @@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel( # where z = sqrt(2/pi) * (a + 0.044715 * a^3) term1 = 0.5 * (1 + tanh_result) tanh_sq = tanh_result * tanh_result - term2 = ( - 0.5 - * a_row - * (1 - tanh_sq) - * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) - ) + term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) da_row = dc_row * b_row * (term1 + term2) tl.store(a + col_offsets, da_row, mask=mask) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index aeb4323f3..96633d01a 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -4,7 +4,8 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import compare_version, ensure_contiguous +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous if compare_version("triton", operator.ge, "3.0.0"): try: @@ -73,9 +74,7 @@ def _group_norm_forward_kernel( # Normalize hidden_size_per_channel = hidden_size // channels_per_group - for channel_idx in tl.range( - group_idx * channels_per_group, (group_idx + 1) * channels_per_group - ): + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): W = tl.load(W_ptr + channel_idx) B = tl.load(B_ptr + channel_idx) for i in range(0, hidden_size_per_channel, BLOCK_SIZE): @@ -132,21 +131,15 @@ def _group_norm_backward_kernel( UPSTREAM_ptr += batch_idx * X_row_stride # Mean and rstd are the same shape so have the same strides - mean = tl.load( - Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride - ) - rstd = tl.load( - RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride - ) + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) c1 = 0.0 c2 = 0.0 block_range = tl.arange(0, BLOCK_SIZE) # We need to compute the sum terms of the backprop equations across all channels in the group - for channel_idx in range( - group_idx * channels_per_group, (group_idx + 1) * channels_per_group - ): + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): dW = 0.0 dB = 0.0 # Move the pointers to the correct channel @@ -181,9 +174,7 @@ def _group_norm_backward_kernel( c1 = c1 / N c2 = c2 / N - for channel_idx in tl.range( - group_idx * channels_per_group, (group_idx + 1) * channels_per_group - ): + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): # Move the pointers to the correct channel W = tl.load(W_ptr + channel_idx) for i in range(0, hidden_size, BLOCK_SIZE): @@ -203,9 +194,7 @@ def _group_norm_backward_kernel( x_hat = (X - mean) * rstd wdy = W * UPSTREAM_grad dx = (wdy - (x_hat * c1 + c2)) * rstd - tl.store( - DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask - ) + tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) def group_norm_forward(X, num_channels, num_groups, W, B, eps): @@ -216,9 +205,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): X = X.view(batch_size, num_groups, -1).contiguous() hidden_size = X.shape[-1] BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) - Y = torch.empty( - (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device - ) + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) @@ -307,16 +294,12 @@ def forward( ) ctx.num_channels = num_channels ctx.num_groups = num_groups - ctx.save_for_backward( - X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD - ) + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) return Y @staticmethod @ensure_contiguous def backward(ctx, dY): X, W, B, Mean, RSTD = ctx.saved_tensors - DX, DW, DB = group_norm_backward( - dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups - ) + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) return DX, DW, DB, None, None, None diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index 08048a060..5b6fc5219 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -98,9 +98,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): loss_stride=loss.stride(-2), dX_ptr=dX, dX_stride=dX.stride(-2), - label_ptr=( - shift_labels if has_label else torch.empty(1, device=_input.device) - ), # dummy ptr if no label + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label beta=beta, n_non_ignore=n_non_ignore, ignore_index=ignore_index, @@ -165,9 +163,7 @@ def forward( shift_labels = shift_labels.contiguous() has_label = True - loss, dX = jsd_forward( - _input, target, shift_labels, beta, ignore_index, has_label - ) + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) ctx.save_for_backward(dX) return loss diff --git a/src/liger_kernel/ops/kl_div.py b/src/liger_kernel/ops/kl_div.py index 2e3c6e933..bf3ee9b28 100644 --- a/src/liger_kernel/ops/kl_div.py +++ b/src/liger_kernel/ops/kl_div.py @@ -4,7 +4,8 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import ensure_contiguous, is_hip +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import is_hip def get_num_warps(BLOCK_SIZE): @@ -23,10 +24,10 @@ def get_num_warps(BLOCK_SIZE): REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] -_REDUCTION_MODE_NONE = tl.constexpr(0) -_REDUCTION_MODE_SUM = tl.constexpr(1) -_REDUCTION_MODE_MEAN = tl.constexpr(2) -_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) +_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0) +_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1) +_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3) _str_to_reduction_mode = { "none": _REDUCTION_MODE_NONE.value, @@ -218,9 +219,7 @@ def forward( ctx.save_for_backward(y_true) ctx.reduction = reduction ctx.log_target = log_target - return kldiv_forward_triton( - y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps - ) + return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps) @staticmethod @ensure_contiguous @@ -238,9 +237,7 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: new_grads = torch.empty_like(y_true) - derivative = kldiv_backward_triton( - y_true, grad_output, new_grads, ctx.log_target - ) + derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target) if ctx.reduction == "batchmean": derivative = derivative / y_true.shape[0] diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index 70c372237..6d527c7ee 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -5,11 +5,9 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import ( - calculate_settings, - compare_version, - ensure_contiguous, -) +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous if compare_version("triton", operator.ge, "3.0.0"): try: diff --git a/src/liger_kernel/ops/qwen2vl_mrope.py b/src/liger_kernel/ops/qwen2vl_mrope.py index 103b15604..fbd120f96 100644 --- a/src/liger_kernel/ops/qwen2vl_mrope.py +++ b/src/liger_kernel/ops/qwen2vl_mrope.py @@ -67,36 +67,20 @@ def _triton_qwen2vl_mrope( # program instance (i.e. for the current token) separately # #################################################################### # left half of the head - first_half_q_offsets = ( - tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] - ) - first_half_k_offsets = ( - tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] - ) - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( - tl.arange(0, pad_hd // 2)[None, :] < hd // 2 - ) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( - tl.arange(0, pad_hd // 2)[None, :] < hd // 2 - ) - q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( - sin_row.dtype - ) + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) # right half of the head second_half_q_offsets = first_half_q_offsets + (hd // 2) second_half_k_offsets = first_half_k_offsets + (hd // 2) second_q_mask = first_q_mask second_k_mask = first_k_mask - q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( - sin_row.dtype - ) + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) if not BACKWARD_PASS: # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] @@ -124,7 +108,6 @@ def _triton_qwen2vl_mrope( def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): - # transpose it back to the physical shape because Triton looks at the physical storage # note: q and k are incontiguous before the transformation and will become contiguous after transpose q = q.transpose(1, 2) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index fff199a93..5fc9674fb 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -17,12 +17,10 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import ( - calculate_settings, - compare_version, - ensure_contiguous, - torch_to_triton_dtype, -) +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import torch_to_triton_dtype if compare_version("triton", operator.ge, "3.0.0"): try: @@ -35,9 +33,9 @@ from triton.language.math import rsqrt -_CASTING_MODE_NONE = tl.constexpr(-1) -_CASTING_MODE_LLAMA = tl.constexpr(0) -_CASTING_MODE_GEMMA = tl.constexpr(1) +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) @triton.jit @@ -177,9 +175,7 @@ def _rms_norm_backward_kernel( dX_row = rstd_row * m - dX_row += (rstd_row) * ( - -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row - ) + dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) # calculate the gradient of W if casting_mode == _CASTING_MODE_LLAMA: @@ -207,14 +203,10 @@ def _rms_norm_backward_kernel( def rms_norm_forward(X, W, eps, offset, casting_mode): if not isinstance(casting_mode, int): - assert ( - casting_mode in _str_to_casting_mode - ), f"Invalid casting mode: {casting_mode}" + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" casting_mode = _str_to_casting_mode[casting_mode] else: - assert ( - casting_mode in _str_to_casting_mode.values() - ), f"Invalid casting mode: {casting_mode}" + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" shape = X.shape dim = shape[-1] @@ -225,17 +217,11 @@ def rms_norm_forward(X, W, eps, offset, casting_mode): Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) # RSTD is to cache rstd for each row # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode - rstd_dtype = ( - torch.float32 - if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) - else X.dtype - ) + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) # Check constraints. - assert ( - X.shape[1] == W.shape[0] - ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" _rms_norm_forward_kernel[(n_rows,)]( Y, @@ -256,9 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode): return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode -def rms_norm_backward( - dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place -): +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place): shape = dY.shape dim = shape[-1] dY = dY.view(-1, dim) @@ -340,9 +324,7 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True): X: (B, T, H) or (BxT, H) W: (H,) """ - Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward( - X, W, eps, offset, casting_mode - ) + Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode) ctx.offset = offset ctx.casting_mode = casting_mode ctx.in_place = in_place diff --git a/src/liger_kernel/ops/rope.py b/src/liger_kernel/ops/rope.py index 0cd88efeb..ebf4702c3 100644 --- a/src/liger_kernel/ops/rope.py +++ b/src/liger_kernel/ops/rope.py @@ -15,6 +15,7 @@ def _triton_rope( sin_row_stride, sl, bs: tl.constexpr, + cos_bs: tl.constexpr, n_qh: tl.constexpr, n_kh: tl.constexpr, hd: tl.constexpr, @@ -29,7 +30,7 @@ def _triton_rope( # k size: (bsz, seq_len, num_kv_heads, head_dim) # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) - # cos size: (1, seq_len, head_dim) + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) # stride: (seq_len * head_dim, head_dim, 1) pid = tl.program_id(0) @@ -48,9 +49,19 @@ def _triton_rope( # and pid % sl to get the sequence index. # 2. We only need the left half of cos and sin matrix because the right half is just # a clone of the left half. - cos_row_idx = pid % (sl) - cos = cos + cos_row_idx * cos_row_stride - sin = sin + cos_row_idx * sin_row_stride + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + cos_offsets = tl.arange(0, pad_hd // 2) cos_mask = cos_offsets < hd // 2 cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) @@ -61,36 +72,20 @@ def _triton_rope( # program instance (i.e. for the current token) separately # #################################################################### # left half of the head - first_half_q_offsets = ( - tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] - ) - first_half_k_offsets = ( - tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] - ) - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( - tl.arange(0, pad_hd // 2)[None, :] < hd // 2 - ) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( - tl.arange(0, pad_hd // 2)[None, :] < hd // 2 - ) - q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( - sin_row.dtype - ) + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) # right half of the head second_half_q_offsets = first_half_q_offsets + (hd // 2) second_half_k_offsets = first_half_k_offsets + (hd // 2) second_q_mask = first_q_mask second_k_mask = first_k_mask - q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( - sin_row.dtype - ) + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) if not BACKWARD_PASS: # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] @@ -118,7 +113,6 @@ def _triton_rope( def rope_forward(q, k, cos, sin): - # transpose it back to the physical shape because Triton looks at the physical storage # note: q and k are incontiguous before the transformation and will become contiguous after transpose q = q.transpose(1, 2) @@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin): k = k.contiguous() cos = cos.contiguous() sin = sin.contiguous() + cos_batch_size = cos.shape[0] _triton_rope[(n_row,)]( q, @@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin): sin.stride(-2), seq_len, batch_size, + cos_batch_size, n_q_head, n_kv_head, head_dim, @@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin): dk = dk.transpose(1, 2) batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] n_kv_head = dk.shape[2] pad_hd = triton.next_power_of_2(head_dim) pad_n_q_head = triton.next_power_of_2(n_q_head) @@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin): sin.stride(-2), seq_len, batch_size, + cos_batch_size, n_q_head, n_kv_head, head_dim, @@ -221,8 +219,8 @@ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ q size: (bsz, n_q_head, seq_len, head_dim) k size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (1, seq_len, head_dim) - sin size: (1, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) """ q, k, cos, sin = rope_forward(q, k, cos, sin) ctx.save_for_backward(cos, sin) @@ -232,8 +230,8 @@ def backward(ctx, dq, dk): """ dq size: (bsz, n_q_head, seq_len, head_dim) dk size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (1, seq_len, head_dim) - sin size: (1, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) """ cos, sin = ctx.saved_tensors diff --git a/src/liger_kernel/ops/swiglu.py b/src/liger_kernel/ops/swiglu.py index 5dffa1133..a1feca26b 100644 --- a/src/liger_kernel/ops/swiglu.py +++ b/src/liger_kernel/ops/swiglu.py @@ -2,7 +2,8 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import calculate_settings, ensure_contiguous +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import ensure_contiguous @triton.jit @@ -11,9 +12,7 @@ def silu(x): @triton.jit -def _swiglu_forward_kernel( - a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr -): +def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): program_id = tl.program_id(0).to(tl.int64) # locate start index @@ -32,9 +31,7 @@ def _swiglu_forward_kernel( @triton.jit -def _swiglu_backward_kernel( - dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr -): +def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): program_id = tl.program_id(0).to(tl.int64) # locate start index @@ -84,7 +81,6 @@ def swiglu_forward(a, b): def swiglu_backward(a, b, dc): - ori_shape = dc.shape n_cols = ori_shape[-1] dc = dc.view(-1, n_cols) diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index d87adac44..8a15bf8d8 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -13,11 +13,13 @@ import functools import importlib import operator + from typing import Callable import torch import triton import triton.language as tl + from packaging.version import Version from liger_kernel.utils import infer_device diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index ffb8235cc..cbf330cc2 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -1,31 +1,23 @@ -from liger_kernel.transformers.auto_model import ( # noqa: F401 - AutoLigerKernelForCausalLM, -) +from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401 from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401 -from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401 - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401 from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 -from liger_kernel.transformers.monkey_patch import ( # noqa: F401 - _apply_liger_kernel, - _apply_liger_kernel_to_instance, - apply_liger_kernel_to_gemma, - apply_liger_kernel_to_gemma2, - apply_liger_kernel_to_llama, - apply_liger_kernel_to_mistral, - apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_mllama, - apply_liger_kernel_to_phi3, - apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_qwen2_vl, -) +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401 +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401 from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 -from liger_kernel.transformers.swiglu import ( # noqa: F401 - LigerBlockSparseTop2MLP, - LigerPhi3SwiGLUMLP, - LigerSwiGLUMLP, -) +from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401 +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401 +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401 diff --git a/src/liger_kernel/transformers/auto_model.py b/src/liger_kernel/transformers/auto_model.py index 42527a1ee..130a03863 100644 --- a/src/liger_kernel/transformers/auto_model.py +++ b/src/liger_kernel/transformers/auto_model.py @@ -1,11 +1,10 @@ import inspect -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig +from transformers import AutoModelForCausalLM -from liger_kernel.transformers.monkey_patch import ( - MODEL_TYPE_TO_APPLY_LIGER_FN, - _apply_liger_kernel, -) +from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel def _get_model_config(model_dir, **model_init_kwargs): @@ -34,12 +33,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] apply_fn_signature = inspect.signature(apply_fn) - applicable_kwargs = { - key: value - for key, value in kwargs.items() - if key not in apply_fn_signature.parameters - } + applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters} - return super().from_pretrained( - pretrained_model_name_or_path, *model_args, **applicable_kwargs - ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs) diff --git a/src/liger_kernel/transformers/experimental/embedding.py b/src/liger_kernel/transformers/experimental/embedding.py index efe81f1bd..bf76ad6c3 100644 --- a/src/liger_kernel/transformers/experimental/embedding.py +++ b/src/liger_kernel/transformers/experimental/embedding.py @@ -7,9 +7,7 @@ class LigerEmbedding(nn.Module): - def __init__( - self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None - ): + def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 60d472129..dd34fafb1 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -1,9 +1,7 @@ from typing import Optional from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction -from liger_kernel.ops.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyFunction, -) +from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.ops.geglu import LigerGELUMulFunction from liger_kernel.ops.group_norm import LigerGroupNormFunction @@ -162,9 +160,7 @@ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1): return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) -def liger_rms_norm( - X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True -): +def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True): return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index c13148f91..0c6ce0328 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -2,9 +2,7 @@ import torch -from liger_kernel.ops.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyFunction, -) +from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): diff --git a/src/liger_kernel/transformers/geglu.py b/src/liger_kernel/transformers/geglu.py index 89376f18a..f2ee8f6d1 100644 --- a/src/liger_kernel/transformers/geglu.py +++ b/src/liger_kernel/transformers/geglu.py @@ -19,7 +19,4 @@ def __init__(self, config): # So we can safely assume we use tanh approximation form all the time def forward(self, x): - - return self.down_proj( - LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)) - ) + return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) diff --git a/src/liger_kernel/transformers/group_norm.py b/src/liger_kernel/transformers/group_norm.py index d0cc6799b..ca3d314e2 100644 --- a/src/liger_kernel/transformers/group_norm.py +++ b/src/liger_kernel/transformers/group_norm.py @@ -27,19 +27,13 @@ def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones self.num_channels = num_channels self.num_groups = num_groups self.eps = eps - self.weight = nn.Parameter( - torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels) - ) - self.bias = nn.Parameter( - torch.randn(num_channels) if bias else torch.zeros(num_channels) - ) + self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)) + self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels)) self.variance_epsilon = eps def forward(self, hidden_states): # hidden_states: (batch_size, num_channels, *) - assert ( - hidden_states.dim() >= 3 - ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" + assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" assert ( hidden_states.size(1) == self.num_channels ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" diff --git a/src/liger_kernel/transformers/jsd.py b/src/liger_kernel/transformers/jsd.py index c9d78ff8a..843b79ab2 100644 --- a/src/liger_kernel/transformers/jsd.py +++ b/src/liger_kernel/transformers/jsd.py @@ -67,6 +67,4 @@ def forward( log_p: torch.Tensor, shift_labels: Optional[torch.LongTensor] = None, ): - return LigerJSDFunction.apply( - log_q, log_p, shift_labels, self.beta, self.ignore_index - ) + return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index) diff --git a/src/liger_kernel/transformers/kl_div.py b/src/liger_kernel/transformers/kl_div.py index 8bd50dad0..878578604 100644 --- a/src/liger_kernel/transformers/kl_div.py +++ b/src/liger_kernel/transformers/kl_div.py @@ -9,6 +9,4 @@ def __init__(self, eps: float = 1e-10, *args, **kwargs): self.eps = eps def forward(self, y_pred, y_true): - return LigerKLDivLossFunction.apply( - y_pred, y_true, self.reduction, self.log_target, self.eps - ) + return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps) diff --git a/src/liger_kernel/transformers/layer_norm.py b/src/liger_kernel/transformers/layer_norm.py index 9590898a7..135d5fde7 100644 --- a/src/liger_kernel/transformers/layer_norm.py +++ b/src/liger_kernel/transformers/layer_norm.py @@ -13,18 +13,12 @@ def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"): ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" self.hidden_size = hidden_size self.eps = eps - self.weight = nn.Parameter( - torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size) - ) - self.bias = nn.Parameter( - torch.randn(hidden_size) if bias else torch.zeros(hidden_size) - ) + self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)) + self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): - return LigerLayerNormFunction.apply( - hidden_states, self.weight, self.bias, self.variance_epsilon - ) + return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon) def extra_repr(self): return f"{self.hidden_size}, eps={self.eps}" diff --git a/src/liger_kernel/transformers/model/gemma.py b/src/liger_kernel/transformers/model/gemma.py index f7b9814e9..8fd8eab7b 100644 --- a/src/liger_kernel/transformers/model/gemma.py +++ b/src/liger_kernel/transformers/model/gemma.py @@ -1,27 +1,23 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma.modeling_gemma import ( - _CONFIG_FOR_DOC, - GEMMA_INPUTS_DOCSTRING, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC +from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, @@ -64,19 +60,11 @@ def lce_forward_deprecated( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -139,9 +127,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -188,19 +174,11 @@ def lce_forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py index 8ce5aa696..59d2ed48a 100644 --- a/src/liger_kernel/transformers/model/gemma2.py +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -1,22 +1,20 @@ import logging -from typing import Optional, Tuple, Union + +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.cache_utils import HybridCache from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma2.modeling_gemma2 import ( - _CONFIG_FOR_DOC, - GEMMA2_INPUTS_DOCSTRING, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) - -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC +from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss logger = logging.getLogger(__name__) @@ -63,19 +61,11 @@ def lce_forward_deprecated( "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -104,9 +94,7 @@ def lce_forward_deprecated( shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) shift_labels = shift_labels.view(-1) - lce = LigerFusedLinearCrossEntropyLoss( - softcap=self.config.final_logit_softcapping - ) + lce = LigerFusedLinearCrossEntropyLoss(softcap=self.config.final_logit_softcapping) loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) else: @@ -146,9 +134,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -201,19 +187,11 @@ def lce_forward( "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py index b8d12c76a..e4dde0f55 100644 --- a/src/liger_kernel/transformers/model/llama.py +++ b/src/liger_kernel/transformers/model/llama.py @@ -1,30 +1,27 @@ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch import torch.nn.functional as F + from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama.modeling_llama import ( - _CONFIG_FOR_DOC, - LLAMA_INPUTS_DOCSTRING, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) - -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss if TYPE_CHECKING: from transformers.cache_utils import Cache @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, @@ -67,19 +64,11 @@ def lce_forward_deprecated( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -113,13 +102,8 @@ def lce_forward_deprecated( else: if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split( - self.vocab_size // self.config.pretraining_tp, dim=0 - ) - logits = [ - F.linear(hidden_states, lm_head_slices[i]) - for i in range(self.config.pretraining_tp) - ] + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) @@ -151,9 +135,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -201,19 +183,11 @@ def lce_forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/liger_kernel/transformers/model/mistral.py b/src/liger_kernel/transformers/model/mistral.py index cc2ab9b76..30568adaf 100644 --- a/src/liger_kernel/transformers/model/mistral.py +++ b/src/liger_kernel/transformers/model/mistral.py @@ -1,27 +1,23 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - _CONFIG_FOR_DOC, - MISTRAL_INPUTS_DOCSTRING, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC +from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -65,19 +61,11 @@ def lce_forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index 22fea53da..c0bde2634 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -1,27 +1,23 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import MoeCausalLMOutputWithPast -from transformers.models.mixtral.modeling_mixtral import ( - _CONFIG_FOR_DOC, - MIXTRAL_INPUTS_DOCSTRING, - load_balancing_loss_func, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC +from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, @@ -38,7 +34,7 @@ def lce_forward_deprecated( cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" - Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy + Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy Args: @@ -66,25 +62,15 @@ def lce_forward_deprecated( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -138,9 +124,7 @@ def lce_forward_deprecated( attention_mask, ) if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( - loss.device - ) # make sure to reside in the same device + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: output = (logits,) + outputs[1:] @@ -160,9 +144,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy def lce_forward( self, @@ -212,25 +194,15 @@ def lce_forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -288,9 +260,7 @@ def lce_forward( attention_mask, ) if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( - loss.device - ) # make sure to reside in the same device + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/liger_kernel/transformers/model/mllama.py b/src/liger_kernel/transformers/model/mllama.py index fcf45293e..6f7258554 100644 --- a/src/liger_kernel/transformers/model/mllama.py +++ b/src/liger_kernel/transformers/model/mllama.py @@ -1,24 +1,22 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig") def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, @@ -66,19 +64,11 @@ def lce_forward_deprecated( I love the idea of snowflakes gently falling, each one ``` """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -143,9 +133,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig") def lce_forward( self, input_ids: torch.LongTensor = None, @@ -198,19 +186,11 @@ def lce_forward( I love the idea of snowflakes gently falling, each one ``` """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py index e860582ce..696fadc7e 100644 --- a/src/liger_kernel/transformers/model/phi3.py +++ b/src/liger_kernel/transformers/model/phi3.py @@ -1,26 +1,22 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.phi3.modeling_phi3 import ( - _CONFIG_FOR_DOC, - PHI3_INPUTS_DOCSTRING, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC +from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, @@ -64,19 +60,11 @@ def lce_forward_deprecated( 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -138,9 +126,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -202,19 +188,11 @@ def lce_forward( f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index b019e4c88..6f5045f75 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -1,26 +1,22 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.qwen2.modeling_qwen2 import ( - _CONFIG_FOR_DOC, - QWEN2_INPUTS_DOCSTRING, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC +from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, @@ -63,19 +59,11 @@ def lce_forward_deprecated( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -137,9 +125,7 @@ def lce_forward_deprecated( @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -187,19 +173,11 @@ def lce_forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py index 983d2d946..474c68fc5 100644 --- a/src/liger_kernel/transformers/model/qwen2_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -1,28 +1,24 @@ -from typing import List, Optional, Tuple, Union +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import torch + from packaging import version from torch.nn import CrossEntropyLoss from transformers import __version__ as transformers_version -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - _CONFIG_FOR_DOC, - QWEN2_VL_INPUTS_DOCSTRING, - Qwen2VLCausalLMOutputWithPast, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC +from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) +@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -82,19 +78,11 @@ def lce_forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) @@ -144,9 +132,7 @@ def lce_forward( # transformers and leads to failed tests or users noticing differences in results. # TODO: remove above conditional when liger drops support for transformers<4.47.0 if position_ids is None and input_ids is not None: - position_ids, _ = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) + position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) outputs = self.model( input_ids=None, diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 01b5f6efe..eafce145e 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1,9 +1,11 @@ import inspect import logging + from functools import partial from typing import Callable import transformers + from packaging import version from transformers import PreTrainedModel @@ -12,38 +14,24 @@ from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward -from liger_kernel.transformers.model.gemma import ( - lce_forward_deprecated as gemma_lce_forward_deprecated, -) +from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward -from liger_kernel.transformers.model.gemma2 import ( - lce_forward_deprecated as gemma2_lce_forward_deprected, -) +from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward -from liger_kernel.transformers.model.llama import ( - lce_forward_deprecated as llama_lce_forward_deprecated, -) +from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward -from liger_kernel.transformers.model.mixtral import ( - lce_forward_deprecated as mixtral_lce_forward_deprecated, -) +from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward -from liger_kernel.transformers.model.phi3 import ( - lce_forward_deprecated as phi3_lce_forward_deprecated, -) +from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward -from liger_kernel.transformers.model.qwen2 import ( - lce_forward_deprecated as qwen2_lce_forward_deprecated, -) +from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb -from liger_kernel.transformers.swiglu import ( - LigerBlockSparseTop2MLP, - LigerPhi3SwiGLUMLP, - LigerSwiGLUMLP, -) +from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP transformer_version = version.parse(transformers.__version__) @@ -57,23 +45,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable): module.__dict__[method_name] = new_method.__get__(module, module.__class__) -def _patch_rms_norm_module( - module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True -): +def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True): module.offset = offset module.casting_mode = casting_mode - module.variance_epsilon = ( - getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps - ) + module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps module.in_place = in_place _bind_method_to_module(module, "forward", LigerRMSNorm.forward) _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) def _patch_layer_norm_module(module, eps=1e-6): - module.variance_epsilon = ( - getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps - ) + module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps module.hidden_size = module.normalized_shape _bind_method_to_module(module, "forward", LigerLayerNorm.forward) _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr) @@ -145,9 +127,7 @@ def apply_liger_kernel_to_llama( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -184,17 +164,13 @@ def apply_liger_kernel_to_mllama( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.mllama import modeling_mllama - from transformers.models.mllama.modeling_mllama import ( - MllamaForCausalLM, - MllamaForConditionalGeneration, - MllamaTextModel, - MllamaVisionModel, - ) + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration + from transformers.models.mllama.modeling_mllama import MllamaTextModel + from transformers.models.mllama.modeling_mllama import MllamaVisionModel from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward - from liger_kernel.transformers.model.mllama import ( - lce_forward_deprecated as mllama_lce_forward_deprecated, - ) + from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated if rope: modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -241,9 +217,7 @@ def apply_liger_kernel_to_mllama( _patch_rms_norm_module(text_model.norm) for decoder_layer in text_model.layers: if swiglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -317,9 +291,7 @@ def apply_liger_kernel_to_mistral( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -391,9 +363,7 @@ def apply_liger_kernel_to_mixtral( for decoder_layer in base_model.layers: if swiglu: for expert in decoder_layer.block_sparse_moe.experts: - _bind_method_to_module( - expert, "forward", LigerBlockSparseTop2MLP.forward - ) + _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -431,12 +401,8 @@ def apply_liger_kernel_to_gemma( from transformers.models.gemma.modeling_gemma import GemmaModel # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 - LigerRMSNormForGemma = partial( - LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" - ) - _patch_rms_norm_module_for_gemma = partial( - _patch_rms_norm_module, casting_mode="gemma", offset=1.0 - ) + LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma") + _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0) if rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -471,9 +437,7 @@ def apply_liger_kernel_to_gemma( for decoder_layer in base_model.layers: if geglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerGEGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) if rms_norm: _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) @@ -510,9 +474,7 @@ def apply_liger_kernel_to_gemma2( from transformers.models.gemma2 import modeling_gemma2 from transformers.models.gemma2.modeling_gemma2 import Gemma2Model - LigerRMSNormForGemma2 = partial( - LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False - ) + LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False) _patch_rms_norm_module_for_gemma2 = partial( _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False ) @@ -551,20 +513,12 @@ def apply_liger_kernel_to_gemma2( for decoder_layer in base_model.layers: if geglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerGEGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) if rms_norm: _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) - _patch_rms_norm_module_for_gemma2( - decoder_layer.post_attention_layernorm - ) - _patch_rms_norm_module_for_gemma2( - decoder_layer.pre_feedforward_layernorm - ) - _patch_rms_norm_module_for_gemma2( - decoder_layer.post_feedforward_layernorm - ) + _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm) def apply_liger_kernel_to_qwen2( @@ -633,9 +587,7 @@ def apply_liger_kernel_to_qwen2( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -674,14 +626,10 @@ def apply_liger_kernel_to_qwen2_vl( from transformers.models.qwen2_vl import modeling_qwen2_vl from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel - from liger_kernel.transformers.model.qwen2_vl import ( - lce_forward as qwen2_vl_lce_forward, - ) + from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward if rope: - modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = ( - liger_multimodal_rotary_pos_emb - ) + modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb if rms_norm: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm @@ -712,9 +660,7 @@ def apply_liger_kernel_to_qwen2_vl( _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -783,9 +729,7 @@ def apply_liger_kernel_to_phi3( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module( - decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward - ) + _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -826,24 +770,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None: return if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): - logger.info( - f"There are currently no Liger kernels supported for model type: {model_type}." - ) + logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.") return apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] apply_fn_signature = inspect.signature(apply_fn) # Filter out the keyword arguments that are not supported by the apply function - applicable_kwargs = { - key: value - for key, value in kwargs.items() - if key in apply_fn_signature.parameters - } + applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters} - logger.info( - f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}" - ) + logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}") # Assume this is invoked pre-model initialization, so we only need to patch transformers code apply_fn(**applicable_kwargs) @@ -857,20 +793,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: - model: the model instance to apply Liger kernels to - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function. """ - model_type = getattr(model, "config", None) and getattr( - model.config, "model_type", None - ) + model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None) if not model_type: - logger.info( - "Model type could not be determined from model config. No Liger kernels will be applied." - ) + logger.info("Model type could not be determined from model config. No Liger kernels will be applied.") return if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): - logger.info( - f"There are currently no Liger kernels supported for model type: {model_type}." - ) + logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.") return apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] @@ -878,11 +808,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: apply_fn_signature = inspect.signature(apply_fn) # Filter out the keyword arguments that are not supported by the apply function - applicable_kwargs = { - key: value - for key, value in kwargs.items() - if key in apply_fn_signature.parameters - } + applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters} logger.info( f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}" ) diff --git a/src/liger_kernel/transformers/rms_norm.py b/src/liger_kernel/transformers/rms_norm.py index e2b472aa7..d3a50e02f 100644 --- a/src/liger_kernel/transformers/rms_norm.py +++ b/src/liger_kernel/transformers/rms_norm.py @@ -19,9 +19,7 @@ def __init__( "ones", "zeros", ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" - self.weight = nn.Parameter( - torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size) - ) + self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)) self.variance_epsilon, self.offset, self.casting_mode, self.in_place = ( eps, offset, @@ -40,4 +38,6 @@ def forward(self, hidden_states): ) def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}" + return ( + f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}" + ) diff --git a/src/liger_kernel/transformers/rope.py b/src/liger_kernel/transformers/rope.py index a40b29af3..de060ea01 100644 --- a/src/liger_kernel/transformers/rope.py +++ b/src/liger_kernel/transformers/rope.py @@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Args: q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). - cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). - sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). + cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim). + sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim). position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. diff --git a/src/liger_kernel/transformers/swiglu.py b/src/liger_kernel/transformers/swiglu.py index 42f4df106..01961e57a 100644 --- a/src/liger_kernel/transformers/swiglu.py +++ b/src/liger_kernel/transformers/swiglu.py @@ -16,10 +16,7 @@ def __init__(self, config): raise ValueError(f"Activation function {config.hidden_act} not supported.") def forward(self, x): - - return self.down_proj( - LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)) - ) + return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) class LigerBlockSparseTop2MLP(nn.Module): @@ -36,7 +33,6 @@ def __init__(self, config): raise ValueError(f"Activation function {config.hidden_act} not supported.") def forward(self, x): - return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x))) @@ -51,9 +47,7 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_proj = nn.Linear( - self.hidden_size, 2 * self.intermediate_size, bias=False - ) + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) if config.hidden_act not in ["silu", "swish"]: raise ValueError(f"Activation function {config.hidden_act} not supported.") diff --git a/src/liger_kernel/transformers/trainer/__init__.py b/src/liger_kernel/transformers/trainer/__init__.py index b677d868b..df5de2038 100644 --- a/src/liger_kernel/transformers/trainer/__init__.py +++ b/src/liger_kernel/transformers/trainer/__init__.py @@ -1,6 +1,4 @@ try: - from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401 - LigerORPOTrainer, - ) + from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401 except ImportError: raise ImportError("Please `pip install trl` to use LigerORPOTrainer") diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 3605b9f1b..ca54733d0 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -1,7 +1,14 @@ -from typing import Any, Callable, Dict, List, Literal, Tuple, Union +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal +from typing import Tuple +from typing import Union import torch import torch.nn as nn + from torch.distributed.fsdp import FullyShardedDataParallel from trl.trainer import ORPOTrainer @@ -17,7 +24,7 @@ class _FSDPForwardRedirection: This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`) - will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of + will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just the `lm_head` part of a model, we need this trick too to properly get its params all-gathered. @@ -62,9 +69,7 @@ def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: class LigerORPOTrainer(ORPOTrainer): def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] - ) -> Tuple[ - torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor - ]: + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. @@ -79,9 +84,7 @@ def concatenated_forward( model_kwargs = ( { - "decoder_input_ids": self._shift_right( - concatenated_batch["concatenated_labels"] - ), + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), } if self.is_encoder_decoder else {} @@ -109,14 +112,10 @@ def concatenated_forward( **model_kwargs, ) - orpo_loss_fn = LigerFusedLinearORPOLoss( - ignore_index=self.label_pad_token_id, beta=self.beta - ) + orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) def orpo_partial(lm_head, last_hidden_state, concatenated_labels): - return orpo_loss_fn( - lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias - ) + return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias) orpo_loss, aux_outputs = _FSDPForwardRedirection()( model, @@ -125,6 +124,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels): outputs.last_hidden_state, concatenated_batch["concatenated_labels"], ) + # if aux_loss_enabled, add the aux_loss to the orpo_loss + if self.aux_loss_enabled: + orpo_loss += self.aux_loss_coef * outputs.aux_loss + return orpo_loss, aux_outputs def get_batch_loss_metrics( @@ -145,9 +148,7 @@ def get_batch_loss_metrics( ) = aux_outputs[:5] # return loss, metrics - chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[ - 5: - ] + chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:] reward_accuracies = (chosen_rewards > rejected_rewards).float() diff --git a/src/liger_kernel/triton/__init__.py b/src/liger_kernel/triton/__init__.py index 88c282f8b..d373966a9 100644 --- a/src/liger_kernel/triton/__init__.py +++ b/src/liger_kernel/triton/__init__.py @@ -1,3 +1 @@ -from liger_kernel.triton.monkey_patch import ( # noqa: F401 - apply_liger_triton_cache_manager, -) +from liger_kernel.triton.monkey_patch import apply_liger_triton_cache_manager # noqa: F401 diff --git a/src/liger_kernel/triton/monkey_patch.py b/src/liger_kernel/triton/monkey_patch.py index 70863f4e3..bac4a6a0d 100644 --- a/src/liger_kernel/triton/monkey_patch.py +++ b/src/liger_kernel/triton/monkey_patch.py @@ -37,6 +37,4 @@ def apply_liger_triton_cache_manager(): Experimental feature to get around transient FileNotFoundError in triton compilation. For more details please see https://github.com/triton-lang/triton/pull/4295 """ - os.environ["TRITON_CACHE_MANAGER"] = ( - "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" - ) + os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index f0fef7734..5f5e15ad1 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -1,4 +1,3 @@ -from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed from typing import Tuple import pytest @@ -9,6 +8,9 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.functional import liger_fused_linear_cpo from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed device = infer_device() @@ -60,19 +62,17 @@ def alignment_loss( if self.loss_type == "sigmoid": # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) elif self.loss_type == "simpo": logits = logits - (self.simpo_gamma / self.beta) losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) else: - raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']" - ) + raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']") return losses @@ -86,17 +86,17 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.cpo_loss = HFCPOLoss( ignore_index=ignore_index, beta=beta, loss_type=loss_type, + label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics @@ -114,13 +114,15 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -142,11 +144,22 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] -) +@pytest.mark.parametrize("ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + alpha, + label_smoothing, ): B = 2 * B # cpo loss requires B to be even @@ -157,6 +170,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -165,6 +179,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( @@ -172,9 +187,7 @@ def test_correctness( ) if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) @@ -270,12 +283,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1, aggregated_aux_outputs1 = LigerFusedLinearCPOFunction.apply( - input1, weight1, target, bias1 - ) - loss2, aggregated_aux_outputs2 = liger_fused_linear_cpo( - input2, weight2, target, bias2 - ) + loss1, aggregated_aux_outputs1 = LigerFusedLinearCPOFunction.apply(input1, weight1, target, bias1) + loss2, aggregated_aux_outputs2 = liger_fused_linear_cpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 0ac8faeb8..ab18a0f24 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -1,5 +1,3 @@ -from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed - import pytest import torch import torch.nn.functional as F @@ -8,6 +6,9 @@ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction from liger_kernel.chunked_loss.functional import liger_fused_linear_dpo from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed device = infer_device() @@ -23,10 +24,17 @@ class HFDPOLoss(HFAlignmentLoss): """ def __init__( - self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, ): super().__init__( - beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, ) def alignment_loss( @@ -61,18 +69,18 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) - self.ref_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=ref_bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) self.dpo_loss = HFDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ).get_batch_loss_metrics def forward(self, x, ref_x, y): @@ -95,18 +103,18 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) - self.ref_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=ref_bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) self.dpo_loss = LigerFusedLinearDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ) def forward(self, x, ref_x, y): @@ -132,14 +140,27 @@ def forward(self, x, ref_x, y): "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-1), - (1.0, torch.float32, 2e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), ], ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) @pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + compute_nll_loss, + ignore_index, + beta, ): B = 2 * B # dpo loss requires B to be even @@ -149,6 +170,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -158,6 +180,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -165,26 +188,22 @@ def test_correctness( torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn( V, H, device=device, dtype=dtype ) - torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = ( - torch.randn(V, H, device=device, dtype=dtype) + torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype ) if bias: - torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) + torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) if ref_bias: - torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = ( - torch.randn(V, device=device, dtype=dtype) + torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = torch.randn( + V, device=device, dtype=dtype ) _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) - ref_input = ( - torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar - ) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar target = torch.randint( 0, @@ -251,16 +270,15 @@ def test_correctness( ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) -def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss): B = 2 * B _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) - ref_input = ( - torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar - ) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar target = torch.randint( 0, @@ -290,10 +308,28 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( - input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1 + input1, + weight1, + target, + bias1, + ref_input, + ref_weight1, + ref_bias1, + -100, + 0.1, + compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( - input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2 + input2, + weight2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + -100, + 0.1, + compute_nll_loss, ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 9f5d81b18..529b7dff7 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -1,4 +1,3 @@ -from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed from typing import Tuple import pytest @@ -9,6 +8,9 @@ from liger_kernel.chunked_loss.functional import liger_fused_linear_orpo from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed device = infer_device() @@ -53,11 +55,10 @@ def alignment_loss( # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) log_odds = (policy_chosen_logps - policy_rejected_logps) - ( - torch.log1p(-torch.exp(policy_chosen_logps)) - - torch.log1p(-torch.exp(policy_rejected_logps)) + torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) ) ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio + losses = -self.beta * ratio chosen_rewards = self.beta * policy_chosen_logps rejected_rewards = self.beta * policy_rejected_logps @@ -82,12 +83,8 @@ def __init__( beta: float = 0.1, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) - self.orpo_loss = HFORPOLoss( - ignore_index=ignore_index, beta=beta - ).get_batch_loss_metrics + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.orpo_loss = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics def forward(self, x, y): return self.orpo_loss(self.lin.weight, x, y, self.lin.bias) @@ -104,9 +101,7 @@ def __init__( beta: float = 0.1, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta) def forward(self, x, y): @@ -148,14 +143,12 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta=beta, ) - torch_lm_head_orpo.lin.weight.data = liger_lm_head_orpo.lin.weight.data = ( - torch.randn(V, H, device=device, dtype=dtype) + torch_lm_head_orpo.lin.weight.data = liger_lm_head_orpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype ) if bias: - torch_lm_head_orpo.lin.bias.data = liger_lm_head_orpo.lin.bias.data = ( - torch.randn(V, device=device, dtype=dtype) - ) + torch_lm_head_orpo.lin.bias.data = liger_lm_head_orpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) @@ -251,12 +244,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1, aggregated_aux_outputs1 = LigerFusedLinearORPOFunction.apply( - input1, weight1, target, bias1 - ) - loss2, aggregated_aux_outputs2 = liger_fused_linear_orpo( - input2, weight2, target, bias2 - ) + loss1, aggregated_aux_outputs1 = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1) + loss2, aggregated_aux_outputs2 = liger_fused_linear_orpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 3d0937c27..4a6f01959 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -1,6 +1,3 @@ -from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO -from test.utils import assert_verbose_allclose, set_seed - import pytest import torch @@ -8,6 +5,9 @@ from liger_kernel.chunked_loss.functional import liger_fused_linear_simpo from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction from liger_kernel.utils import infer_device +from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO +from test.utils import assert_verbose_allclose +from test.utils import set_seed device = infer_device() @@ -25,14 +25,17 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, gamma: float = 0.5, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.simpo_loss = LigerFusedLinearSimPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + gamma=gamma, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -54,11 +57,22 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize( - "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] -) +@pytest.mark.parametrize("ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)]) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + gamma, + label_smoothing, ): B = 2 * B # SimPO loss requires B to be even @@ -70,6 +84,7 @@ def test_correctness( ignore_index=ignore_index, beta=beta, loss_type="simpo", + label_smoothing=label_smoothing, simpo_gamma=gamma, ) liger_lm_head_simpo = LigerLMHeadSimPO( @@ -79,16 +94,17 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, gamma=gamma, ) - torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = ( - torch.randn(V, H, device=device, dtype=dtype) + torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype ) if bias: - torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = ( - torch.randn(V, device=device, dtype=dtype) + torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype ) _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar @@ -185,12 +201,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1, aggregated_aux_outputs1 = LigerFusedLinearSimPOFunction.apply( - input1, weight1, target, bias1 - ) - loss2, aggregated_aux_outputs2 = liger_fused_linear_simpo( - input2, weight2, target, bias2 - ) + loss1, aggregated_aux_outputs1 = LigerFusedLinearSimPOFunction.apply(input1, weight1, target, bias1) + loss2, aggregated_aux_outputs2 = liger_fused_linear_simpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 0f7e410c4..8566558e7 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -1,44 +1,47 @@ -from test.utils import ( - DEFAULT_DATASET_PATH, - MiniModelConfig, - assert_verbose_allclose, - revert_liger_kernel_to_gemma, - revert_liger_kernel_to_gemma2, - revert_liger_kernel_to_llama, - revert_liger_kernel_to_mistral, - revert_liger_kernel_to_mixtral, - revert_liger_kernel_to_mllama, - revert_liger_kernel_to_phi3, - revert_liger_kernel_to_qwen2, - revert_liger_kernel_to_qwen2_vl, - set_seed, - simple_collate_fn, - supports_bfloat16, -) - import pytest import torch + from datasets import load_from_disk from torch.utils.data import DataLoader -from transformers.models.gemma import GemmaConfig, GemmaForCausalLM -from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM -from transformers.models.llama import LlamaConfig, LlamaForCausalLM -from transformers.models.mistral import MistralConfig, MistralForCausalLM -from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM -from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM -from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM - -from liger_kernel.transformers import ( - apply_liger_kernel_to_gemma, - apply_liger_kernel_to_gemma2, - apply_liger_kernel_to_llama, - apply_liger_kernel_to_mistral, - apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_mllama, - apply_liger_kernel_to_phi3, - apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_qwen2_vl, -) +from transformers.models.gemma import GemmaConfig +from transformers.models.gemma import GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config +from transformers.models.gemma2 import Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaForCausalLM +from transformers.models.mistral import MistralConfig +from transformers.models.mistral import MistralForCausalLM +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral import MixtralForCausalLM +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config +from transformers.models.qwen2 import Qwen2ForCausalLM + +from liger_kernel.transformers import apply_liger_kernel_to_gemma +from liger_kernel.transformers import apply_liger_kernel_to_gemma2 +from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_mistral +from liger_kernel.transformers import apply_liger_kernel_to_mixtral +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_phi3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from test.utils import DEFAULT_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import revert_liger_kernel_to_gemma +from test.utils import revert_liger_kernel_to_gemma2 +from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_mistral +from test.utils import revert_liger_kernel_to_mixtral +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_phi3 +from test.utils import revert_liger_kernel_to_qwen2 +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import set_seed +from test.utils import simple_collate_fn +from test.utils import supports_bfloat16 try: # Mllama is only available in transformers>=4.45.0 @@ -52,9 +55,7 @@ try: # Qwen2-VL is only available in transformers>4.44.2 from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLForConditionalGeneration, - ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration QWEN2_VL_AVAILABLE = True except ImportError: @@ -434,9 +435,7 @@ def run_mini_model( model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) - loader = DataLoader( - train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn - ) + loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) loader_iter = iter(loader) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) @@ -470,9 +469,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), pytest.param( "mini_mllama", @@ -502,9 +499,7 @@ def run_mini_model( 1e-2, 1e-2, marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), pytest.mark.skipif( not MLLAMA_AVAILABLE, reason="Mllama not available in this version of transformers", @@ -523,9 +518,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), pytest.param( # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0 "mini_qwen2_vl", @@ -555,9 +548,7 @@ def run_mini_model( 1e-2, 1e-2, marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", @@ -576,9 +567,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( @@ -592,9 +581,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), # TODO: mixtral is flaky so disable the test for now # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), @@ -626,9 +613,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( @@ -642,9 +627,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate @@ -679,13 +662,9 @@ def test_mini_model( ): # Non-liger models should be initialized and tested first to avoid the module being overridden - expected_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr - ) + expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) - actual_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True - ) + actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True) # Compare every step of the loss assert_verbose_allclose( @@ -710,7 +689,6 @@ def test_mini_model( for expected_param, actual_param in zip( expected_output["model"].named_parameters(), actual_output["model"].named_parameters(), + strict=False, ): - assert_verbose_allclose( - expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol - ) + assert_verbose_allclose(expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol) diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index 07ddd9493..27483f08f 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -1,40 +1,33 @@ import functools import os -from test.utils import ( - FAKE_CONFIGS_PATH, - UNTOKENIZED_DATASET_PATH, - MiniModelConfig, - assert_verbose_allclose, - load_tokenizer_config, - multimodal_collate_fn, - revert_liger_kernel_to_mllama, - revert_liger_kernel_to_qwen2_vl, - set_seed, - supports_bfloat16, - train_bpe_tokenizer, -) import pytest import torch + from datasets import load_dataset from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerFast -from liger_kernel.transformers import ( - apply_liger_kernel_to_mllama, - apply_liger_kernel_to_qwen2_vl, -) +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from test.utils import FAKE_CONFIGS_PATH +from test.utils import UNTOKENIZED_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import load_tokenizer_config +from test.utils import multimodal_collate_fn +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import set_seed +from test.utils import supports_bfloat16 +from test.utils import train_bpe_tokenizer try: # Qwen2-VL is only available in transformers>=4.45.0 from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig - from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( - Qwen2VLImageProcessor, - ) - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLForConditionalGeneration, - ) + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor QWEN2_VL_AVAILABLE = True @@ -43,15 +36,11 @@ try: # Mllama is only available in transformers>=4.45.0 - from transformers.models.mllama.configuration_mllama import ( - MllamaConfig, - MllamaTextConfig, - MllamaVisionConfig, - ) + from transformers.models.mllama.configuration_mllama import MllamaConfig + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.configuration_mllama import MllamaVisionConfig from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor - from transformers.models.mllama.modeling_mllama import ( - MllamaForConditionalGeneration, - ) + from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration from transformers.models.mllama.processing_mllama import MllamaProcessor MLLAMA_AVAILABLE = True @@ -79,9 +68,7 @@ if MLLAMA_AVAILABLE: MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_mllama, fused_linear_cross_entropy=False - ), + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_mllama, fused_linear_cross_entropy=False), liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, model_class=MllamaForConditionalGeneration, mini_model_config=MllamaConfig( @@ -136,9 +123,7 @@ if QWEN2_VL_AVAILABLE: MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False - ), + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False), liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, model_class=Qwen2VLForConditionalGeneration, mini_model_config=Qwen2VLConfig( @@ -188,9 +173,7 @@ def create_processor(model_name): if model_name == "mini_qwen2_vl": tokenizer_config = load_tokenizer_config( - os.path.join( - FAKE_CONFIGS_PATH, "Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json" - ) + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json") ) tokenizer_base = train_bpe_tokenizer( [ @@ -201,13 +184,9 @@ def create_processor(model_name): ) ] ) - qwen_tokenizer = Qwen2TokenizerFast( - tokenizer_object=tokenizer_base, **tokenizer_config - ) + qwen_tokenizer = Qwen2TokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) image_processor = Qwen2VLImageProcessor() - return Qwen2VLProcessor( - image_processor=image_processor, tokenizer=qwen_tokenizer - ) + return Qwen2VLProcessor(image_processor=image_processor, tokenizer=qwen_tokenizer) elif model_name == "mini_mllama": tokenizer_config = load_tokenizer_config( @@ -225,13 +204,9 @@ def create_processor(model_name): ) ] ) - fast_tokenizer = PreTrainedTokenizerFast( - tokenizer_object=tokenizer_base, **tokenizer_config - ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) - return MllamaProcessor( - image_processor=image_processor, tokenizer=fast_tokenizer - ) + return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) else: raise ValueError(f"Processor not available for model {model_name}") @@ -265,9 +240,7 @@ def apply_chat_template(example): "content": [{"type": "text", "text": example["text"]}], }, ] - example["text"] = processor.tokenizer.apply_chat_template( - conversation, tokenize=False - ) + example["text"] = processor.tokenizer.apply_chat_template(conversation, tokenize=False) return example def preprocess_function(examples): @@ -282,9 +255,7 @@ def preprocess_function(examples): ) train_dataset = ( - load_dataset( - "text", data_files={"train": UNTOKENIZED_DATASET_PATH}, split="train" - ) + load_dataset("text", data_files={"train": UNTOKENIZED_DATASET_PATH}, split="train") .to_iterable_dataset() # only map examples as-needed and on-demand .map(generate_procedural_image, with_indices=True) .map(apply_chat_template) @@ -341,9 +312,7 @@ def run_mini_model_multimodal( model.gradient_checkpointing_enable() train_dataset = create_multimodal_dataset(model_name) - loader = DataLoader( - train_dataset, batch_size=2, shuffle=False, collate_fn=multimodal_collate_fn - ) + loader = DataLoader(train_dataset, batch_size=2, shuffle=False, collate_fn=multimodal_collate_fn) loader_iter = iter(loader) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) @@ -394,9 +363,7 @@ def run_mini_model_multimodal( 1e-2, 1e-2, marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", @@ -431,9 +398,7 @@ def run_mini_model_multimodal( 1e-2, 1e-2, marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), pytest.mark.skipif( not MLLAMA_AVAILABLE, reason="Mllama not available in this version of transformers", @@ -455,9 +420,7 @@ def test_mini_model_multimodal( param_rtol, ): # Non-liger models should be initialized and tested first to avoid the module being overridden - expected_output = run_mini_model_multimodal( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr - ) + expected_output = run_mini_model_multimodal(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) actual_output = run_mini_model_multimodal( model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True @@ -484,7 +447,6 @@ def test_mini_model_multimodal( for expected_param, actual_param in zip( expected_output["model"].named_parameters(), actual_output["model"].named_parameters(), + strict=False, ): - assert_verbose_allclose( - expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol - ) + assert_verbose_allclose(expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol) diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index e7672c4a4..9abed2bd9 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -1,44 +1,47 @@ -from test.utils import ( - DEFAULT_DATASET_PATH, - MiniModelConfig, - assert_verbose_allclose, - revert_liger_kernel_to_gemma, - revert_liger_kernel_to_gemma2, - revert_liger_kernel_to_llama, - revert_liger_kernel_to_mistral, - revert_liger_kernel_to_mixtral, - revert_liger_kernel_to_mllama, - revert_liger_kernel_to_phi3, - revert_liger_kernel_to_qwen2, - revert_liger_kernel_to_qwen2_vl, - set_seed, - simple_collate_fn, - supports_bfloat16, -) - import pytest import torch + from datasets import load_from_disk from torch.utils.data import DataLoader -from transformers.models.gemma import GemmaConfig, GemmaForCausalLM -from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM -from transformers.models.llama import LlamaConfig, LlamaForCausalLM -from transformers.models.mistral import MistralConfig, MistralForCausalLM -from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM -from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM -from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM - -from liger_kernel.transformers import ( - apply_liger_kernel_to_gemma, - apply_liger_kernel_to_gemma2, - apply_liger_kernel_to_llama, - apply_liger_kernel_to_mistral, - apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_mllama, - apply_liger_kernel_to_phi3, - apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_qwen2_vl, -) +from transformers.models.gemma import GemmaConfig +from transformers.models.gemma import GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config +from transformers.models.gemma2 import Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaForCausalLM +from transformers.models.mistral import MistralConfig +from transformers.models.mistral import MistralForCausalLM +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral import MixtralForCausalLM +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config +from transformers.models.qwen2 import Qwen2ForCausalLM + +from liger_kernel.transformers import apply_liger_kernel_to_gemma +from liger_kernel.transformers import apply_liger_kernel_to_gemma2 +from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_mistral +from liger_kernel.transformers import apply_liger_kernel_to_mixtral +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_phi3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from test.utils import DEFAULT_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import revert_liger_kernel_to_gemma +from test.utils import revert_liger_kernel_to_gemma2 +from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_mistral +from test.utils import revert_liger_kernel_to_mixtral +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_phi3 +from test.utils import revert_liger_kernel_to_qwen2 +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import set_seed +from test.utils import simple_collate_fn +from test.utils import supports_bfloat16 try: # Mllama is only available in transformers>=4.45.0 @@ -52,9 +55,7 @@ try: # Qwen2-VL is only available in transformers>4.44.2 from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLForConditionalGeneration, - ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration QWEN2_VL_AVAILABLE = True except ImportError: @@ -433,9 +434,7 @@ def run_mini_model( model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) - loader = DataLoader( - train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn - ) + loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) loader_iter = iter(loader) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) @@ -469,9 +468,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), pytest.param( "mini_mllama", @@ -501,9 +498,7 @@ def run_mini_model( 1e-2, 1e-2, marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), pytest.mark.skipif( not MLLAMA_AVAILABLE, reason="Mllama not available in this version of transformers", @@ -522,9 +517,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), pytest.param( "mini_qwen2_vl", @@ -554,9 +547,7 @@ def run_mini_model( 1e-2, 1e-2, marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", @@ -575,9 +566,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( @@ -591,9 +580,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), # TODO: mixtral is flaky so disable the test for now # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), @@ -625,9 +612,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( @@ -641,9 +626,7 @@ def run_mini_model( 1e-2, 1e-2, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate @@ -678,13 +661,9 @@ def test_mini_model( ): # Non-liger models should be initialized and tested first to avoid the module being overridden - expected_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr - ) + expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) - actual_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True - ) + actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True) # Compare every step of the loss assert_verbose_allclose( @@ -709,7 +688,6 @@ def test_mini_model( for expected_param, actual_param in zip( expected_output["model"].named_parameters(), actual_output["model"].named_parameters(), + strict=False, ): - assert_verbose_allclose( - expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol - ) + assert_verbose_allclose(expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol) diff --git a/test/resources/scripts/generate_tokenized_dataset.py b/test/resources/scripts/generate_tokenized_dataset.py index c450d56b1..9a4b8c8f3 100644 --- a/test/resources/scripts/generate_tokenized_dataset.py +++ b/test/resources/scripts/generate_tokenized_dataset.py @@ -13,19 +13,13 @@ def prepare_dataset(tokenizer, text_file_path: str): dataset = load_dataset("text", data_files={"train": text_file_path}) def tokenize_function(examples): - return tokenizer( - examples["text"], padding="max_length", truncation=True, max_length=128 - ) + return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128) - tokenized_dataset = dataset.map( - tokenize_function, batched=True, remove_columns=["text"] - ) + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) return tokenized_dataset["train"] -def generate_tokenized_dataset( - tokenizer_path: str, text_file_path: str, output_dir: str -) -> None: +def generate_tokenized_dataset(tokenizer_path: str, text_file_path: str, output_dir: str) -> None: """ Generate tokenized dataset from a text file, where each line is a different example. @@ -44,9 +38,7 @@ def generate_tokenized_dataset( if __name__ == "__main__": # Example usage: # python generate_tokenized_dataset.py --tokenizer_path /shared/public/models/Mistral-7B --text_file_path ./../../resources/tiny_shakespeare.txt --output_dir ./../../resources/tiny_shakespeare_tokenized - parser = argparse.ArgumentParser( - description="Generate tokenized dataset from a text file." - ) + parser = argparse.ArgumentParser(description="Generate tokenized dataset from a text file.") # Add arguments parser.add_argument( diff --git a/test/transformers/test_auto_model.py b/test/transformers/test_auto_model.py index 021030406..0506d69d8 100644 --- a/test/transformers/test_auto_model.py +++ b/test/transformers/test_auto_model.py @@ -1,14 +1,14 @@ from inspect import signature from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock +from unittest.mock import patch -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig +from transformers import AutoModelForCausalLM from liger_kernel.transformers import AutoLigerKernelForCausalLM -from liger_kernel.transformers.monkey_patch import ( - MODEL_TYPE_TO_APPLY_LIGER_FN, - apply_liger_kernel_to_llama, -) +from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama def test_auto_liger_kernel_for_causal_lm_from_pretrained(): @@ -33,20 +33,17 @@ def test_auto_liger_kernel_for_causal_lm_from_pretrained(): mock_model_config.model_type = "llama" mock_llama = mock.Mock() - with patch.dict( - MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama} - ), mock.patch.object( - AutoConfig, "from_pretrained", return_value=mock_model_config - ), mock.patch.object( - AutoModelForCausalLM, "from_pretrained", return_value="mock_model" - ) as mock_super_from_pretrained: - + with ( + patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}), + mock.patch.object(AutoConfig, "from_pretrained", return_value=mock_model_config), + mock.patch.object( + AutoModelForCausalLM, "from_pretrained", return_value="mock_model" + ) as mock_super_from_pretrained, + ): # Mock the function signature of apply_liger_kernel_to_llama mock_llama.__signature__ = signature(apply_liger_kernel_to_llama) - model = AutoLigerKernelForCausalLM.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs - ) + model = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs mock_llama.assert_called_once_with(rope=False, swiglu=True) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 5c050b983..7b8d1a9d0 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -1,14 +1,14 @@ -from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 - import pytest import torch import torch.nn.functional as F + +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 from torch.nn import CrossEntropyLoss -from liger_kernel.ops.cross_entropy import ( - LigerCrossEntropyFunction, - liger_cross_entropy_kernel, -) +from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction +from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel from liger_kernel.ops.utils import is_hip from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy @@ -98,10 +98,7 @@ def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, r assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_ignore_index_once( - target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol -): - +def _test_correctness_with_ignore_index_once(target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol): torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar @@ -114,9 +111,7 @@ def _test_correctness_with_ignore_index_once( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices target[indices_to_assign] = ignore_index output = torch_ce(_input, target) @@ -129,10 +124,7 @@ def _test_correctness_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_label_smoothing_once( - target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol -): - +def _test_correctness_with_label_smoothing_once(target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol): torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar @@ -154,10 +146,7 @@ def _test_correctness_with_label_smoothing_once( def _test_correctness_with_label_smoothing_with_ignore_index_once( target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol ): - - torch_ce = CrossEntropyLoss( - ignore_index=ignore_index, label_smoothing=label_smoothing - ) + torch_ce = CrossEntropyLoss(ignore_index=ignore_index, label_smoothing=label_smoothing) _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) @@ -169,9 +158,7 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices target[indices_to_assign] = ignore_index output = torch_ce(_input, target) @@ -184,10 +171,7 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_softcap_once( - target_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol -): - +def _test_correctness_with_softcap_once(target_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol): torch_ce = CrossEntropyLoss(reduction=reduction) _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar @@ -198,9 +182,7 @@ def _test_correctness_with_softcap_once( # upcasting to match liger's casting strategy # and downcasting to original dtype - output = torch_ce( - softcap * torch.tanh(_input.to(torch.float32) / softcap), target - ).to(dtype) + output = torch_ce(softcap * torch.tanh(_input.to(torch.float32) / softcap), target).to(dtype) output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) @@ -289,9 +271,7 @@ def _test_correctness_with_z_loss_with_other_params_once( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices target[indices_to_assign] = ignore_index if return_z_loss: @@ -419,7 +399,6 @@ def _test_correctness_functional( atol, rtol, ): - _input = torch.randn(B * T, V, device=device, dtype=dtype) * scalar x1 = _input.clone().requires_grad_(True) @@ -474,9 +453,7 @@ def _test_correctness_functional( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], @@ -522,20 +499,14 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], ) -def test_correctness_with_ignore_index( - B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol -): +def test_correctness_with_ignore_index(B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) - _test_correctness_with_ignore_index_once( - liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol - ) + _test_correctness_with_ignore_index_once(liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol) @pytest.mark.parametrize( @@ -554,20 +525,14 @@ def test_correctness_with_ignore_index( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], ) -def test_correctness_with_label_smoothing_once( - B, T, V, label_smoothing, scalar, dtype, atol, rtol -): +def test_correctness_with_label_smoothing_once(B, T, V, label_smoothing, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing) - _test_correctness_with_label_smoothing_once( - liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol - ) + _test_correctness_with_label_smoothing_once(liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol) @pytest.mark.parametrize( @@ -586,9 +551,7 @@ def test_correctness_with_label_smoothing_once( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], @@ -622,20 +585,14 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], ) -def test_correctness_with_softcap_once( - B, T, V, softcap, reduction, scalar, dtype, atol, rtol -): +def test_correctness_with_softcap_once(B, T, V, softcap, reduction, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(softcap=softcap, reduction=reduction) - _test_correctness_with_softcap_once( - liger_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol - ) + _test_correctness_with_softcap_once(liger_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol) @pytest.mark.parametrize( @@ -654,9 +611,7 @@ def test_correctness_with_softcap_once( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], @@ -714,9 +669,7 @@ def test_correctness_with_z_loss_once( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], @@ -790,9 +743,7 @@ def test_correctness_with_z_loss_with_other_params_once( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], @@ -903,9 +854,7 @@ def test_correctness_with_weight_with_other_params_once( ) def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(reduction=reduction) - _test_correctness_not_last_layer_once( - liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol - ) + _test_correctness_not_last_layer_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol) def test_float32_internal(): diff --git a/test/transformers/test_embedding.py b/test/transformers/test_embedding.py index 416784d0f..72c3cc721 100644 --- a/test/transformers/test_embedding.py +++ b/test/transformers/test_embedding.py @@ -1,5 +1,6 @@ import pytest import torch + from torch.nn import Embedding from liger_kernel.transformers.experimental.embedding import LigerEmbedding @@ -33,24 +34,12 @@ (torch.float32, 1e-6, 1e-5, device), ], ) -def test_embedding_correctness( - num_embeddings, embedding_dim, padding_idx, dtype, atol, rtol, device -): - print( - f"\nTesting embedding with size: ({num_embeddings}, {embedding_dim}), padding_idx: {padding_idx}" - ) +def test_embedding_correctness(num_embeddings, embedding_dim, padding_idx, dtype, atol, rtol, device): + print(f"\nTesting embedding with size: ({num_embeddings}, {embedding_dim}), padding_idx: {padding_idx}") torch.manual_seed(42) - torch_embedding = ( - Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - .to(dtype) - .to(device) - ) - liger_embedding = ( - LigerEmbedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - .to(dtype) - .to(device) - ) + torch_embedding = Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx).to(dtype).to(device) + liger_embedding = LigerEmbedding(num_embeddings, embedding_dim, padding_idx=padding_idx).to(dtype).to(device) liger_embedding.weight.data.copy_(torch_embedding.weight.data) if padding_idx is not None: @@ -69,6 +58,4 @@ def test_embedding_correctness( torch_output.backward(grad_output) liger_output.backward(grad_output) - assert torch.allclose( - torch_embedding.weight.grad, liger_embedding.weight.grad, atol=atol, rtol=rtol - ) + assert torch.allclose(torch_embedding.weight.grad, liger_embedding.weight.grad, atol=atol, rtol=rtol) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 8909d9337..58e76ac46 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -1,17 +1,15 @@ -from test.transformers.test_cross_entropy import CrossEntropyWithZLoss -from test.utils import assert_verbose_allclose, set_seed from typing import Optional import pytest import torch -from liger_kernel.ops.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyFunction, -) +from test.transformers.test_cross_entropy import CrossEntropyWithZLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.utils import infer_device device = infer_device() @@ -49,9 +47,7 @@ def __init__( softcap: Optional[float] = None, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.ce_loss = CrossEntropyWithZLoss( weight=ce_weight, ignore_index=ignore_index, @@ -83,9 +79,7 @@ def __init__( softcap: Optional[float] = None, ): super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.ce_loss = LigerFusedLinearCrossEntropyLoss( ce_weight=ce_weight, ignore_index=ignore_index, @@ -118,6 +112,8 @@ def forward(self, x, y): ("mean", 1.0, torch.float32, 1e-5, 5e-4), ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), ("sum", 1.0, torch.float32, 1e-3, 5e-2), + ("none", 1.0, torch.bfloat16, 5e-0, 5e1), + ("none", 1.0, torch.float32, 1e-3, 5e-2), ], ) @pytest.mark.parametrize("bias", [True, False]) @@ -176,14 +172,10 @@ def test_correctness( ).to(device) # init the linear in all CEs with the same weights - torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( - V, H, device=device, dtype=dtype - ) + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(V, H, device=device, dtype=dtype) if bias: - torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand( - V, device=device, dtype=dtype - ) + torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(V, device=device, dtype=dtype) _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar _input1 = _tensor.detach().clone().requires_grad_(True) @@ -194,9 +186,7 @@ def test_correctness( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices target[indices_to_assign] = ignore_index output1 = torch_lm_head_ce(_input1, target) @@ -204,8 +194,8 @@ def test_correctness( assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - output1.backward() - output2.backward() + output1.backward(gradient=torch.ones_like(output1)) + output2.backward(gradient=torch.ones_like(output2)) assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) @@ -303,9 +293,7 @@ def test_amp(B, T, H, V, cast_dtype, atol, rtol): ).to(device) # init the linear in all CEs with the same weights - torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( - V, H, device=device, dtype=dtype - ) + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(V, H, device=device, dtype=dtype) _tensor = torch.randn(B * T, H, device=device, dtype=dtype) _input1 = _tensor.detach().clone().requires_grad_(True) diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py index 75f4d775c..c33d28885 100644 --- a/test/transformers/test_fused_linear_jsd.py +++ b/test/transformers/test_fused_linear_jsd.py @@ -1,9 +1,10 @@ -from test.transformers.test_jsd import JSD as TorchJSD -from test.utils import assert_verbose_allclose, set_seed - import pytest import torch +from test.transformers.test_jsd import JSD as TorchJSD +from test.utils import assert_verbose_allclose +from test.utils import set_seed + from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.transformers.functional import liger_fused_linear_jsd from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD @@ -34,12 +35,8 @@ def __init__( temperature: float = 1.0, ): super().__init__() - self.student_lin = torch.nn.Linear( - in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device - ) - self.teacher_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) self.temperature = temperature @@ -64,15 +61,9 @@ def __init__( temperature: float = 1.0, ): super().__init__() - self.student_lin = torch.nn.Linear( - in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device - ) - self.teacher_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) - self.fused_jsd = LigerFusedLinearJSD( - jsd_beta=beta, ignore_index=ignore_index, temperature=temperature - ) + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.fused_jsd = LigerFusedLinearJSD(jsd_beta=beta, ignore_index=ignore_index, temperature=temperature) def forward(self, student_input, teacher_input, label=None): return self.fused_jsd( @@ -131,12 +122,12 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): ).to(device) # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H // 2, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar _input1 = _tensor.detach().clone().requires_grad_(True) @@ -186,9 +177,7 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): (1.0, 1.0, 2), ], ) -def test_correctness_with_ignore_index( - B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol -): +def test_correctness_with_ignore_index(B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol): torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -209,12 +198,12 @@ def test_correctness_with_ignore_index( ).to(device) # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H // 2, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar _input1 = _tensor.detach().clone().requires_grad_(True) @@ -228,9 +217,7 @@ def test_correctness_with_ignore_index( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices label[indices_to_assign] = ignore_index output1 = torch_lm_head_jsd(_input1, teacher_input, label) @@ -266,12 +253,8 @@ def test_correctness_with_ignore_index( (0.5, torch.float32, 1e-5, 5e-4), ], ) -@pytest.mark.parametrize( - "temperature, beta, ignore_index", [(1.0, 0.5, -100), (2.0, 0.1, 42)] -) -def test_correctness_functional( - B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol -): +@pytest.mark.parametrize("temperature, beta, ignore_index", [(1.0, 0.5, -100), (2.0, 0.1, 42)]) +def test_correctness_functional(B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol): # init the linear in all FusedLinearJSDs with the same weights _weight = torch.rand(V, H // 2, device=device, dtype=dtype) _weight1 = _weight.detach().clone().requires_grad_(True) @@ -289,9 +272,7 @@ def test_correctness_functional( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices label[indices_to_assign] = ignore_index output1 = liger_fused_linear_jsd( @@ -346,9 +327,7 @@ def test_correctness_functional( (2.0, 0.1, 42), ], ) -def test_correctness_all_ignored( - B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol -): +def test_correctness_all_ignored(B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol): torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -369,12 +348,12 @@ def test_correctness_all_ignored( ).to(device) # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H // 2, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar _input1 = _tensor.detach().clone().requires_grad_(True) @@ -392,9 +371,7 @@ def test_correctness_all_ignored( output2.backward() - assert_verbose_allclose( - torch.zeros_like(_input2.grad), _input2.grad, atol=atol, rtol=rtol - ) + assert_verbose_allclose(torch.zeros_like(_input2.grad), _input2.grad, atol=atol, rtol=rtol) @pytest.mark.parametrize( @@ -433,12 +410,12 @@ def test_amp(autocast_dtype, atol, rtol): beta=beta, ).to(device) # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H // 2, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) _tensor = torch.rand(B * T, H // 2, device=device, dtype=autocast_dtype) * scalar _input1 = _tensor.detach().clone().requires_grad_(True) @@ -452,9 +429,7 @@ def test_amp(autocast_dtype, atol, rtol): num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices label[indices_to_assign] = ignore_index with torch.autocast(device_type=device, dtype=autocast_dtype): diff --git a/test/transformers/test_geglu.py b/test/transformers/test_geglu.py index 0d5919729..dee8dbeac 100644 --- a/test/transformers/test_geglu.py +++ b/test/transformers/test_geglu.py @@ -1,7 +1,7 @@ -from test.utils import supports_bfloat16 - import pytest import torch + +from test.utils import supports_bfloat16 from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP @@ -38,9 +38,7 @@ torch.bfloat16, 1e4, 6e-3, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ], ) diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 4f53444d5..a73258133 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -29,24 +29,16 @@ (torch.float32, 1e-4, 1e-4), ], ) -def test_liger_group_norm( - batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol -): +def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): torch.manual_seed(0) - _tensor = torch.randn( - batch_size, num_channels, hidden_size, dtype=dtype, device=device - ) + _tensor = torch.randn(batch_size, num_channels, hidden_size, dtype=dtype, device=device) liger_x = _tensor.clone().detach().requires_grad_(True) torch_x = _tensor.clone().detach().requires_grad_(True) liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device) - torch_ln = ( - torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6) - .to(dtype) - .to(device) - ) + torch_ln = torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6).to(dtype).to(device) with torch.no_grad(): torch_ln.weight.copy_(liger_ln.weight) @@ -62,9 +54,5 @@ def test_liger_group_norm( liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) - assert torch.allclose( - liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol - ), "Bias grads different" - assert torch.allclose( - liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - ), "Weight grads different" + assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" + assert torch.allclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol), "Weight grads different" diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 86f4e3388..c0214010c 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -1,12 +1,16 @@ -from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 from typing import Optional import pytest import torch + +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 from torch.nn import KLDivLoss from liger_kernel.transformers.functional import liger_jsd -from liger_kernel.transformers.jsd import LigerJSD, LigerJSDFunction +from liger_kernel.transformers.jsd import LigerJSD +from liger_kernel.transformers.jsd import LigerJSDFunction from liger_kernel.utils import infer_device device = infer_device() @@ -39,13 +43,14 @@ def forward( loss = self.kl(log_p, log_q).sum(dim=-1) else: log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) - log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view( - -1, log_q.size(-1) + log_p, log_q = ( + log_p.view(-1, log_p.size(-1)), + log_q.view(-1, log_q.size(-1)), ) m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) - loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( - 1 - self.beta - ) * self.kl(torch.log(m), log_q).sum(dim=-1) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ).sum(dim=-1) if label is not None: loss = torch.where(label != self.ignore_index, loss, 0.0) @@ -75,9 +80,7 @@ def forward( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (torch.float32, 1e-8, 1e-6), (torch.float16, 1e-3, 1e-3), @@ -98,9 +101,7 @@ def _test_correctness_once( ): torch_jsd = JSD(dtype=dtype) - input = torch.randn( - B * T, V, device=device, dtype=dtype, requires_grad=True - ).log_softmax(dim=-1) + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) @@ -140,9 +141,7 @@ def _test_correctness_with_beta_once( ): torch_jsd = JSD(beta=beta, dtype=dtype) - input = torch.randn( - B * T, V, device=device, dtype=dtype, requires_grad=True - ).log_softmax(dim=-1) + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) @@ -177,9 +176,7 @@ def _test_correctness_with_ignore_index_once( ): torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) - input = torch.randn( - B * T, V, device=device, dtype=dtype, requires_grad=True - ).log_softmax(dim=-1) + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) @@ -193,9 +190,7 @@ def _test_correctness_with_ignore_index_once( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices label[indices_to_assign] = ignore_index output = torch_jsd(x1, target, label) @@ -207,12 +202,8 @@ def _test_correctness_with_ignore_index_once( assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) -def _test_correctness_functional( - B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device=device -): - input = torch.randn( - B * T, V, device=device, dtype=dtype, requires_grad=True - ).log_softmax(dim=-1) +def _test_correctness_functional(B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device=device): + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) @@ -226,9 +217,7 @@ def _test_correctness_functional( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices label[indices_to_assign] = ignore_index output = LigerJSDFunction.apply(x1, target, label, beta, ignore_index) @@ -278,9 +267,7 @@ def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): @pytest.mark.parametrize("ignore_index", [2, 42]) def test_correctness_with_ignore_index(B, T, V, ignore_index, dtype, atol, rtol): liger_jsd = LigerJSD(ignore_index=ignore_index) - _test_correctness_with_ignore_index_once( - liger_jsd, ignore_index, B, T, V, dtype, atol, rtol - ) + _test_correctness_with_ignore_index_once(liger_jsd, ignore_index, B, T, V, dtype, atol, rtol) @pytest.mark.parametrize(*_SHAPE_PARAMS) @@ -292,12 +279,8 @@ def test_correctness_with_ignore_index(B, T, V, ignore_index, dtype, atol, rtol) (0.1, 42, True), ], ) -def test_correctness_functional( - B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol -): - _test_correctness_functional( - B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol - ) +def test_correctness_functional(B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol): + _test_correctness_functional(B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol) # @pytest.mark.parametrize(*_SHAPE_PARAMS) @@ -314,9 +297,7 @@ def test_correctness_with_all_indices_ignored( torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) liger_jsd = LigerJSD(ignore_index=ignore_index) - inp = torch.randn( - B * T, V, device=device, dtype=dtype, requires_grad=True - ).log_softmax(dim=-1) + inp = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True).log_softmax(dim=-1) x1 = inp.detach().clone().requires_grad_(True) x2 = inp.detach().clone().requires_grad_(True) @@ -331,9 +312,7 @@ def test_correctness_with_all_indices_ignored( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices label[indices_to_assign] = ignore_index output = torch_jsd(x1, target, label) diff --git a/test/transformers/test_kl_div.py b/test/transformers/test_kl_div.py index 1f0c2d5ad..84386f4e4 100644 --- a/test/transformers/test_kl_div.py +++ b/test/transformers/test_kl_div.py @@ -1,7 +1,7 @@ -from test.utils import supports_bfloat16 - import pytest import torch + +from test.utils import supports_bfloat16 from torch.nn import KLDivLoss from liger_kernel.transformers.kl_div import LigerKLDIVLoss @@ -25,9 +25,7 @@ torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (torch.float32, 1e-8, 1e-6), (torch.float16, 1e-3, 1e-3), @@ -51,9 +49,7 @@ def _test_correctness_once( torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) - input = torch.randn( - B * T, V, device=device, dtype=dtype, requires_grad=True - ).log_softmax(dim=-1) + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) @@ -85,9 +81,7 @@ def _test_correctness_once( @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) - _test_correctness_once( - liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target - ) + _test_correctness_once(liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target) @pytest.mark.parametrize(*_SHAPE_PARAMS) diff --git a/test/transformers/test_layer_norm.py b/test/transformers/test_layer_norm.py index 4ac152440..264b730c9 100644 --- a/test/transformers/test_layer_norm.py +++ b/test/transformers/test_layer_norm.py @@ -47,9 +47,7 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): torch_output.backward(grad_output, retain_graph=True) assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) - assert torch.allclose( - liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - ) + assert torch.allclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol) assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) @@ -66,9 +64,7 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): (torch.float32, 1e-5, 1e-5), ], ) -def test_liger_layer_norm_functional( - hidden_size, batch_size, seq_len, dtype, atol, rtol -): +def test_liger_layer_norm_functional(hidden_size, batch_size, seq_len, dtype, atol, rtol): torch.manual_seed(0) input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) diff --git a/test/transformers/test_mm_int8int2.py b/test/transformers/test_mm_int8int2.py index a2458523a..57347ca89 100644 --- a/test/transformers/test_mm_int8int2.py +++ b/test/transformers/test_mm_int8int2.py @@ -1,11 +1,9 @@ import pytest import torch -from liger_kernel.ops.experimental.mm_int8int2 import ( - matmul, - pack_weights, - unpack_weights, -) +from liger_kernel.ops.experimental.mm_int8int2 import matmul +from liger_kernel.ops.experimental.mm_int8int2 import pack_weights +from liger_kernel.ops.experimental.mm_int8int2 import unpack_weights from liger_kernel.utils import infer_device device = infer_device() @@ -44,15 +42,11 @@ (1e-2, 1e-2, device), ], ) -def test_kernel_correctness( - batch_size, seq_len, out_features, size, atol, rtol, device -): +def test_kernel_correctness(batch_size, seq_len, out_features, size, atol, rtol, device): print(f"\nTesting kernel with size: {size}, atol: {atol}, rtol: {rtol}") # Generate the random tensors - ht = torch.randint( - -127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8 - ) + ht = torch.randint(-127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8) u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) # Calculate dimensions @@ -63,18 +57,14 @@ def test_kernel_correctness( # Unpack weights and compute torch output unpacked = unpack_weights(u.T, bits=2).T - torch_output = torch.matmul( - ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32) - ) + torch_output = torch.matmul(ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32)) # Print the results (optional, can be commented out) print("triton_output =", triton_output) print("torch_output =", torch_output) # Check if outputs are close within the given tolerances - assert torch.allclose( - triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol - ), "Results differ" + assert torch.allclose(triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol), "Results differ" @pytest.mark.skip(reason="mm_int8int2 is under experimentation") @@ -104,6 +94,4 @@ def test_kernel_correctness( def test_unpack_pack_correctness(out_features, size, device): u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) - assert ( - pack_weights(unpack_weights(u.T), 2) == u.T - ).all(), "Packed weights do not match original weights." + assert (pack_weights(unpack_weights(u.T), 2) == u.T).all(), "Packed weights do not match original weights." diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 19e8eb161..811cd74cc 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -1,26 +1,28 @@ import inspect + from inspect import signature -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch import pytest import torch import transformers -from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel - -from liger_kernel.transformers import ( - LigerBlockSparseTop2MLP, - LigerGEGLUMLP, - LigerPhi3SwiGLUMLP, - LigerRMSNorm, - LigerSwiGLUMLP, - monkey_patch, -) + +from transformers import AutoModelForCausalLM +from transformers import PretrainedConfig +from transformers import PreTrainedModel + +from liger_kernel.transformers import LigerBlockSparseTop2MLP +from liger_kernel.transformers import LigerGEGLUMLP +from liger_kernel.transformers import LigerPhi3SwiGLUMLP +from liger_kernel.transformers import LigerRMSNorm +from liger_kernel.transformers import LigerSwiGLUMLP +from liger_kernel.transformers import monkey_patch from liger_kernel.transformers.layer_norm import LigerLayerNorm -from liger_kernel.transformers.monkey_patch import ( - MODEL_TYPE_TO_APPLY_LIGER_FN, - _apply_liger_kernel, - _apply_liger_kernel_to_instance, -) +from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # Check if optional modules are available @@ -44,18 +46,16 @@ def is_qwen2_vl_available(): def test_import_from_root(): try: - from liger_kernel.transformers import ( # noqa: F401 - AutoLigerKernelForCausalLM, - apply_liger_kernel_to_gemma, - apply_liger_kernel_to_gemma2, - apply_liger_kernel_to_llama, - apply_liger_kernel_to_mistral, - apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_mllama, - apply_liger_kernel_to_phi3, - apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_qwen2_vl, - ) + from liger_kernel.transformers import AutoLigerKernelForCausalLM # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_gemma # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_gemma2 # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_llama # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_mistral # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_mixtral # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_mllama # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_phi3 # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_qwen2 # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl # noqa: F401 except Exception: pytest.fail("Import kernel patch from root fails") @@ -236,37 +236,21 @@ def test_apply_liger_kernel_to_instance_for_llama(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) # Ensure that the model patched with Liger modules can work properly try: @@ -279,9 +263,7 @@ def test_apply_liger_kernel_to_instance_for_llama(): def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.mllama.modeling_mllama"): - from transformers.models.mllama.modeling_mllama import ( - MllamaForConditionalGeneration, - ) + from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration # Instantiate a dummy model config = transformers.models.mllama.configuration_mllama.MllamaConfig( @@ -314,79 +296,59 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): assert isinstance(dummy_model_instance, MllamaForConditionalGeneration) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.language_model.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.language_model.model.norm.forward) != inspect.getsource( + LigerRMSNorm.forward + ) for layer in dummy_model_instance.language_model.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - - assert inspect.getsource( - dummy_model_instance.vision_model.layernorm_pre.forward - ) != inspect.getsource(LigerLayerNorm.forward) - assert inspect.getsource( - dummy_model_instance.vision_model.layernorm_post.forward - ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + + assert inspect.getsource(dummy_model_instance.vision_model.layernorm_pre.forward) != inspect.getsource( + LigerLayerNorm.forward + ) + assert inspect.getsource(dummy_model_instance.vision_model.layernorm_post.forward) != inspect.getsource( + LigerLayerNorm.forward + ) for layer in dummy_model_instance.vision_model.transformer.layers: - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerLayerNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource( + LigerLayerNorm.forward + ) for layer in dummy_model_instance.vision_model.global_transformer.layers: - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerLayerNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource( + LigerLayerNorm.forward + ) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.language_model.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.language_model.model.norm.forward) == inspect.getsource( + LigerRMSNorm.forward + ) for layer in dummy_model_instance.language_model.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - - assert inspect.getsource( - dummy_model_instance.vision_model.layernorm_pre.forward - ) == inspect.getsource(LigerLayerNorm.forward) - assert inspect.getsource( - dummy_model_instance.vision_model.layernorm_post.forward - ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + + assert inspect.getsource(dummy_model_instance.vision_model.layernorm_pre.forward) == inspect.getsource( + LigerLayerNorm.forward + ) + assert inspect.getsource(dummy_model_instance.vision_model.layernorm_post.forward) == inspect.getsource( + LigerLayerNorm.forward + ) for layer in dummy_model_instance.vision_model.transformer.layers: - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerLayerNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource( + LigerLayerNorm.forward + ) for layer in dummy_model_instance.vision_model.global_transformer.layers: - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerLayerNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource( + LigerLayerNorm.forward + ) try: print(dummy_model_instance) @@ -423,33 +385,19 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm(): # Check that model instance variables are not yet patched with Liger modules assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) @@ -472,37 +420,21 @@ def test_apply_liger_kernel_to_instance_for_mistral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) @@ -527,39 +459,23 @@ def test_apply_liger_kernel_to_instance_for_mixtral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert inspect.getsource(expert.forward) != inspect.getsource( - LigerBlockSparseTop2MLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(expert.forward) != inspect.getsource(LigerBlockSparseTop2MLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert inspect.getsource(expert.forward) == inspect.getsource( - LigerBlockSparseTop2MLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(expert.forward) == inspect.getsource(LigerBlockSparseTop2MLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) @@ -582,37 +498,21 @@ def test_apply_liger_kernel_to_instance_for_gemma(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerGEGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerGEGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) @@ -635,49 +535,29 @@ def test_apply_liger_kernel_to_instance_for_gemma2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerGEGLUMLP.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.pre_feedforward_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_feedforward_layernorm.forward) != inspect.getsource( + LigerRMSNorm.forward ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.pre_feedforward_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_feedforward_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerGEGLUMLP.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.pre_feedforward_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_feedforward_layernorm.forward) == inspect.getsource( + LigerRMSNorm.forward ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.pre_feedforward_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_feedforward_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) @@ -700,37 +580,21 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) @@ -742,9 +606,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): def test_apply_liger_kernel_to_instance_for_qwen2_vl(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"): - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLForConditionalGeneration, - ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration # Instantiate a dummy model config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig( @@ -770,51 +632,27 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl(): assert isinstance(dummy_model_instance, Qwen2VLForConditionalGeneration) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) for vision_block in dummy_model_instance.visual.blocks: - assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource( - LigerLayerNorm.forward - ) - assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource( - LigerLayerNorm.forward - ) + assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerLayerNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) for vision_block in dummy_model_instance.visual.blocks: - assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource( - LigerLayerNorm.forward - ) - assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource( - LigerLayerNorm.forward - ) + assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerLayerNorm.forward) try: print(dummy_model_instance) @@ -837,37 +675,21 @@ def test_apply_liger_kernel_to_instance_for_phi3(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource( - LigerPhi3SwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerPhi3SwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource( - dummy_model_instance.model.norm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource( - LigerPhi3SwiGLUMLP.forward - ) - assert inspect.getsource( - layer.input_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource( - layer.post_attention_layernorm.forward - ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerPhi3SwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) try: print(dummy_model_instance) diff --git a/test/transformers/test_qwen2vl_mrope.py b/test/transformers/test_qwen2vl_mrope.py index 239ba7784..1ca73b82e 100644 --- a/test/transformers/test_qwen2vl_mrope.py +++ b/test/transformers/test_qwen2vl_mrope.py @@ -1,13 +1,11 @@ -from test.utils import supports_bfloat16 - import pytest import torch +from test.utils import supports_bfloat16 + try: - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLRotaryEmbedding, - apply_multimodal_rotary_pos_emb, - ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding + from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb IS_QWEN_AVAILABLE = True except Exception: @@ -21,9 +19,7 @@ device = infer_device() -@pytest.mark.skipif( - not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." -) +@pytest.mark.skipif(not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers.") @pytest.mark.parametrize("bsz", [1, 2]) @pytest.mark.parametrize("seq_len", [128, 131]) @pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)]) @@ -43,28 +39,16 @@ torch.bfloat16, 1e-1, 1e-5, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ], ) -def test_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol -): +def test_correctness(bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol): rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) - _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) - .transpose(1, 2) - .to(dtype) - ) + _tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype) - _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) - .transpose(1, 2) - .to(dtype) - ) + _tensor_k = torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device).transpose(1, 2).to(dtype) q1 = _tensor_q.clone().requires_grad_(True) k1 = _tensor_k.clone().requires_grad_(True) @@ -73,9 +57,7 @@ def test_correctness( k2 = _tensor_k.clone().requires_grad_(True) # NOTE: this position ids distribution is different from the real one, just to test op correctness - pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view( - 3, bsz, seq_len - ) + pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(3, bsz, seq_len) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -90,20 +72,14 @@ def test_correctness( torch.randn_like(hf_k, device=device).to(dtype), ) - q1_grad, k1_grad = torch.autograd.grad( - (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True - ) - q2_grad, k2_grad = torch.autograd.grad( - (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True - ) + q1_grad, k1_grad = torch.autograd.grad((hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True) + q2_grad, k2_grad = torch.autograd.grad((tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True) torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol) torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol) -@pytest.mark.skipif( - not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." -) +@pytest.mark.skipif(not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers.") @pytest.mark.parametrize( "bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section", [ @@ -118,9 +94,7 @@ def test_correctness( (torch.bfloat16, 1e-1, 1e-5), ], ) -def test_functional_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol -): +def test_functional_correctness(bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol): _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) @@ -132,9 +106,7 @@ def test_functional_correctness( rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) - pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view( - 3, bsz, seq_len - ) + pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(3, bsz, seq_len) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section) diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 5831b1ec2..49c4656ee 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -1,10 +1,13 @@ import os -from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 import pytest import torch import torch.nn as nn +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 + from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.transformers.functional import liger_rms_norm from liger_kernel.transformers.rms_norm import LigerRMSNorm @@ -91,9 +94,7 @@ def forward(self, x): torch.bfloat16, 2e-1, 2e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ], ) @@ -112,9 +113,7 @@ def forward(self, x): False, ], ) -def test_correctness( - bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place -): +def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place): _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) h1 = _tensor.clone().requires_grad_(True) @@ -130,19 +129,13 @@ def test_correctness( # triton triton_rms = ( - LigerRMSNorm( - hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place - ) - .to(device) - .to(dtype) + LigerRMSNorm(hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place).to(device).to(dtype) ) triton_o = triton_rms(h2) triton_o.backward(do, retain_graph=True) assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol) - assert_verbose_allclose( - ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol - ) + assert_verbose_allclose(ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol) print(f"{h1.grad=}") print(f"{h2.grad=}") assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) @@ -170,9 +163,7 @@ def test_correctness( (GemmaRMSNorm, 1.0, "gemma"), ], ) -def test_correctness_functional( - bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode -): +def test_correctness_functional(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode): # h _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 74080b57f..2670e8c81 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -1,11 +1,9 @@ -from test.utils import supports_bfloat16 - import pytest import torch -from transformers.models.llama.modeling_llama import ( - LlamaRotaryEmbedding, - apply_rotary_pos_emb, -) + +from test.utils import supports_bfloat16 +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.transformers.functional import liger_rope @@ -40,28 +38,30 @@ torch.bfloat16, 1e-1, 1e-5, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ], ) +@pytest.mark.parametrize( + "expand_position_ids", + [True, False], +) def test_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + dtype, + expand_position_ids, + atol, + rtol, ): rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) - _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) - .transpose(1, 2) - .to(dtype) - ) + _tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype) - _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) - .transpose(1, 2) - .to(dtype) - ) + _tensor_k = torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device).transpose(1, 2).to(dtype) q1 = _tensor_q.clone().requires_grad_(True) k1 = _tensor_k.clone().requires_grad_(True) @@ -70,6 +70,8 @@ def test_correctness( k2 = _tensor_k.clone().requires_grad_(True) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + if expand_position_ids: + pos_ids = pos_ids.expand(bsz, -1) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -84,12 +86,8 @@ def test_correctness( torch.randn_like(hf_k, device=device).to(dtype), ) - q1_grad, k1_grad = torch.autograd.grad( - (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True - ) - q2_grad, k2_grad = torch.autograd.grad( - (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True - ) + q1_grad, k1_grad = torch.autograd.grad((hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True) + q2_grad, k2_grad = torch.autograd.grad((tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True) assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) @@ -111,8 +109,20 @@ def test_correctness( (torch.bfloat16, 1e-1, 1e-5), ], ) +@pytest.mark.parametrize( + "expand_position_ids", + [True, False], +) def test_functional_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + expand_position_ids, + dtype, + atol, + rtol, ): _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) @@ -126,6 +136,8 @@ def test_functional_correctness( rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + if expand_position_ids: + pos_ids = pos_ids.expand(bsz, -1) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin) diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index 154d5061f..0e98eec27 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -1,7 +1,7 @@ -from test.utils import supports_bfloat16 - import pytest import torch + +from test.utils import supports_bfloat16 from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.phi3.configuration_phi3 import Phi3Config @@ -9,7 +9,8 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction from liger_kernel.transformers.functional import liger_swiglu -from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.utils import infer_device device = infer_device() @@ -46,15 +47,11 @@ torch.bfloat16, 1e4, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ], ) -def test_correctness_llamamlp( - bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol -): +def test_correctness_llamamlp(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) @@ -126,15 +123,11 @@ def test_correctness_llamamlp( torch.bfloat16, 1e4, 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), ], ) -def test_correctness_phi3mlp( - bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol -): +def test_correctness_phi3mlp(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) diff --git a/test/transformers/test_trainer_integration.py b/test/transformers/test_trainer_integration.py index 554da5ea9..b7c940d1f 100644 --- a/test/transformers/test_trainer_integration.py +++ b/test/transformers/test_trainer_integration.py @@ -3,8 +3,6 @@ def test_import(): try: - from liger_kernel.transformers.trainer_integration import ( # noqa: F401 - _apply_liger_kernel, - ) + from liger_kernel.transformers.trainer_integration import _apply_liger_kernel # noqa: F401 except Exception: pytest.fail("Import _apply_liger_kernel fails") diff --git a/test/transformers/test_transformers.py b/test/transformers/test_transformers.py index 9601229ec..61431871b 100644 --- a/test/transformers/test_transformers.py +++ b/test/transformers/test_transformers.py @@ -3,16 +3,14 @@ def test_import_from_root(): try: - from liger_kernel.transformers import ( # noqa: F401 - LigerBlockSparseTop2MLP, - LigerCrossEntropyLoss, - LigerFusedLinearCrossEntropyLoss, - LigerGEGLUMLP, - LigerLayerNorm, - LigerPhi3SwiGLUMLP, - LigerRMSNorm, - LigerSwiGLUMLP, - liger_rotary_pos_emb, - ) + from liger_kernel.transformers import LigerBlockSparseTop2MLP # noqa: F401 + from liger_kernel.transformers import LigerCrossEntropyLoss # noqa: F401 + from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # noqa: F401 + from liger_kernel.transformers import LigerGEGLUMLP # noqa: F401 + from liger_kernel.transformers import LigerLayerNorm # noqa: F401 + from liger_kernel.transformers import LigerPhi3SwiGLUMLP # noqa: F401 + from liger_kernel.transformers import LigerRMSNorm # noqa: F401 + from liger_kernel.transformers import LigerSwiGLUMLP # noqa: F401 + from liger_kernel.transformers import liger_rotary_pos_emb # noqa: F401 except Exception: pytest.fail("Import kernels from root fails") diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..ec4abd5a8 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,18 +2,25 @@ import json import os import random + from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Tuple +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple import numpy as np import torch import torch.nn as nn -from tokenizers import AddedToken, Tokenizer + +from tokenizers import AddedToken +from tokenizers import Tokenizer from tokenizers.models import BPE from tokenizers.pre_tokenizers import Whitespace from tokenizers.trainers import BpeTrainer -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig +from transformers import PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding from liger_kernel.utils import infer_device @@ -80,13 +87,9 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print= nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor( - torch.isposinf(tensor1), torch.isposinf(tensor2) - ) + posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2)) # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor( - torch.isneginf(tensor1), torch.isneginf(tensor2) - ) + neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2)) # Find all mismatched elements mismatched = torch.logical_or( @@ -108,29 +111,19 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print= print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) - mismatch_details.append( - f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" - ) + mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}") if num_mismatched > max_print: - mismatch_details.append( - f"... and {num_mismatched - max_print} more mismatched elements." - ) + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") raise AssertionError("\n".join(mismatch_details)) # Pre-tokenized dataset using Mistral-7B tokenizer used for convergence tests -DEFAULT_DATASET_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare_tokenized" -) +DEFAULT_DATASET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare_tokenized") -UNTOKENIZED_DATASET_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare.txt" -) +UNTOKENIZED_DATASET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare.txt") -FAKE_CONFIGS_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "resources/fake_configs" -) +FAKE_CONFIGS_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/fake_configs") @dataclass @@ -145,9 +138,7 @@ def simple_collate_fn(data: List[Dict[str, Any]]): """A basic collate function to use for DataLoader""" input_ids = torch.stack([torch.tensor(item["input_ids"]) for item in data]) - attention_mask = torch.stack( - [torch.tensor(item["attention_mask"]) for item in data] - ) + attention_mask = torch.stack([torch.tensor(item["attention_mask"]) for item in data]) labels = input_ids.clone() return BatchEncoding( @@ -234,9 +225,7 @@ def revert_liger_kernel_to_llama(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") -def revert_liger_kernel_to_mllama( - model_config: MiniModelConfig, model_type: str = "causal_lm" -): +def revert_liger_kernel_to_mllama(model_config: MiniModelConfig, model_type: str = "causal_lm"): """ Revert all Liger kernel patches applied to MLlama. """ @@ -246,6 +235,7 @@ def revert_liger_kernel_to_mllama( "conditional_generation", ], f'model_type must be "causal_lm" or "conditional_generation", Got: {model_type}' import torch.nn as nn + from transformers.models.mllama import modeling_mllama importlib.reload(nn) @@ -343,18 +333,19 @@ def revert_liger_kernel_to_phi3(model_config: MiniModelConfig): class HFAlignmentLoss: - def __init__( self, alpha: float = 1.0, beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + compute_nll_loss: bool = True, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.compute_nll_loss = compute_nll_loss @abstractmethod def alignment_loss(self): @@ -377,18 +368,14 @@ def get_batch_logps( A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ if logits.shape[:-1] != labels.shape: - raise ValueError( - "Logits (batch and sequence length dim) and labels must have the same shape." - ) + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") loss_mask = labels != self.ignore_index # dummy token; we'll ignore the losses on these tokens later labels = torch.where(labels == self.ignore_index, 0, labels) - per_token_logps = torch.gather( - logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) - ).squeeze(2) + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) @@ -408,9 +395,7 @@ def get_ref_logps( ref_logits = _input @ ref_weight.t() if ref_bias is not None: ref_logits = ref_logits + ref_bias - ref_all_logps = self.get_batch_logps( - ref_logits, target, average_log_prob=average_log_prob - ) + ref_all_logps = self.get_batch_logps(ref_logits, target, average_log_prob=average_log_prob) return ( ref_all_logps[: _input.shape[0] // 2], ref_all_logps[_input.shape[0] // 2 :], @@ -423,9 +408,7 @@ def concatenated_forward( target: torch.LongTensor, bias: torch.FloatTensor = None, average_log_prob: bool = True, - ) -> Tuple[ - torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor - ]: + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. @@ -448,9 +431,9 @@ def cross_entropy_loss(logits, labels): return loss labels = target - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], labels[:len_chosen] - ) + chosen_nll_loss = torch.tensor(0.0, device=all_logits.device) + if self.compute_nll_loss: + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, @@ -485,9 +468,7 @@ def get_batch_loss_metrics( ): """Compute the loss metrics for the given batch of inputs for train or test.""" - forward_output = self.concatenated_forward( - _input, weight, target, bias, average_log_prob - ) + forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob) ( policy_chosen_logps, policy_rejected_logps, @@ -503,15 +484,13 @@ def get_batch_loss_metrics( ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps - alignment_loss_outputs = self.alignment_loss( - policy_chosen_logps, policy_rejected_logps, **loss_kwargs - ) + alignment_loss_outputs = self.alignment_loss(policy_chosen_logps, policy_rejected_logps, **loss_kwargs) if isinstance(alignment_loss_outputs, tuple): losses, *aggregated_aux_outputs = alignment_loss_outputs else: losses, aggregated_aux_outputs = alignment_loss_outputs, [] # full loss - loss = policy_nll_loss * self.alpha - losses.mean() + loss = policy_nll_loss * self.alpha + losses.mean() return_vars = ( policy_chosen_logps, policy_rejected_logps, @@ -628,7 +607,5 @@ def get_batch_loss_metrics( soft_loss = self.distillation_loss(student_logits, teacher_logits) # full loss - loss = ( - self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() - ) + loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() return loss