Skip to content

Commit

Permalink
fix chosen_nll_loss in chunked losses (#486)
Browse files Browse the repository at this point in the history
## Summary
Fix the nll loss in the the chunked loses when the model is a decoder
only model, by shifting the logits and targets

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
kashif authored Dec 18, 2024
1 parent ac56674 commit 61eefe9
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 37 deletions.
11 changes: 10 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(
alpha=1.0,
compute_nll_loss=True,
compiled=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx,
Expand All @@ -60,12 +61,13 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None
return *grads, None, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
Expand All @@ -80,18 +82,24 @@ def __init__(
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
is_encoder_decoder: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
alpha (float): Weight for the NLL loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to compile the loss function.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.is_encoder_decoder = is_encoder_decoder

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCPOFunction.apply(
Expand All @@ -104,4 +112,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.alpha,
self.compute_nll_loss,
self.compiled,
self.is_encoder_decoder,
)
8 changes: 7 additions & 1 deletion src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -83,12 +84,13 @@ def forward(
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -103,6 +105,7 @@ def __init__(
compute_nll_loss: bool = True,
compiled: bool = True,
use_ref_model: bool = False,
is_encoder_decoder: bool = False,
):
"""
Args:
Expand All @@ -111,13 +114,15 @@ def __init__(
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
self.is_encoder_decoder = is_encoder_decoder

def forward(
self,
Expand All @@ -142,4 +147,5 @@ def forward(
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
self.is_encoder_decoder,
)
92 changes: 61 additions & 31 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def forward(
ignore_index=-100,
alpha=1.0,
beta=0.1,
is_encoder_decoder=False,
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
Expand Down Expand Up @@ -56,6 +57,7 @@ def forward(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
Expand Down Expand Up @@ -94,6 +96,7 @@ def forward(
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
is_encoder_decoder=is_encoder_decoder,
**loss_kwargs,
)

Expand Down Expand Up @@ -282,33 +285,48 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
is_encoder_decoder=False,
):
len_chosen_chunk = target_chunk.shape[0] // 2
# Calculate logits and log probabilities
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
logits_chunk += bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

# Split chunk into chosen and rejected portions
len_chosen_chunk = target_chunk.shape[0] // 2

# Handle sequence shifting for non-encoder-decoder models
if not is_encoder_decoder:
logits_chunk = logits_chunk[:, :-1]
log_probs_chunk = log_probs_chunk[:, :-1]
target_chunk = target_chunk[:, 1:]

# Calculate NLL loss for chosen sequences
chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_probs = log_probs_chunk[:len_chosen_chunk]
chosen_targets = target_chunk[:len_chosen_chunk]
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
chosen_probs.reshape(-1, chosen_probs.shape[-1]),
chosen_targets.reshape(-1),
reduction="sum",
ignore_index=ignore_index,
)

# Calculate per-token log probabilities
loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

# Split results for chosen and rejected
chosen_logps, rejected_logps = (
average_log_prob[:len_chosen_chunk],
average_log_prob[len_chosen_chunk:],
)
chosen_logits = logits_chunk[:len_chosen_chunk]
rejected_logits = logits_chunk[len_chosen_chunk:]

Expand All @@ -331,6 +349,7 @@ def _compute_loss(
ignore_index=-100,
alpha=1.0,
beta=0.1,
is_encoder_decoder=False,
compute_nll_loss=True,
use_ref_model=False,
ref_input_chunk=None,
Expand All @@ -350,6 +369,7 @@ def _compute_loss(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
Expand All @@ -369,33 +389,43 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
is_encoder_decoder=is_encoder_decoder,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
if not is_encoder_decoder:
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2, 1:] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0]
)
else:
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)

if use_ref_model:
with torch.no_grad():
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
(ref_chosen_logps, ref_rejected_logps, _, _, _) = (
LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
is_encoder_decoder=is_encoder_decoder, # assume the ref model is the same family
)
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
10 changes: 9 additions & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def forward(
beta=0.1,
compute_nll_loss=True,
compiled=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -69,12 +70,13 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None
return *grads, None, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
Expand All @@ -88,17 +90,22 @@ def __init__(
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
is_encoder_decoder: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to compile the loss function.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.is_encoder_decoder = is_encoder_decoder

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearORPOFunction.apply(
Expand All @@ -110,4 +117,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.beta,
self.compute_nll_loss,
self.compiled,
self.is_encoder_decoder,
)
17 changes: 14 additions & 3 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,13 @@ def __init__(
beta: float = 0.1,
ignore_index: int = -100,
use_ref_model: bool = False,
is_encoder_decoder: bool = False,
):
self.alpha = alpha
self.beta = beta
self.ignore_index = ignore_index
self.use_ref_model = use_ref_model
self.is_encoder_decoder = is_encoder_decoder

@abstractmethod
def alignment_loss(self):
Expand All @@ -372,7 +374,6 @@ def get_batch_logps(
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
Expand All @@ -381,6 +382,9 @@ def get_batch_logps(
"Logits (batch and sequence length dim) and labels must have the same shape."
)

if not self.is_encoder_decoder:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss_mask = labels != self.ignore_index

# dummy token; we'll ignore the losses on these tokens later
Expand Down Expand Up @@ -440,6 +444,9 @@ def concatenated_forward(
def cross_entropy_loss(logits, labels):
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
if not self.is_encoder_decoder:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
Expand All @@ -461,8 +468,12 @@ def cross_entropy_loss(logits, labels):
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1]
rejected_logits = all_logits[len_chosen:, :-1]
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (
chosen_logps,
Expand Down

0 comments on commit 61eefe9

Please sign in to comment.