Skip to content

Commit

Permalink
Revert "Add ref_input parameter to support separate inputs for refere…
Browse files Browse the repository at this point in the history
…nce model" (#469)

Reverts #467 until the test is fixed cc @shivam15s
  • Loading branch information
ByronHsu authored Dec 11, 2024
1 parent eee40c5 commit 969ce3a
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
ref_input=None,
# TODO: ref input
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand Down Expand Up @@ -59,7 +59,6 @@ def forward(
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size).
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Other possible arguments that a loss function might need
Expand Down Expand Up @@ -93,7 +92,6 @@ def forward(
compute_nll_loss=compute_nll_loss,
full_target=target,
use_ref_model=use_ref_model,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
**loss_kwargs,
Expand Down Expand Up @@ -303,7 +301,6 @@ def _compute_loss(
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_input=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand All @@ -322,7 +319,6 @@ def _compute_loss(
beta (float): Weight for the preference loss.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
Expand Down Expand Up @@ -361,7 +357,7 @@ def _compute_loss(
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
ref_input,
input_chunk,
ref_weight,
target_chunk,
ref_bias,
Expand Down

0 comments on commit 969ce3a

Please sign in to comment.