Skip to content

Commit

Permalink
Fix kernel arugments in flce
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Dec 22, 2024
1 parent 7ed4dd9 commit 5535a60
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def fused_linear_cross_entropy_forward(
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=ce_weight, # dummy ptr, not used
weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
Expand All @@ -115,7 +115,7 @@ def fused_linear_cross_entropy_forward(
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
Expand Down

0 comments on commit 5535a60

Please sign in to comment.