From 21bacccd107b40146e599e2c27af46d7d157f174 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Dec 2024 22:43:02 +0100 Subject: [PATCH] [Transformer] fix ORPO loss for MOE models (#479) ## Summary Add missing MOE loss when specified in the trainer. - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/transformers/trainer/orpo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 3605b9f1b..184430ac1 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels): outputs.last_hidden_state, concatenated_batch["concatenated_labels"], ) + # if aux_loss_enabled, add the aux_loss to the orpo_loss + if self.aux_loss_enabled: + orpo_loss += self.aux_loss_coef * outputs.aux_loss + return orpo_loss, aux_outputs def get_batch_loss_metrics(