Skip to content

Commit

Permalink
ℹ️ XPU support for DPO (#2533)
Browse files Browse the repository at this point in the history
* add xpu support

* bug fix

* remove header

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* fix import and use the util

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent 4516772 commit d6a7e9d
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available
from transformers.utils import is_peft_available, is_torch_xpu_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
Expand All @@ -61,6 +61,7 @@
RunningMoments,
cap_exp,
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
Expand Down Expand Up @@ -741,7 +742,7 @@ def get_train_dataloader(self) -> DataLoader:
ref_rejected_logps.append(ref_rejected_logp.cpu())

# Unnecessary cache clearing to avoid OOM
torch.cuda.empty_cache()
empty_cache()
self.accelerator.free_memory()

all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
Expand Down Expand Up @@ -821,7 +822,8 @@ def null_ref_context(self):

def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict:
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
compte_ref_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
device_type = "xpu" if is_torch_xpu_available() else "cuda"
compte_ref_context_manager = amp.autocast(device_type) if self._peft_has_been_casted_to_bf16 else nullcontext()
with torch.no_grad(), compte_ref_context_manager:
if self.ref_model is None:
with self.null_ref_context():
Expand Down Expand Up @@ -1333,7 +1335,10 @@ def compute_loss(
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
device_type = "xpu" if is_torch_xpu_available() else "cuda"
compute_loss_context_manager = (
amp.autocast(device_type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

Expand All @@ -1351,8 +1356,9 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor])
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
# the torch amp context manager as some hidden states are silently casted to full precision.
device_type = "xpu" if is_torch_xpu_available() else "cuda"
generate_context_manager = amp.autocast(device_type) if self._peft_has_been_casted_to_bf16 else nullcontext()

with generate_context_manager:
policy_output = model.generate(
Expand Down Expand Up @@ -1406,7 +1412,8 @@ def prediction_step(
else:
ignore_keys = []

prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
device_type = "xpu" if is_torch_xpu_available() else "cuda"
prediction_context_manager = amp.autocast(device_type) if self._peft_has_been_casted_to_bf16 else nullcontext()

with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
Expand Down

0 comments on commit d6a7e9d

Please sign in to comment.