From 90fb5e4a3cb971ab996638596f07652404e719e5 Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:56:19 +0000 Subject: [PATCH] Optimize CE Loss by casting dtype to float32 inside kernel (#406) ## Summary This PR is essentially a reproduction of #238 along with the necessary changes to merge the code with main. ## Testing Done - Hardware Type: A100-SMX4 40GB - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --- benchmark/data/all_benchmark_data.csv | 48 ++++++------ src/liger_kernel/ops/cross_entropy.py | 18 +++-- .../ops/fused_linear_cross_entropy.py | 11 --- test/transformers/test_cross_entropy.py | 77 ++++++++++++++++++- 4 files changed, 112 insertions(+), 42 deletions(-) diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index ed25905cd..4e966cab2 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -179,30 +179,30 @@ embedding,torch_compile,full,memory,MB,V,embedding dimension,16384,1536.125,1536 embedding,torch_compile,full,memory,MB,V,embedding dimension,32768,3072.125,3072.125,3072.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 embedding,torch_compile,full,memory,MB,V,embedding dimension,65536,6144.125,6144.125,6144.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 embedding,torch_compile,full,memory,MB,V,embedding dimension,131072,12288.125,12288.125,12288.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 -fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,4096,111.0453109741211,111.0453109741211,111.0453109741211,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1 -fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,8192,161.67047119140625,161.67047119140625,161.67047119140625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1 -fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,16384,264.1196594238281,264.1196594238281,264.1196594238281,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1 -fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,32768,492.00390625,492.00390625,492.00390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1 -fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,4096,19.030847549438477,18.991506576538086,19.17319679260254,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1 -fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,8192,37.99166488647461,37.977237701416016,38.0060920715332,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1 -fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,16384,76.0440673828125,76.0440673828125,76.0440673828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1 -fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,32768,151.54771423339844,151.54771423339844,151.54771423339844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1 -fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,4096,113.0862045288086,113.0862045288086,113.0862045288086,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1 -fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,8192,166.76512145996094,166.76512145996094,166.76512145996094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1 -fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,16384,270.321044921875,270.321044921875,270.321044921875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1 -fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,32768,495.4810485839844,495.4810485839844,495.4810485839844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1 -fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,4096,55.55372619628906,55.55372619628906,55.55372619628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1 -fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,8192,111.50227355957031,111.50227355957031,111.50227355957031,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1 -fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,16384,223.53219604492188,223.53219604492188,223.53219604492188,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1 -fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,32768,457.7295227050781,457.7295227050781,457.7295227050781,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1 -fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,4096,4245.546875,4245.546875,4245.546875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1 -fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,8192,4466.96875,4466.96875,4466.96875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1 -fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,16384,4910.4375,4910.4375,4910.4375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1 -fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,32768,5794.625,5794.625,5794.625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1 -fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,4096,6092.2822265625,6092.2822265625,6092.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1 -fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,8192,9162.3134765625,9162.3134765625,9162.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1 -fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,16384,15302.3759765625,15302.3759765625,15302.3759765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1 -fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,32768,27582.5,27582.5,27582.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,4096,119.52153778076172,119.52153778076172,119.52153778076172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,8192,168.08563232421875,168.08563232421875,168.08563232421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,16384,274.07342529296875,274.07342529296875,274.07342529296875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,32768,508.4652099609375,508.4652099609375,508.4652099609375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,4096,20.911680221557617,20.90903663635254,20.915321350097656,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,8192,37.97203063964844,37.9546012878418,37.989463806152344,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,16384,76.39142608642578,76.39142608642578,76.39142608642578,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,32768,151.91404724121094,151.91404724121094,151.91404724121094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,4096,121.43059539794922,121.43059539794922,121.43059539794922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,8192,166.70867919921875,166.70867919921875,166.70867919921875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,16384,277.1166687011719,277.1166687011719,277.1166687011719,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,32768,511.0638732910156,511.0638732910156,511.0638732910156,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,4096,55.96684646606445,55.96684646606445,55.96684646606445,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,8192,111.45471954345703,111.45471954345703,111.45471954345703,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,16384,220.7836151123047,220.7836151123047,220.7836151123047,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,32768,452.4712829589844,452.4712829589844,452.4712829589844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,4096,4245.5478515625,4245.5478515625,4245.5478515625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,8192,4466.9697265625,4466.9697265625,4466.9697265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,16384,4910.4384765625,4910.4384765625,4910.4384765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,32768,5794.6259765625,5794.6259765625,5794.6259765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,4096,6092.2822265625,6092.2822265625,6092.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,8192,9162.3134765625,9162.3134765625,9162.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,16384,15302.3759765625,15302.3759765625,15302.3759765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,32768,27582.5,27582.5,27582.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 geglu,liger,full,speed,ms,T,sequence length,1024,30.03536033630371,30.03536033630371,30.03536033630371,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 geglu,liger,full,speed,ms,T,sequence length,2048,54.04060745239258,54.04060745239258,54.04060745239258,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 geglu,liger,full,speed,ms,T,sequence length,4096,108.52435302734375,108.52435302734375,108.52435302734375,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 8cc116a0e..2a980c69e 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -92,8 +92,8 @@ def liger_cross_entropy_kernel( # 3. [Online softmax] first pass: find max + sum m = float("-inf") # m is the max value. use the notation from the paper d = 0.0 # d is the sum. use the notation from the paper - ori_X_y = tl.load( - X_ptr + y + ori_X_y = tl.load(X_ptr + y).cast( + tl.float32 ) # we need to store the original value of X_y for the loss calculation if HAS_SOFTCAPPING: ori_X_y = softcap * tanh(ori_X_y / softcap) @@ -106,8 +106,11 @@ def liger_cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( - X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") - ) + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) if HAS_SOFTCAPPING: X_block = softcap * tanh(X_block / softcap) block_max = tl.max(X_block) @@ -141,8 +144,11 @@ def liger_cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( - X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") - ) + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) if HAS_SOFTCAPPING: intermediate = tanh(X_block / softcap) X_block = softcap * intermediate diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 963590d45..191a2b3d2 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward( reduction="mean", softcap=None, ): - dtype = _input.dtype device = _input.device # inputs have shape: BT x H @@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward( loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, n_non_ignore = (target_chunk != ignore_index).sum().item() - # when doing CE, use the upcasted precision - logits_chunk = logits_chunk.float() - # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() target_chunk = target_chunk.contiguous() @@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward( num_warps=32 if not is_hip() else 16, ) - # gradient of logits_chunk is computed in-place by the above triton kernel. - # Following HuggingFace model source code, we do the forward and backward - # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge. - # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194) - # Propagating to lm_head's backward, we'll switch back to the original dtype. - logits_chunk = logits_chunk.to(dtype) - # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 6ec73a1a3..82edc98fa 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -5,7 +5,10 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss -from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction +from liger_kernel.ops.cross_entropy import ( + LigerCrossEntropyFunction, + liger_cross_entropy_kernel, +) from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy @@ -711,3 +714,75 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto _test_correctness_not_last_layer_once( liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol ) + + +def test_float32_internal(): + """ + This test validates that the internal softmax calculations occur in float32, + even if the input dtype is bfloat16. + """ + # Set up test parameters + batch_size = 4 + n_cols = 128256 + n_non_ignore = batch_size + ignore_index = -100 + label_smoothing = 0.0 + lse_square_scale = 0.0 + softcap = 0.0 + BLOCK_SIZE = 32768 + reduction = "mean" + + # Initialize input tensors + X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda") + Y = torch.randint(0, n_cols, (batch_size,), device="cuda") + + # Run kernel for bfloat16 + X_bf16 = X_init.clone() + loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + liger_cross_entropy_kernel[(batch_size,)]( + X_ptr=X_bf16, + X_stride=X_bf16.stride(-2), + Y_ptr=Y, + Y_stride=Y.stride(-1), + z_loss_ptr=loss_bf16, # dummy ptr, not used + loss_ptr=loss_bf16, + loss_stride=loss_bf16.stride(-1), + n_cols=n_cols, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=0, # False + HAS_SOFTCAPPING=False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # Run kernel for float32 + X_fp32 = X_init.float() + loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + liger_cross_entropy_kernel[(batch_size,)]( + X_ptr=X_fp32, + X_stride=X_fp32.stride(-2), + Y_ptr=Y, + Y_stride=Y.stride(-1), + loss_ptr=loss_fp32, + z_loss_ptr=loss_fp32, # dummy ptr, not used + loss_stride=loss_fp32.stride(-1), + n_cols=n_cols, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=0, # False + HAS_SOFTCAPPING=False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + torch.allclose(X_bf16, X_fp32.bfloat16()) + torch.allclose(loss_bf16, loss_fp32)