From 4ddce2b78ae58a42daa520f3f7a7ec9caae88caa Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Fri, 22 Nov 2024 10:31:44 +0800 Subject: [PATCH] Format Signed-off-by: Austin Liu --- src/liger_kernel/chunked_loss/dpo_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index c7c9aae1b..0e929720e 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -99,6 +99,7 @@ def backward(ctx, grad_output): # Return these gradients, followed by None for the remaining inputs return d_input, d_weight, d_target, None, None, d_bias, None, None, None + class LigerFusedLinearDPOLoss(torch.nn.Module): """ Fused linear layer with DPO loss. @@ -132,4 +133,4 @@ def forward(self, lin_weight, _input, target, bias=None): self.beta, self.compute_nll_loss, self.compiled, - ) \ No newline at end of file + )