Skip to content

Commit

Permalink
Merge branch 'main' into tcc/weight-ce
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Dec 28, 2024
2 parents 5535a60 + 9875488 commit 8195398
Show file tree
Hide file tree
Showing 112 changed files with 1,629 additions and 2,769 deletions.
10 changes: 0 additions & 10 deletions .flake8

This file was deleted.

2 changes: 0 additions & 2 deletions .isort.cfg

This file was deleted.

7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ test:
# Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
# Subsequent commands still run if the previous fails, but return failure at the end
checkstyle:
flake8 .; flake8_status=$$?; \
isort .; isort_status=$$?; \
black .; black_status=$$?; \
if [ $$flake8_status -ne 0 ] || [ $$isort_status -ne 0 ] || [ $$black_status -ne 0 ]; then \
ruff check . --fix; ruff_check_status=$$?; \
ruff format .; ruff_format_status=$$?; \
if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
exit 1; \
fi

Expand Down
15 changes: 5 additions & 10 deletions benchmark/benchmarks_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os

from argparse import ArgumentParser
from dataclasses import dataclass

Expand Down Expand Up @@ -39,9 +40,7 @@ def parse_args() -> VisualizationsConfig:
VisualizationsConfig: Configuration object for the visualizations script.
"""
parser = ArgumentParser()
parser.add_argument(
"--kernel-name", type=str, required=True, help="Kernel name to benchmark"
)
parser.add_argument("--kernel-name", type=str, required=True, help="Kernel name to benchmark")
parser.add_argument(
"--metric-name",
type=str,
Expand All @@ -54,9 +53,7 @@ def parse_args() -> VisualizationsConfig:
required=True,
help="Kernel operation mode to visualize (forward/backward/full)",
)
parser.add_argument(
"--display", action="store_true", help="Display the visualization"
)
parser.add_argument("--display", action="store_true", help="Display the visualization")
parser.add_argument(
"--overwrite",
action="store_true",
Expand Down Expand Up @@ -126,7 +123,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
lines = ax.get_lines()
colors = [line.get_color() for line in lines]

for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors, strict=False):
# for i, row in group_data.iterrows():
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
Expand All @@ -145,9 +142,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
plt.ylabel(ylabel)
plt.tight_layout()

out_path = os.path.join(
VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
)
out_path = os.path.join(VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png")

if config.display:
plt.show()
Expand Down
27 changes: 11 additions & 16 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device
Expand All @@ -33,9 +32,7 @@ def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100)
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.cpo_loss = HFCPOLoss().get_batch_loss_metrics

def forward(self, x, y):
Expand All @@ -45,9 +42,7 @@ def forward(self, x, y):
class LigerLMHeadCPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.cpo_loss = LigerFusedLinearCPOFunction.apply

def forward(self, x, y):
Expand Down Expand Up @@ -180,12 +175,12 @@ def full():
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_cpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
**common_configs,
)
23 changes: 10 additions & 13 deletions benchmark/scripts/benchmark_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import torch
import triton

from torch.nn import CrossEntropyLoss
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.utils import infer_device
Expand Down Expand Up @@ -86,9 +85,7 @@ def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full, rep=100, quantiles=QUANTILES
)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)

return SingleBenchmarkRunOutput(
y_20=ms_20,
Expand All @@ -115,12 +112,12 @@ def full():
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_cross_entropy,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
**common_configs,
)
46 changes: 16 additions & 30 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device
Expand All @@ -28,9 +26,7 @@ def __init__(
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)

def forward(self, x, target):
Expand All @@ -53,9 +49,7 @@ def __init__(
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.beta = beta
self.ignore_index = ignore_index

Expand All @@ -82,12 +76,8 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_dpo_loss = LigerDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
Expand Down Expand Up @@ -129,12 +119,8 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_dpo_loss = LigerDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
Expand Down Expand Up @@ -215,13 +201,13 @@ def full():
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
**common_configs,
)

run_benchmarks(
bench_test_fn=bench_memory_dpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
**common_configs,
)
23 changes: 10 additions & 13 deletions benchmark/scripts/benchmark_embedding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import torch
import triton

from torch.nn import Embedding
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.experimental.embedding import LigerEmbedding
from liger_kernel.utils import infer_device
Expand Down Expand Up @@ -50,9 +49,7 @@ def full():
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif mode == "full":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full, quantiles=QUANTILES, rep=100
)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
Expand Down Expand Up @@ -118,12 +115,12 @@ def full():
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_embedding,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
**common_configs,
)
Loading

0 comments on commit 8195398

Please sign in to comment.