Skip to content

Commit

Permalink
Xpu support (#407)
Browse files Browse the repository at this point in the history
## Summary
Replica of #396 
Adds xpu support so all tests, benchmarks etc. run on XPUs or Intel
GPUs.

## Details
infer_device() function is moved to a separate file and in any file
where previously "cuda" was needed, infer_device is imported and "cuda"
is replaced with return value of a call to infer_device()

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
A100 80GB PCIe, RTX 3060, Intel Data Center GPU Max 1550
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <[email protected]>
  • Loading branch information
mgrabban and lancerts authored Nov 23, 2024
1 parent 90fb5e4 commit 7e3683e
Show file tree
Hide file tree
Showing 48 changed files with 365 additions and 252 deletions.
6 changes: 3 additions & 3 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
)

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device

device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

Expand Down Expand Up @@ -66,7 +69,6 @@ def bench_memory_fused_linear_cpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -107,8 +109,6 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)

Expand Down
11 changes: 7 additions & 4 deletions benchmark/scripts/benchmark_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.utils import infer_device

device = infer_device()


def bench_memory_cross_entropy(
Expand All @@ -24,8 +27,8 @@ def bench_memory_cross_entropy(
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda")
target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1)
_input = torch.randn(B * T, V, requires_grad=True, device=device)
target = torch.randint(V, (B * T, 1), device=device).squeeze(1)

def fwd():
if provider == "liger":
Expand Down Expand Up @@ -57,8 +60,8 @@ def bench_speed_cross_entropy(
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda")
target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1)
_input = torch.randn(B * T, V, requires_grad=True, device=device)
target = torch.randint(V, (B * T, 1), device=device).squeeze(1)

def fwd():
if provider == "liger":
Expand Down
5 changes: 3 additions & 2 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
)

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device

device = infer_device()


class TorchDPOLoss(torch.nn.Module):
Expand Down Expand Up @@ -79,7 +82,6 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

device = "cuda"
torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
Expand Down Expand Up @@ -127,7 +129,6 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"
torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
Expand Down
7 changes: 3 additions & 4 deletions benchmark/scripts/benchmark_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)

from liger_kernel.transformers.experimental.embedding import LigerEmbedding
from liger_kernel.utils import infer_device

device = infer_device()

# NOTE: For torch compile, we will just use default inductor settings. No further customization
# is needed.
Expand All @@ -26,8 +29,6 @@ def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
D = input.extra_benchmark_config["D"]
dtype = input.extra_benchmark_config["dtype"]

device = "cuda"

torch_emb = Embedding(V, D).to(device).to(dtype)
liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
torch_compile_emb = torch.compile(torch_emb)
Expand Down Expand Up @@ -68,8 +69,6 @@ def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
D = input.extra_benchmark_config["D"]
dtype = input.extra_benchmark_config["dtype"]

device = "cuda"

torch_emb = Embedding(V, D).to(device).to(dtype)
liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
torch_compile_emb = torch.compile(torch_emb)
Expand Down
6 changes: 3 additions & 3 deletions benchmark/scripts/benchmark_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from liger_kernel.utils import infer_device

device = infer_device()


class TorchLMHeadCE(torch.nn.Module):
Expand Down Expand Up @@ -65,7 +68,6 @@ def bench_memory_fused_linear_cross_entropy(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -105,8 +107,6 @@ def bench_speed_fused_linear_cross_entropy(
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)

Expand Down
5 changes: 3 additions & 2 deletions benchmark/scripts/benchmark_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
)

from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD
from liger_kernel.utils import infer_device

device = infer_device()


class TorchJSD(torch.nn.Module):
Expand Down Expand Up @@ -134,7 +137,6 @@ def bench_memory_fused_linear_jsd(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)

Expand Down Expand Up @@ -183,7 +185,6 @@ def bench_speed_fused_linear_jsd(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)

Expand Down
5 changes: 3 additions & 2 deletions benchmark/scripts/benchmark_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
)

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.utils import infer_device

device = infer_device()


def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand All @@ -31,7 +34,6 @@ def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
)

x_shape = (bsz, seq_len, hidden_size)
device = "cuda"

# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
Expand Down Expand Up @@ -99,7 +101,6 @@ def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp
)

x_shape = (bsz, seq_len, hidden_size)
device = "cuda"
# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

Expand Down
15 changes: 9 additions & 6 deletions benchmark/scripts/benchmark_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
)

from liger_kernel.transformers.group_norm import LigerGroupNorm
from liger_kernel.utils import infer_device

device = infer_device()


def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand All @@ -26,12 +29,12 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(
num_channels=C, num_groups=C // channels_per_group, eps=eps
).to("cuda")
).to(device)
torch_ln = torch.nn.GroupNorm(
num_groups=C // channels_per_group, num_channels=C, eps=eps
).to("cuda")
).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down Expand Up @@ -83,12 +86,12 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(
num_channels=C, num_groups=C // channels_per_group, eps=eps
).to("cuda")
).to(device)
torch_ln = torch.nn.GroupNorm(
num_groups=C // channels_per_group, num_channels=C, eps=eps
).to("cuda")
).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down
11 changes: 7 additions & 4 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
)

from liger_kernel.transformers.jsd import LigerJSD
from liger_kernel.utils import infer_device

device = infer_device()


class TorchJSD(torch.nn.Module):
Expand Down Expand Up @@ -56,10 +59,10 @@ def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down Expand Up @@ -101,10 +104,10 @@ def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down
11 changes: 7 additions & 4 deletions benchmark/scripts/benchmark_kl_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)

from liger_kernel.transformers.kl_div import LigerKLDIVLoss
from liger_kernel.utils import infer_device

device = infer_device()

S, E = 12, 18

Expand All @@ -22,10 +25,10 @@ def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
torch_kl_div = nn.KLDivLoss(reduction=reduction)
liger_kl_div = LigerKLDIVLoss(reduction=reduction)

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down Expand Up @@ -68,10 +71,10 @@ def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down
15 changes: 9 additions & 6 deletions benchmark/scripts/benchmark_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
)

from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.utils import infer_device

device = infer_device()


def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand All @@ -22,10 +25,10 @@ def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
dtype = extra_benchmark_config["dtype"]

x_shape = (M, N)
triton_ln = LigerLayerNorm(hidden_size=N).to("cuda")
torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda")
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down Expand Up @@ -73,10 +76,10 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu

x_shape = (M, N)

triton_ln = LigerLayerNorm(hidden_size=N).to("cuda")
torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda")
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down
6 changes: 3 additions & 3 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
)

from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.utils import infer_device

device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

Expand Down Expand Up @@ -66,7 +69,6 @@ def bench_memory_fused_linear_orpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -107,8 +109,6 @@ def bench_speed_fused_linear_orpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)

Expand Down
Loading

0 comments on commit 7e3683e

Please sign in to comment.