diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 494aade9f..41ec78a9d 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -33,7 +33,6 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( full_target.shape[0] // 2 ) - return loss @staticmethod diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 17b82d09c..ba6fbef81 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -127,13 +127,16 @@ def fused_linear_cross_entropy_forward( alpha=alpha, ) - loss = torch.sum(loss_1d) + if reduction == "none": + loss = loss_1d + else: + loss = torch.sum(loss_1d) return loss, grad_input, grad_weight, grad_bias def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. BT, H = grad_input.shape diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 08d126656..9f0d5d5e1 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -108,6 +108,8 @@ def forward(self, x, y): ("mean", 1.0, torch.float32, 1e-5, 5e-4), ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), ("sum", 1.0, torch.float32, 1e-3, 5e-2), + ("none", 1.0, torch.bfloat16, 5e-0, 5e1), + ("none", 1.0, torch.float32, 1e-3, 5e-2), ], ) @pytest.mark.parametrize("bias", [True, False]) @@ -185,8 +187,8 @@ def test_correctness( assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - output1.backward() - output2.backward() + output1.backward(gradient=torch.ones_like(output1)) + output2.backward(gradient=torch.ones_like(output2)) assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)