Skip to content

Commit

Permalink
Kernels for GroupNorm (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodith authored Nov 7, 2024
1 parent f33de99 commit a954b73
Show file tree
Hide file tree
Showing 6 changed files with 708 additions and 0 deletions.
114 changes: 114 additions & 0 deletions benchmark/data/all_benchmark_data.csv

Large diffs are not rendered by default.

147 changes: 147 additions & 0 deletions benchmark/scripts/benchmark_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.transformers.group_norm import LigerGroupNorm


def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
C = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
H = extra_benchmark_config["H"]
channels_per_group = extra_benchmark_config["channels_per_group"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]

x_shape = (M, C, H)
triton_ln = LigerGroupNorm(
num_channels=C, num_groups=C // channels_per_group, eps=eps
).to("cuda")
torch_ln = torch.nn.GroupNorm(
num_groups=C // channels_per_group, num_channels=C, eps=eps
).to("cuda")

x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)

def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":

def full():
y = y_fwd()
y.backward(dy, retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full, quantiles=QUANTILES, grad_to_none=[x], rep=500
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
C = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
H = extra_benchmark_config["H"]
channels_per_group = extra_benchmark_config["channels_per_group"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]

x_shape = (M, C, H)
triton_ln = LigerGroupNorm(
num_channels=C, num_groups=C // channels_per_group, eps=eps
).to("cuda")
torch_ln = torch.nn.GroupNorm(
num_groups=C // channels_per_group, num_channels=C, eps=eps
).to("cuda")

x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)

def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)

def full():
y = y_fwd()
y.backward(dy, retain_graph=True)

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "group_norm",
"x_name": "C",
"x_label": "num_channels",
"x_values": [2**i for i in range(5, 12)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"M": 128,
"H": 512,
"channels_per_group": 4,
"dtype": torch.float32,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_group_norm,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_group_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
Loading

0 comments on commit a954b73

Please sign in to comment.