From 3b6d09350fbbbea95792322808b8b1ee02ad12d0 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sat, 21 Dec 2024 16:54:27 +0800 Subject: [PATCH 1/6] Update fused_linear_cross_entropy.py --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 191a2b3d2..84e20d665 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -133,7 +133,10 @@ def fused_linear_cross_entropy_forward( alpha=alpha, ) - loss = torch.sum(loss_1d) + if reduction == "none": + loss = torch.sum(loss_1d) + else: + loss = loss_1d return loss, grad_input, grad_weight, grad_bias From 30fd95be562bccd2457fb4a53b7a8b90776b7cb0 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sat, 21 Dec 2024 17:01:02 +0800 Subject: [PATCH 2/6] fix typo --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 84e20d665..56a02e4e9 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -134,9 +134,9 @@ def fused_linear_cross_entropy_forward( ) if reduction == "none": - loss = torch.sum(loss_1d) - else: loss = loss_1d + else: + loss = torch.sum(loss_1d) return loss, grad_input, grad_weight, grad_bias From a57f5a24295dcc4011397e72ec5f3afcdde47fd5 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sat, 21 Dec 2024 17:19:18 +0800 Subject: [PATCH 3/6] checkstyle --- src/liger_kernel/chunked_loss/cpo_loss.py | 2 +- src/liger_kernel/chunked_loss/simpo_loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index dd84a4dbf..daf64b91b 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -36,7 +36,7 @@ def preference_loss_fn( """ logits = beta * (chosen_logps - rejected_logps) loss = ( - - F.logsigmoid(logits) * (1 - label_smoothing) + -F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing ).sum() / (full_target.shape[0] // 2) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 5d5867252..0890596ea 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -42,7 +42,7 @@ def preference_loss_fn( """ logits = beta * (chosen_logps - rejected_logps) - gamma loss = ( - - F.logsigmoid(logits) * (1 - label_smoothing) + -F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing ).sum() / (full_target.shape[0] // 2) From e8b0a3a624fbd4ea6e0650bbd516ea4936ece5c2 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sat, 21 Dec 2024 19:01:44 +0800 Subject: [PATCH 4/6] add test --- test/transformers/test_fused_linear_cross_entropy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index a6bcd4d8b..d451f35ae 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -114,6 +114,8 @@ def forward(self, x, y): ("mean", 1.0, torch.float32, 1e-5, 5e-4), ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), ("sum", 1.0, torch.float32, 1e-3, 5e-2), + ("none", 1.0, torch.bfloat16, 5e-0, 5e1), + ("none", 1.0, torch.float32, 1e-3, 5e-2), ], ) @pytest.mark.parametrize("bias", [True, False]) @@ -197,8 +199,8 @@ def test_correctness( assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - output1.backward() - output2.backward() + output1.backward(gradients=torch.ones_like(output1)) + output2.backward(gradients=torch.ones_like(output2)) assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) From bc0874c8e56872df767c16a61ccd50305a4e55f4 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sat, 21 Dec 2024 19:05:01 +0800 Subject: [PATCH 5/6] fix typo --- test/transformers/test_fused_linear_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index d451f35ae..fb19db13f 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -199,8 +199,8 @@ def test_correctness( assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - output1.backward(gradients=torch.ones_like(output1)) - output2.backward(gradients=torch.ones_like(output2)) + output1.backward(gradient=torch.ones_like(output1)) + output2.backward(gradient=torch.ones_like(output2)) assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) From ad414fd5b054288930db829633ef1c3dbb6484ad Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sat, 21 Dec 2024 21:30:59 +0800 Subject: [PATCH 6/6] use non-broadcastable torch.equal to prevent it returns a tensor --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 56a02e4e9..bad2c4600 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -144,7 +144,7 @@ def fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ): # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. BT, H = grad_input.shape