From 7e3683e23f8a9a5663913fd0ea7b0b03ea1a667b Mon Sep 17 00:00:00 2001 From: Golam Rabbani Date: Fri, 22 Nov 2024 20:03:06 -0800 Subject: [PATCH] Xpu support (#407) ## Summary Replica of #396 Adds xpu support so all tests, benchmarks etc. run on XPUs or Intel GPUs. ## Details infer_device() function is moved to a separate file and in any file where previously "cuda" was needed, infer_device is imported and "cuda" is replaced with return value of a call to infer_device() ## Testing Done A100 80GB PCIe, RTX 3060, Intel Data Center GPU Max 1550 - Hardware Type: - [x] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang --- benchmark/scripts/benchmark_cpo_loss.py | 6 +-- benchmark/scripts/benchmark_cross_entropy.py | 11 +++-- benchmark/scripts/benchmark_dpo_loss.py | 5 +- benchmark/scripts/benchmark_embedding.py | 7 ++- .../benchmark_fused_linear_cross_entropy.py | 6 +-- .../scripts/benchmark_fused_linear_jsd.py | 5 +- benchmark/scripts/benchmark_geglu.py | 5 +- benchmark/scripts/benchmark_group_norm.py | 15 +++--- benchmark/scripts/benchmark_jsd.py | 11 +++-- benchmark/scripts/benchmark_kl_div.py | 11 +++-- benchmark/scripts/benchmark_layer_norm.py | 15 +++--- benchmark/scripts/benchmark_orpo_loss.py | 6 +-- benchmark/scripts/benchmark_qwen2vl_mrope.py | 27 ++++++----- benchmark/scripts/benchmark_rms_norm.py | 15 +++--- benchmark/scripts/benchmark_rope.py | 27 ++++++----- benchmark/scripts/benchmark_simpo_loss.py | 6 +-- benchmark/scripts/benchmark_swiglu.py | 5 +- benchmark/scripts/utils.py | 13 ++++-- examples/huggingface/callback.py | 13 ++++-- examples/lightning/training.py | 10 +++- examples/medusa/callback.py | 15 ++++-- src/liger_kernel/__init__.py | 0 src/liger_kernel/ops/layer_norm.py | 7 ++- src/liger_kernel/ops/rms_norm.py | 1 + src/liger_kernel/ops/utils.py | 7 ++- src/liger_kernel/utils.py | 13 ++++++ test/chunked_loss/test_cpo_loss.py | 19 ++++---- test/chunked_loss/test_dpo_loss.py | 27 ++++++----- test/chunked_loss/test_orpo_loss.py | 19 ++++---- test/chunked_loss/test_simpo_loss.py | 19 ++++---- test/convergence/test_mini_models.py | 6 ++- .../test_mini_models_multimodal.py | 6 ++- .../test_mini_models_with_logits.py | 6 ++- test/transformers/test_cross_entropy.py | 46 ++++++++++--------- test/transformers/test_embedding.py | 5 +- .../test_fused_linear_cross_entropy.py | 11 ++--- test/transformers/test_fused_linear_jsd.py | 11 ++--- test/transformers/test_geglu.py | 19 ++++---- test/transformers/test_group_norm.py | 9 ++-- test/transformers/test_jsd.py | 13 ++++-- test/transformers/test_kl_div.py | 5 +- test/transformers/test_layer_norm.py | 15 +++--- test/transformers/test_mm_int8int2.py | 7 ++- test/transformers/test_qwen2vl_mrope.py | 23 ++++++---- test/transformers/test_rms_norm.py | 12 ++--- test/transformers/test_rope.py | 23 ++++++---- test/transformers/test_swiglu.py | 29 ++++++------ test/utils.py | 25 +++------- 48 files changed, 365 insertions(+), 252 deletions(-) create mode 100644 src/liger_kernel/__init__.py create mode 100644 src/liger_kernel/utils.py diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index d10c8da8a..5fc43c7ea 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -13,6 +13,9 @@ ) from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -66,7 +69,6 @@ def bench_memory_fused_linear_cpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) @@ -107,8 +109,6 @@ def bench_speed_fused_linear_cpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_cross_entropy.py b/benchmark/scripts/benchmark_cross_entropy.py index d6dffbf7e..f7b749c98 100644 --- a/benchmark/scripts/benchmark_cross_entropy.py +++ b/benchmark/scripts/benchmark_cross_entropy.py @@ -11,6 +11,9 @@ ) from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device + +device = infer_device() def bench_memory_cross_entropy( @@ -24,8 +27,8 @@ def bench_memory_cross_entropy( B = input.extra_benchmark_config["B"] T = input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda") - target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1) + _input = torch.randn(B * T, V, requires_grad=True, device=device) + target = torch.randint(V, (B * T, 1), device=device).squeeze(1) def fwd(): if provider == "liger": @@ -57,8 +60,8 @@ def bench_speed_cross_entropy( B = input.extra_benchmark_config["B"] T = input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda") - target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1) + _input = torch.randn(B * T, V, requires_grad=True, device=device) + target = torch.randint(V, (B * T, 1), device=device).squeeze(1) def fwd(): if provider == "liger": diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index 537be47bc..af8e3dac5 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -12,6 +12,9 @@ ) from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() class TorchDPOLoss(torch.nn.Module): @@ -79,7 +82,6 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO ignore_index = input.extra_benchmark_config["ignore_index"] provider = input.kernel_provider - device = "cuda" torch_dpo_loss = TorchDPOLoss( H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) @@ -127,7 +129,6 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" torch_dpo_loss = TorchDPOLoss( H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) diff --git a/benchmark/scripts/benchmark_embedding.py b/benchmark/scripts/benchmark_embedding.py index 1f20aec35..40722ee1b 100644 --- a/benchmark/scripts/benchmark_embedding.py +++ b/benchmark/scripts/benchmark_embedding.py @@ -11,6 +11,9 @@ ) from liger_kernel.transformers.experimental.embedding import LigerEmbedding +from liger_kernel.utils import infer_device + +device = infer_device() # NOTE: For torch compile, we will just use default inductor settings. No further customization # is needed. @@ -26,8 +29,6 @@ def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO D = input.extra_benchmark_config["D"] dtype = input.extra_benchmark_config["dtype"] - device = "cuda" - torch_emb = Embedding(V, D).to(device).to(dtype) liger_emb = LigerEmbedding(V, D).to(device).to(dtype) torch_compile_emb = torch.compile(torch_emb) @@ -68,8 +69,6 @@ def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun D = input.extra_benchmark_config["D"] dtype = input.extra_benchmark_config["dtype"] - device = "cuda" - torch_emb = Embedding(V, D).to(device).to(dtype) liger_emb = LigerEmbedding(V, D).to(device).to(dtype) torch_compile_emb = torch.compile(torch_emb) diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py index eaceeed03..2e3b08732 100644 --- a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -12,6 +12,9 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.utils import infer_device + +device = infer_device() class TorchLMHeadCE(torch.nn.Module): @@ -65,7 +68,6 @@ def bench_memory_fused_linear_cross_entropy( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) @@ -105,8 +107,6 @@ def bench_speed_fused_linear_cross_entropy( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index 7f652de8a..dcefb2137 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -10,6 +10,9 @@ ) from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD +from liger_kernel.utils import infer_device + +device = infer_device() class TorchJSD(torch.nn.Module): @@ -134,7 +137,6 @@ def bench_memory_fused_linear_jsd( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) @@ -183,7 +185,6 @@ def bench_speed_fused_linear_jsd( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) diff --git a/benchmark/scripts/benchmark_geglu.py b/benchmark/scripts/benchmark_geglu.py index 81611de3f..7b0d237ca 100644 --- a/benchmark/scripts/benchmark_geglu.py +++ b/benchmark/scripts/benchmark_geglu.py @@ -12,6 +12,9 @@ ) from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -31,7 +34,6 @@ def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) @@ -99,7 +101,6 @@ def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py index 595d379f8..0c3c05608 100644 --- a/benchmark/scripts/benchmark_group_norm.py +++ b/benchmark/scripts/benchmark_group_norm.py @@ -10,6 +10,9 @@ ) from liger_kernel.transformers.group_norm import LigerGroupNorm +from liger_kernel.utils import infer_device + +device = infer_device() def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -26,12 +29,12 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun x_shape = (M, C, H) triton_ln = LigerGroupNorm( num_channels=C, num_groups=C // channels_per_group, eps=eps - ).to("cuda") + ).to(device) torch_ln = torch.nn.GroupNorm( num_groups=C // channels_per_group, num_channels=C, eps=eps - ).to("cuda") + ).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) @@ -83,12 +86,12 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu x_shape = (M, C, H) triton_ln = LigerGroupNorm( num_channels=C, num_groups=C // channels_per_group, eps=eps - ).to("cuda") + ).to(device) torch_ln = torch.nn.GroupNorm( num_groups=C // channels_per_group, num_channels=C, eps=eps - ).to("cuda") + ).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py index 272008315..c5f8bec18 100644 --- a/benchmark/scripts/benchmark_jsd.py +++ b/benchmark/scripts/benchmark_jsd.py @@ -10,6 +10,9 @@ ) from liger_kernel.transformers.jsd import LigerJSD +from liger_kernel.utils import infer_device + +device = infer_device() class TorchJSD(torch.nn.Module): @@ -56,10 +59,10 @@ def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: torch_jsd = TorchJSD() liger_jsd = LigerJSD() - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": @@ -101,10 +104,10 @@ def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput V = input.x B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": diff --git a/benchmark/scripts/benchmark_kl_div.py b/benchmark/scripts/benchmark_kl_div.py index c446c7ae2..c52d8e658 100644 --- a/benchmark/scripts/benchmark_kl_div.py +++ b/benchmark/scripts/benchmark_kl_div.py @@ -11,6 +11,9 @@ ) from liger_kernel.transformers.kl_div import LigerKLDIVLoss +from liger_kernel.utils import infer_device + +device = infer_device() S, E = 12, 18 @@ -22,10 +25,10 @@ def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu torch_kl_div = nn.KLDivLoss(reduction=reduction) liger_kl_div = LigerKLDIVLoss(reduction=reduction) - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": @@ -68,10 +71,10 @@ def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp V = input.x B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 89f07c640..4d36d4b4b 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -10,6 +10,9 @@ ) from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.utils import infer_device + +device = infer_device() def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -22,10 +25,10 @@ def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun dtype = extra_benchmark_config["dtype"] x_shape = (M, N) - triton_ln = LigerLayerNorm(hidden_size=N).to("cuda") - torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda") + triton_ln = LigerLayerNorm(hidden_size=N).to(device) + torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) @@ -73,10 +76,10 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu x_shape = (M, N) - triton_ln = LigerLayerNorm(hidden_size=N).to("cuda") - torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda") + triton_ln = LigerLayerNorm(hidden_size=N).to(device) + torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py index dda42d772..e1b2c8d25 100644 --- a/benchmark/scripts/benchmark_orpo_loss.py +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -13,6 +13,9 @@ ) from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -66,7 +69,6 @@ def bench_memory_fused_linear_orpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) @@ -107,8 +109,6 @@ def bench_speed_fused_linear_orpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_qwen2vl_mrope.py b/benchmark/scripts/benchmark_qwen2vl_mrope.py index 77ed61921..dccb37d33 100644 --- a/benchmark/scripts/benchmark_qwen2vl_mrope.py +++ b/benchmark/scripts/benchmark_qwen2vl_mrope.py @@ -14,6 +14,9 @@ ) from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() def bench_speed_qwen2vl_mrope( @@ -40,23 +43,23 @@ def bench_speed_qwen2vl_mrope( ) head_dim = hidden_size // num_q_heads - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) mrope_section_hw = head_dim * 3 // 16 @@ -133,23 +136,23 @@ def bench_memory_qwen2vl_mrope( ) head_dim = hidden_size // num_q_heads - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) mrope_section_hw = head_dim * 3 // 16 diff --git a/benchmark/scripts/benchmark_rms_norm.py b/benchmark/scripts/benchmark_rms_norm.py index 46734504e..533a13aec 100644 --- a/benchmark/scripts/benchmark_rms_norm.py +++ b/benchmark/scripts/benchmark_rms_norm.py @@ -11,6 +11,9 @@ ) from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() class LlamaRMSNorm(nn.Module): @@ -42,10 +45,10 @@ def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu x_shape = (M, N) - triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to("cuda") - llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to("cuda") + triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) + llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) @@ -104,10 +107,10 @@ def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO x_shape = (M, N) - triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to("cuda") - llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to("cuda") + triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) + llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) diff --git a/benchmark/scripts/benchmark_rope.py b/benchmark/scripts/benchmark_rope.py index 265fe703a..b505c6fe9 100644 --- a/benchmark/scripts/benchmark_rope.py +++ b/benchmark/scripts/benchmark_rope.py @@ -14,6 +14,9 @@ ) from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -38,23 +41,23 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput ) head_dim = hidden_size // num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) def fwd(): @@ -122,23 +125,23 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu ) head_dim = hidden_size // num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) def full(): diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py index 457f6f2e8..a8ee48dea 100644 --- a/benchmark/scripts/benchmark_simpo_loss.py +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -13,6 +13,9 @@ ) from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -66,7 +69,6 @@ def bench_memory_fused_linear_simpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) @@ -107,8 +109,6 @@ def bench_speed_fused_linear_simpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index 08689d24e..5feedb557 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -12,6 +12,9 @@ ) from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -33,7 +36,6 @@ def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) @@ -103,7 +105,6 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index 1d147b51b..6fa80a888 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -11,6 +11,10 @@ import torch +from liger_kernel.utils import infer_device + +device = infer_device() + LIGER_KERNEL_VERSION = version("liger-kernel") QUANTILES = [0.5, 0.2, 0.8] @@ -88,10 +92,10 @@ def _test_memory( total_mem = [] for _ in range(_iter): - torch.cuda.memory.reset_peak_memory_stats() + getattr(torch, device).memory.reset_peak_memory_stats() func() # Convert to MB - mem = torch.cuda.max_memory_allocated() / 2**20 + mem = getattr(torch, device).max_memory_allocated() / 2**20 total_mem.append(mem) total_mem = torch.tensor(total_mem, dtype=torch.float) @@ -141,8 +145,9 @@ def get_gpu_name(): """ Returns the current GPU name, formatted to serve as a directory name """ - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) + torch_device = getattr(torch, device) + if torch_device.is_available(): + gpu_name = torch_device.get_device_name(torch_device.current_device()) return gpu_name else: raise Exception("Benchmarks can only be run on GPU.") diff --git a/examples/huggingface/callback.py b/examples/huggingface/callback.py index 9582c81fd..c612a79a9 100644 --- a/examples/huggingface/callback.py +++ b/examples/huggingface/callback.py @@ -5,6 +5,8 @@ import transformers from transformers import TrainerControl, TrainerState, TrainingArguments +from liger_kernel.utils import infer_device + # https://simple.wikipedia.org/wiki/Byte # For memory, we use binary system M_BIN_UNIT = 2**20 @@ -111,6 +113,7 @@ def __init__( self.time = Time() self.memory = Memory() self.tps = TPS() + self.device = infer_device() def on_init_end( self, @@ -171,7 +174,7 @@ def on_step_begin( several inputs. """ # memory - torch.cuda.reset_peak_memory_stats() + getattr(torch, self.device).reset_peak_memory_stats() # time self.state.step_start_time = time.perf_counter() @@ -218,8 +221,12 @@ def on_step_end( ) # memory - step_peak_memory_allocated = torch.cuda.memory.max_memory_allocated() - step_peak_memory_reserved = torch.cuda.memory.max_memory_reserved() + step_peak_memory_allocated = getattr( + torch, self.device + ).memory.max_memory_allocated() + step_peak_memory_reserved = getattr( + torch, self.device + ).memory.max_memory_reserved() self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory diff --git a/examples/lightning/training.py b/examples/lightning/training.py index f70e9aac1..6bf068d1b 100644 --- a/examples/lightning/training.py +++ b/examples/lightning/training.py @@ -15,6 +15,7 @@ from trl import DataCollatorForCompletionOnlyLM from liger_kernel.transformers import AutoLigerKernelForCausalLM +from liger_kernel.utils import infer_device _RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"} QUESTION = "" @@ -263,10 +264,15 @@ def train(): strategy = "auto" precision = "bf16-true" + device = infer_device() trainer = pl.Trainer( - accelerator="cuda", + accelerator=device, strategy=strategy, - devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu, + devices=( + getattr(torch, device).device_count() + if args.num_gpu is None + else args.num_gpu + ), default_root_dir=args.output_dir, log_every_n_steps=1, max_epochs=1, diff --git a/examples/medusa/callback.py b/examples/medusa/callback.py index ef4c38f1e..135f46f0b 100644 --- a/examples/medusa/callback.py +++ b/examples/medusa/callback.py @@ -7,6 +7,8 @@ from accelerate.utils.constants import FSDP_SHARDING_STRATEGY from transformers import TrainerControl, TrainerState, TrainingArguments +from liger_kernel.utils import infer_device + # https://simple.wikipedia.org/wiki/Byte # For memory, we use binary system M_BIN_UNIT = 2**20 @@ -137,6 +139,7 @@ def __init__( self.memory = Memory() self.tps = TPS() self.mfu = MFU() + self.device = infer_device() def on_init_end( self, @@ -198,7 +201,7 @@ def on_step_begin( several inputs. """ # memory - torch.cuda.reset_peak_memory_stats() + getattr(torch, self.device).reset_peak_memory_stats() # time self.state.step_start_time = time.perf_counter() @@ -247,8 +250,12 @@ def on_step_end( ) # memory - step_peak_memory_allocated = torch.cuda.memory.max_memory_allocated() - step_peak_memory_reserved = torch.cuda.memory.max_memory_reserved() + step_peak_memory_allocated = getattr( + torch, self.device + ).memory.max_memory_allocated() + step_peak_memory_reserved = getattr( + torch, self.device + ).memory.max_memory_reserved() self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory @@ -381,7 +388,7 @@ def _get_gpu_peak_tflops(precision_bits: int = 16): if precision_bits not in {16, 32}: raise Exception(f"Precision bits {precision_bits} is not supported") - device_name = torch.cuda.get_device_name() + device_name = getattr(torch, infer_device()).get_device_name() if "A100" in device_name: # data from https://www.nvidia.com/en-us/data-center/a100/ diff --git a/src/liger_kernel/__init__.py b/src/liger_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index 75df1f6ba..70c372237 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -180,8 +180,13 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): dY = dY.view(-1, dim) n_rows, n_cols = dY.shape + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index 633a3275b..fff199a93 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -264,6 +264,7 @@ def rms_norm_backward( dY = dY.view(-1, dim) n_rows, n_cols = dY.shape + sm_count = 1 if X.device.type == "cuda": sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count elif X.device.type == "xpu": diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index 4a24223d0..d87adac44 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -20,6 +20,8 @@ import triton.language as tl from packaging.version import Version +from liger_kernel.utils import infer_device + def is_hip() -> bool: return torch.version.hip is not None @@ -69,10 +71,11 @@ def compare_version(package: str, operator: Callable, target: str): def get_amp_custom_fwd_bwd() -> Callable: + device = infer_device() if compare_version("torch", operator.ge, "2.4.0"): return ( - functools.partial(torch.amp.custom_fwd, device_type="cuda"), - functools.partial(torch.amp.custom_bwd, device_type="cuda"), + functools.partial(torch.amp.custom_fwd, device_type=device), + functools.partial(torch.amp.custom_bwd, device_type=device), ) return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd diff --git a/src/liger_kernel/utils.py b/src/liger_kernel/utils.py new file mode 100644 index 000000000..0a6d5feba --- /dev/null +++ b/src/liger_kernel/utils.py @@ -0,0 +1,13 @@ +import torch + + +def infer_device(): + """ + Get current device name based on available devices + """ + if torch.cuda.is_available(): + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + else: + return "cpu" diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 6f9305ec8..1bdb7dc83 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -8,6 +8,9 @@ from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.functional import liger_fused_linear_cpo +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -166,15 +169,15 @@ def test_correctness( ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device="cuda", dtype=dtype + V, H, device=device, dtype=dtype ) if bias: torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device="cuda", dtype=dtype + V, device=device, dtype=dtype ) - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -185,7 +188,7 @@ def test_correctness( B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -235,7 +238,7 @@ def test_correctness( def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B = 2 * B - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -246,15 +249,15 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B, T, ), - device="cuda", + device=device, dtype=torch.long, ) - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 2f9d1d94e..9b17b6d05 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -7,6 +7,9 @@ from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction from liger_kernel.chunked_loss.functional import liger_fused_linear_dpo +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -148,22 +151,22 @@ def test_correctness( ) torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn( - V, H, device="cuda", dtype=dtype + V, H, device=device, dtype=dtype ) torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = ( - torch.randn(V, H, device="cuda", dtype=dtype) + torch.randn(V, H, device=device, dtype=dtype) ) if bias: torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn( - V, device="cuda", dtype=dtype + V, device=device, dtype=dtype ) if ref_bias: torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = ( - torch.randn(V, device="cuda", dtype=dtype) + torch.randn(V, device=device, dtype=dtype) ) - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -174,7 +177,7 @@ def test_correctness( B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -225,7 +228,7 @@ def test_correctness( def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): B = 2 * B - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -236,23 +239,23 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref B, T, ), - device="cuda", + device=device, dtype=torch.long, ) - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _ref_weight = torch.randn(V, H, device="cuda", dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - _ref_bias = torch.randn(V, device="cuda", dtype=dtype) if ref_bias else None + _ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 41e6c9421..4c95634ed 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -8,6 +8,9 @@ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss from liger_kernel.chunked_loss.functional import liger_fused_linear_orpo from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -137,15 +140,15 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, ) torch_lm_head_orpo.lin.weight.data = liger_lm_head_orpo.lin.weight.data = ( - torch.randn(V, H, device="cuda", dtype=dtype) + torch.randn(V, H, device=device, dtype=dtype) ) if bias: torch_lm_head_orpo.lin.bias.data = liger_lm_head_orpo.lin.bias.data = ( - torch.randn(V, device="cuda", dtype=dtype) + torch.randn(V, device=device, dtype=dtype) ) - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -156,7 +159,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -206,7 +209,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B = 2 * B - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -217,15 +220,15 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B, T, ), - device="cuda", + device=device, dtype=torch.long, ) - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 89658b69c..901247191 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -7,6 +7,9 @@ from liger_kernel.chunked_loss import LigerFusedLinearSimPOLoss from liger_kernel.chunked_loss.functional import liger_fused_linear_simpo from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -80,15 +83,15 @@ def test_correctness( ) torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = ( - torch.randn(V, H, device="cuda", dtype=dtype) + torch.randn(V, H, device=device, dtype=dtype) ) if bias: torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = ( - torch.randn(V, device="cuda", dtype=dtype) + torch.randn(V, device=device, dtype=dtype) ) - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -99,7 +102,7 @@ def test_correctness( B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -149,7 +152,7 @@ def test_correctness( def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B = 2 * B - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -160,15 +163,15 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B, T, ), - device="cuda", + device=device, dtype=torch.long, ) - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 5c30349ae..051effcfa 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -60,6 +60,10 @@ except ImportError: QWEN2_VL_AVAILABLE = False +from liger_kernel.utils import infer_device + +device = infer_device() + MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_llama, @@ -427,7 +431,7 @@ def run_mini_model( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to("cuda") + model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader( diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index bb9d8e712..07ddd9493 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -58,6 +58,10 @@ except ImportError: MLLAMA_AVAILABLE = False +from liger_kernel.utils import infer_device + +device = infer_device() + torch.use_deterministic_algorithms(True) # Only setting torch.use_deterministic_algorithms(True) throws the following error: @@ -333,7 +337,7 @@ def run_mini_model_multimodal( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to("cuda") + model = create_model(model_name).to(dtype).to(device) model.gradient_checkpointing_enable() train_dataset = create_multimodal_dataset(model_name) diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index 0b183e3d3..e7672c4a4 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -60,6 +60,10 @@ except ImportError: QWEN2_VL_AVAILABLE = False +from liger_kernel.utils import infer_device + +device = infer_device() + MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_llama, @@ -427,7 +431,7 @@ def run_mini_model( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to("cuda") + model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader( train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 82edc98fa..28e3ec5dc 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -11,7 +11,9 @@ ) from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy +from liger_kernel.utils import infer_device +device = infer_device() set_seed(42) @@ -74,11 +76,11 @@ def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, r torch.manual_seed(0) torch_ce = CrossEntropyLoss(reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -95,11 +97,11 @@ def _test_correctness_with_ignore_index_once( torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( @@ -126,11 +128,11 @@ def _test_correctness_with_label_smoothing_once( torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -150,11 +152,11 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( ignore_index=ignore_index, label_smoothing=label_smoothing ) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( @@ -181,12 +183,12 @@ def _test_correctness_with_softcap_once( torch_ce = CrossEntropyLoss(reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar # upcasting to match liger's casting strategy _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # downcasting to original dtype output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) @@ -217,11 +219,11 @@ def _test_correctness_with_z_loss_once( dtype=dtype, ) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) if return_z_loss: output, z_output = torch_ce(_input, target) output2, z_output2 = target_ce(_input2, target) @@ -266,11 +268,11 @@ def _test_correctness_with_z_loss_with_other_params_once( dtype=dtype, ) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( @@ -305,11 +307,11 @@ def _test_correctness_not_last_layer_once( torch_ce = CrossEntropyLoss(reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -333,12 +335,12 @@ def _test_correctness_functional( rtol, ): - _input = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B * T, V, device=device, dtype=dtype) * scalar x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) y1, y1_z = liger_cross_entropy( x1, @@ -733,12 +735,12 @@ def test_float32_internal(): reduction = "mean" # Initialize input tensors - X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda") - Y = torch.randint(0, n_cols, (batch_size,), device="cuda") + X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device=device) + Y = torch.randint(0, n_cols, (batch_size,), device=device) # Run kernel for bfloat16 X_bf16 = X_init.clone() - loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_bf16, X_stride=X_bf16.stride(-2), @@ -762,7 +764,7 @@ def test_float32_internal(): # Run kernel for float32 X_fp32 = X_init.float() - loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_fp32, X_stride=X_fp32.stride(-2), diff --git a/test/transformers/test_embedding.py b/test/transformers/test_embedding.py index 998a544c5..416784d0f 100644 --- a/test/transformers/test_embedding.py +++ b/test/transformers/test_embedding.py @@ -3,6 +3,9 @@ from torch.nn import Embedding from liger_kernel.transformers.experimental.embedding import LigerEmbedding +from liger_kernel.utils import infer_device + +device = infer_device() SLEEP_SECONDS = 0.1 @@ -27,7 +30,7 @@ @pytest.mark.parametrize( "dtype, atol, rtol, device", [ - (torch.float32, 1e-6, 1e-5, "cuda"), + (torch.float32, 1e-6, 1e-5, device), ], ) def test_embedding_correctness( diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index bc210ca77..a6bcd4d8b 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -12,6 +12,9 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -142,7 +145,6 @@ def test_correctness( atol, rtol, ): - device = "cuda" torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, @@ -233,8 +235,6 @@ def test_correctness( ) @pytest.mark.parametrize("bias", [True, False]) def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): - device = "cuda" - _input = torch.randn(B * T, H, device=device, dtype=dtype) * scalar x1 = _input.detach().clone().requires_grad_(True) x2 = _input.detach().clone().requires_grad_(True) @@ -277,7 +277,6 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): ], ) def test_amp(B, T, H, V, cast_dtype, atol, rtol): - device = "cuda" dtype = torch.float32 torch_lm_head_ce = TorchLMHeadCE( H=H, @@ -307,13 +306,13 @@ def test_amp(B, T, H, V, cast_dtype, atol, rtol): target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) - with torch.autocast(device_type="cuda", dtype=cast_dtype): + with torch.autocast(device_type=device, dtype=cast_dtype): output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target) assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - with torch.autocast(device_type="cuda", dtype=cast_dtype): + with torch.autocast(device_type=device, dtype=cast_dtype): output1.backward() output2.backward() diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py index 0d011f2a0..75f4d775c 100644 --- a/test/transformers/test_fused_linear_jsd.py +++ b/test/transformers/test_fused_linear_jsd.py @@ -7,6 +7,9 @@ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.transformers.functional import liger_fused_linear_jsd from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) @@ -110,7 +113,6 @@ def forward(self, student_input, teacher_input, label=None): ], ) def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -187,7 +189,6 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): def test_correctness_with_ignore_index( B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -271,8 +272,6 @@ def test_correctness_with_ignore_index( def test_correctness_functional( B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): - device = "cuda" - # init the linear in all FusedLinearJSDs with the same weights _weight = torch.rand(V, H // 2, device=device, dtype=dtype) _weight1 = _weight.detach().clone().requires_grad_(True) @@ -350,7 +349,6 @@ def test_correctness_functional( def test_correctness_all_ignored( B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -415,7 +413,6 @@ def test_amp(autocast_dtype, atol, rtol): ignore_index = -100 temperature = 1.0 beta = 0.5 - device = "cuda" dtype = torch.float32 torch_lm_head_jsd = TorchLMHeadJSD( H=H, @@ -460,7 +457,7 @@ def test_amp(autocast_dtype, atol, rtol): ] # Randomly select indices label[indices_to_assign] = ignore_index - with torch.autocast(device_type="cuda", dtype=autocast_dtype): + with torch.autocast(device_type=device, dtype=autocast_dtype): output1 = torch_lm_head_jsd(_input1, teacher_input, label) output2 = liger_lm_head_jsd(_input2, teacher_input, label) diff --git a/test/transformers/test_geglu.py b/test/transformers/test_geglu.py index 184c971d2..0d5919729 100644 --- a/test/transformers/test_geglu.py +++ b/test/transformers/test_geglu.py @@ -8,6 +8,9 @@ from liger_kernel.ops.geglu import LigerGELUMulFunction from liger_kernel.transformers.functional import liger_geglu from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() LLAMA_CONFIG = LlamaConfig( hidden_size=4096, @@ -42,22 +45,22 @@ ], ) def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): - _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) # initialize weights - G = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - U = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) - llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype) llama_mlp.gate_proj.weight.data = G.T llama_mlp.up_proj.weight.data = U.T llama_mlp.down_proj.weight.data = D.T - liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype) liger_mlp.gate_proj.weight.data = G.T liger_mlp.up_proj.weight.data = U.T liger_mlp.down_proj.weight.data = D.T @@ -121,8 +124,8 @@ def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, ], ) def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol): - _input = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) - _b = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) + _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 32419ed6a..4f53444d5 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -4,6 +4,9 @@ import torch from liger_kernel.transformers.group_norm import LigerGroupNorm +from liger_kernel.utils import infer_device + +device = infer_device() random_batch_size = random.randint(1, 16) random_num_groups = random.randint(1, 32) @@ -32,17 +35,17 @@ def test_liger_group_norm( torch.manual_seed(0) _tensor = torch.randn( - batch_size, num_channels, hidden_size, dtype=dtype, device="cuda" + batch_size, num_channels, hidden_size, dtype=dtype, device=device ) liger_x = _tensor.clone().detach().requires_grad_(True) torch_x = _tensor.clone().detach().requires_grad_(True) - liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() + liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device) torch_ln = ( torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6) .to(dtype) - .cuda() + .to(device) ) with torch.no_grad(): diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 23087d621..86f4e3388 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -7,6 +7,9 @@ from liger_kernel.transformers.functional import liger_jsd from liger_kernel.transformers.jsd import LigerJSD, LigerJSDFunction +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) @@ -91,7 +94,7 @@ def _test_correctness_once( atol, rtol, is_last_layer=True, - device="cuda", + device=device, ): torch_jsd = JSD(dtype=dtype) @@ -133,7 +136,7 @@ def _test_correctness_with_beta_once( atol, rtol, is_last_layer=True, - device="cuda", + device=device, ): torch_jsd = JSD(beta=beta, dtype=dtype) @@ -170,7 +173,7 @@ def _test_correctness_with_ignore_index_once( dtype, atol, rtol, - device="cuda", + device=device, ): torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) @@ -205,7 +208,7 @@ def _test_correctness_with_ignore_index_once( def _test_correctness_functional( - B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device="cuda" + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device=device ): input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True @@ -305,7 +308,7 @@ def test_correctness_with_all_indices_ignored( dtype=torch.bfloat16, atol=1e-3, rtol=1e-3, - device="cuda", + device=device, ): ignore_index = -100 torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) diff --git a/test/transformers/test_kl_div.py b/test/transformers/test_kl_div.py index 5cc3eba6a..1f0c2d5ad 100644 --- a/test/transformers/test_kl_div.py +++ b/test/transformers/test_kl_div.py @@ -5,6 +5,9 @@ from torch.nn import KLDivLoss from liger_kernel.transformers.kl_div import LigerKLDIVLoss +from liger_kernel.utils import infer_device + +device = infer_device() _SHAPE_PARAMS = ( "B, T, V", @@ -43,7 +46,7 @@ def _test_correctness_once( reduction, log_target, is_last_layer=True, - device="cuda", + device=device, ): torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) diff --git a/test/transformers/test_layer_norm.py b/test/transformers/test_layer_norm.py index f570e7b21..4ac152440 100644 --- a/test/transformers/test_layer_norm.py +++ b/test/transformers/test_layer_norm.py @@ -4,6 +4,9 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction from liger_kernel.transformers.functional import liger_layer_norm from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.utils import infer_device + +device = infer_device() @pytest.mark.parametrize( @@ -22,13 +25,13 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): torch.manual_seed(0) - x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) liger_x = x.clone().requires_grad_(True) torch_x = x.clone().requires_grad_(True) - liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).cuda() - torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).cuda() + liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) + torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) with torch.no_grad(): torch_ln.weight.copy_(liger_ln.weight) @@ -68,17 +71,17 @@ def test_liger_layer_norm_functional( ): torch.manual_seed(0) - input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) x1 = input.clone().requires_grad_(True) x2 = input.clone().requires_grad_(True) - w = torch.randn(hidden_size, device="cuda", dtype=dtype) + w = torch.randn(hidden_size, device=device, dtype=dtype) w1 = w.clone().requires_grad_(True) w2 = w.clone().requires_grad_(True) - b = torch.randn(hidden_size, device="cuda", dtype=dtype) + b = torch.randn(hidden_size, device=device, dtype=dtype) b1 = b.clone().requires_grad_(True) b2 = b.clone().requires_grad_(True) diff --git a/test/transformers/test_mm_int8int2.py b/test/transformers/test_mm_int8int2.py index d7d13a958..a2458523a 100644 --- a/test/transformers/test_mm_int8int2.py +++ b/test/transformers/test_mm_int8int2.py @@ -6,6 +6,9 @@ pack_weights, unpack_weights, ) +from liger_kernel.utils import infer_device + +device = infer_device() # input_features = size*4 when the weight matrix is unpacked @@ -38,7 +41,7 @@ @pytest.mark.parametrize( "atol, rtol, device", [ - (1e-2, 1e-2, "cuda"), + (1e-2, 1e-2, device), ], ) def test_kernel_correctness( @@ -95,7 +98,7 @@ def test_kernel_correctness( @pytest.mark.parametrize( "device", [ - "cuda", + device, ], ) def test_unpack_pack_correctness(out_features, size, device): diff --git a/test/transformers/test_qwen2vl_mrope.py b/test/transformers/test_qwen2vl_mrope.py index fb3f4b80e..bfc1f9ac2 100644 --- a/test/transformers/test_qwen2vl_mrope.py +++ b/test/transformers/test_qwen2vl_mrope.py @@ -16,6 +16,9 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction from liger_kernel.transformers.functional import liger_qwen2vl_mrope from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() @pytest.mark.skipif( @@ -49,16 +52,16 @@ def test_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol ): - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) @@ -70,7 +73,7 @@ def test_correctness( k2 = _tensor_k.clone().requires_grad_(True) # NOTE: this position ids distribution is different from the real one, just to test op correctness - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -81,8 +84,8 @@ def test_correctness( # validate backward pass dq, dk = ( - torch.randn_like(hf_q, device="cuda"), - torch.randn_like(hf_k, device="cuda").to(dtype), + torch.randn_like(hf_q, device=device), + torch.randn_like(hf_k, device=device).to(dtype), ) q1_grad, k1_grad = torch.autograd.grad( @@ -116,8 +119,8 @@ def test_correctness( def test_functional_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol ): - _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device="cuda", dtype=dtype) - _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device="cuda", dtype=dtype) + _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) + _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) q1 = _q.clone().requires_grad_(True) q2 = _q.clone().requires_grad_(True) @@ -125,9 +128,9 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section) diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 3fce0dcaa..dc0c78643 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -1,10 +1,5 @@ import os -from test.utils import ( - assert_verbose_allclose, - infer_device, - set_seed, - supports_bfloat16, -) +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 import pytest import torch @@ -13,10 +8,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.transformers.functional import liger_rms_norm from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) torch.use_deterministic_algorithms(True) -device = infer_device() + # Only setting torch.use_deterministic_algorithms(True) might throw the following error: # RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, # but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 8e1198025..74080b57f 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -10,6 +10,9 @@ from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.transformers.functional import liger_rope from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() SLEEP_SECONDS = 0.1 @@ -46,16 +49,16 @@ def test_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol ): - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) @@ -66,7 +69,7 @@ def test_correctness( q2 = _tensor_q.clone().requires_grad_(True) k2 = _tensor_k.clone().requires_grad_(True) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -77,8 +80,8 @@ def test_correctness( # validate backward pass dq, dk = ( - torch.randn_like(hf_q, device="cuda"), - torch.randn_like(hf_k, device="cuda").to(dtype), + torch.randn_like(hf_q, device=device), + torch.randn_like(hf_k, device=device).to(dtype), ) q1_grad, k1_grad = torch.autograd.grad( @@ -111,8 +114,8 @@ def test_correctness( def test_functional_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol ): - _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device="cuda", dtype=dtype) - _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device="cuda", dtype=dtype) + _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) + _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) q1 = _q.clone().requires_grad_(True) q2 = _q.clone().requires_grad_(True) @@ -120,9 +123,9 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin) diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index e1f4f092b..154d5061f 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -10,6 +10,9 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction from liger_kernel.transformers.functional import liger_swiglu from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() LLAMA_CONFIG = LlamaConfig( hidden_size=4096, @@ -52,22 +55,22 @@ def test_correctness_llamamlp( bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol ): - _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) # initialize weights - G = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - U = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) - llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype) llama_mlp.gate_proj.weight.data = G.T llama_mlp.up_proj.weight.data = U.T llama_mlp.down_proj.weight.data = D.T - liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype) liger_mlp.gate_proj.weight.data = G.T liger_mlp.up_proj.weight.data = U.T liger_mlp.down_proj.weight.data = D.T @@ -132,20 +135,20 @@ def test_correctness_llamamlp( def test_correctness_phi3mlp( bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol ): - _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) # initialize weights - GU = torch.randn(hidden_size, intermediate_size * 2, device="cuda", dtype=dtype) - D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + GU = torch.randn(hidden_size, intermediate_size * 2, device=device, dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) - phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to("cuda").to(dtype) + phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to(device).to(dtype) phi3_mlp.gate_up_proj.weight.data = GU.T phi3_mlp.down_proj.weight.data = D.T - liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to("cuda").to(dtype) + liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to(device).to(dtype) liger_mlp.gate_up_proj.weight.data = GU.T liger_mlp.down_proj.weight.data = D.T @@ -193,8 +196,8 @@ def test_correctness_phi3mlp( ], ) def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol): - _input = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) - _b = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) + _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) diff --git a/test/utils.py b/test/utils.py index f209a0388..e8383d659 100644 --- a/test/utils.py +++ b/test/utils.py @@ -16,20 +16,9 @@ from transformers import PretrainedConfig, PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding +from liger_kernel.utils import infer_device -def infer_device(): - """ - Get current device name based on available devices - """ - if torch.cuda.is_available(): - return "cuda" - elif torch.xpu.is_available(): - return "xpu" - else: - return "cpu" - - -torch_device = infer_device() +device = infer_device() def set_seed(seed=42): @@ -43,7 +32,7 @@ def set_seed(seed=42): # PyTorch random seed torch.manual_seed(seed) - if torch_device == "cuda": + if device == "cuda": # If you are using CUDA torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. @@ -51,8 +40,8 @@ def set_seed(seed=42): # PyTorch backend settings torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - elif torch_device == "xpu": - # If you ware using intel GPU + elif device == "xpu": + # If you are using XPU torch.xpu.manual_seed(seed) torch.xpu.manual_seed_all(seed) @@ -225,9 +214,9 @@ def train_bpe_tokenizer(special_tokens: List[str], unk_token: str = "<|unk|>"): def supports_bfloat16(): - if torch_device == "cuda": + if device == "cuda": return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer - elif torch_device == "xpu": + elif device == "xpu": return True else: return False