Skip to content

Commit

Permalink
Optimize CE Loss by casting dtype to float32 inside kernel (#406)
Browse files Browse the repository at this point in the history
## Summary
This PR is essentially a reproduction of #238 along with the necessary
changes to merge the code with main.

## Testing Done

- Hardware Type: A100-SMX4 40GB
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence
  • Loading branch information
pramodith authored Nov 22, 2024
1 parent d907ec0 commit 90fb5e4
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 42 deletions.
48 changes: 24 additions & 24 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -179,30 +179,30 @@ embedding,torch_compile,full,memory,MB,V,embedding dimension,16384,1536.125,1536
embedding,torch_compile,full,memory,MB,V,embedding dimension,32768,3072.125,3072.125,3072.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1
embedding,torch_compile,full,memory,MB,V,embedding dimension,65536,6144.125,6144.125,6144.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1
embedding,torch_compile,full,memory,MB,V,embedding dimension,131072,12288.125,12288.125,12288.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,4096,111.0453109741211,111.0453109741211,111.0453109741211,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,8192,161.67047119140625,161.67047119140625,161.67047119140625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,16384,264.1196594238281,264.1196594238281,264.1196594238281,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,32768,492.00390625,492.00390625,492.00390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:13,0.2.1
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,4096,19.030847549438477,18.991506576538086,19.17319679260254,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,8192,37.99166488647461,37.977237701416016,38.0060920715332,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,16384,76.0440673828125,76.0440673828125,76.0440673828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,32768,151.54771423339844,151.54771423339844,151.54771423339844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:35:45,0.2.1
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,4096,113.0862045288086,113.0862045288086,113.0862045288086,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,8192,166.76512145996094,166.76512145996094,166.76512145996094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,16384,270.321044921875,270.321044921875,270.321044921875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,32768,495.4810485839844,495.4810485839844,495.4810485839844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:22,0.2.1
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,4096,55.55372619628906,55.55372619628906,55.55372619628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,8192,111.50227355957031,111.50227355957031,111.50227355957031,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,16384,223.53219604492188,223.53219604492188,223.53219604492188,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,32768,457.7295227050781,457.7295227050781,457.7295227050781,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:36:56,0.2.1
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,4096,4245.546875,4245.546875,4245.546875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,8192,4466.96875,4466.96875,4466.96875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,16384,4910.4375,4910.4375,4910.4375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,32768,5794.625,5794.625,5794.625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:37:34,0.2.1
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,4096,6092.2822265625,6092.2822265625,6092.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,8192,9162.3134765625,9162.3134765625,9162.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,16384,15302.3759765625,15302.3759765625,15302.3759765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,32768,27582.5,27582.5,27582.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:02,0.2.1
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,4096,119.52153778076172,119.52153778076172,119.52153778076172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,8192,168.08563232421875,168.08563232421875,168.08563232421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,16384,274.07342529296875,274.07342529296875,274.07342529296875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2
fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,32768,508.4652099609375,508.4652099609375,508.4652099609375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,4096,20.911680221557617,20.90903663635254,20.915321350097656,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,8192,37.97203063964844,37.9546012878418,37.989463806152344,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,16384,76.39142608642578,76.39142608642578,76.39142608642578,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2
fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,32768,151.91404724121094,151.91404724121094,151.91404724121094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,4096,121.43059539794922,121.43059539794922,121.43059539794922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,8192,166.70867919921875,166.70867919921875,166.70867919921875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,16384,277.1166687011719,277.1166687011719,277.1166687011719,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2
fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,32768,511.0638732910156,511.0638732910156,511.0638732910156,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,4096,55.96684646606445,55.96684646606445,55.96684646606445,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,8192,111.45471954345703,111.45471954345703,111.45471954345703,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,16384,220.7836151123047,220.7836151123047,220.7836151123047,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2
fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,32768,452.4712829589844,452.4712829589844,452.4712829589844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,4096,4245.5478515625,4245.5478515625,4245.5478515625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,8192,4466.9697265625,4466.9697265625,4466.9697265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,16384,4910.4384765625,4910.4384765625,4910.4384765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2
fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,32768,5794.6259765625,5794.6259765625,5794.6259765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,4096,6092.2822265625,6092.2822265625,6092.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,8192,9162.3134765625,9162.3134765625,9162.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,16384,15302.3759765625,15302.3759765625,15302.3759765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2
fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,32768,27582.5,27582.5,27582.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2
geglu,liger,full,speed,ms,T,sequence length,1024,30.03536033630371,30.03536033630371,30.03536033630371,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1
geglu,liger,full,speed,ms,T,sequence length,2048,54.04060745239258,54.04060745239258,54.04060745239258,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1
geglu,liger,full,speed,ms,T,sequence length,4096,108.52435302734375,108.52435302734375,108.52435302734375,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1
Expand Down
18 changes: 12 additions & 6 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def liger_cross_entropy_kernel(
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
ori_X_y = tl.load(
X_ptr + y
ori_X_y = tl.load(X_ptr + y).cast(
tl.float32
) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
Expand All @@ -106,8 +106,11 @@ def liger_cross_entropy_kernel(
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)
Expand Down Expand Up @@ -141,8 +144,11 @@ def liger_cross_entropy_kernel(
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
Expand Down
11 changes: 0 additions & 11 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward(
reduction="mean",
softcap=None,
):
dtype = _input.dtype
device = _input.device

# inputs have shape: BT x H
Expand Down Expand Up @@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward(
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
n_non_ignore = (target_chunk != ignore_index).sum().item()

# when doing CE, use the upcasted precision
logits_chunk = logits_chunk.float()

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
target_chunk = target_chunk.contiguous()
Expand All @@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward(
num_warps=32 if not is_hip() else 16,
)

# gradient of logits_chunk is computed in-place by the above triton kernel.
# Following HuggingFace model source code, we do the forward and backward
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
# Propagating to lm_head's backward, we'll switch back to the original dtype.
logits_chunk = logits_chunk.to(dtype)

# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
Expand Down
77 changes: 76 additions & 1 deletion test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.ops.cross_entropy import (
LigerCrossEntropyFunction,
liger_cross_entropy_kernel,
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy

Expand Down Expand Up @@ -711,3 +714,75 @@ def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rto
_test_correctness_not_last_layer_once(
liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol
)


def test_float32_internal():
"""
This test validates that the internal softmax calculations occur in float32,
even if the input dtype is bfloat16.
"""
# Set up test parameters
batch_size = 4
n_cols = 128256
n_non_ignore = batch_size
ignore_index = -100
label_smoothing = 0.0
lse_square_scale = 0.0
softcap = 0.0
BLOCK_SIZE = 32768
reduction = "mean"

# Initialize input tensors
X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda")
Y = torch.randint(0, n_cols, (batch_size,), device="cuda")

# Run kernel for bfloat16
X_bf16 = X_init.clone()
loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda")
liger_cross_entropy_kernel[(batch_size,)](
X_ptr=X_bf16,
X_stride=X_bf16.stride(-2),
Y_ptr=Y,
Y_stride=Y.stride(-1),
z_loss_ptr=loss_bf16, # dummy ptr, not used
loss_ptr=loss_bf16,
loss_stride=loss_bf16.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

# Run kernel for float32
X_fp32 = X_init.float()
loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda")
liger_cross_entropy_kernel[(batch_size,)](
X_ptr=X_fp32,
X_stride=X_fp32.stride(-2),
Y_ptr=Y,
Y_stride=Y.stride(-1),
loss_ptr=loss_fp32,
z_loss_ptr=loss_fp32, # dummy ptr, not used
loss_stride=loss_fp32.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

torch.allclose(X_bf16, X_fp32.bfloat16())
torch.allclose(loss_bf16, loss_fp32)

0 comments on commit 90fb5e4

Please sign in to comment.