Skip to content

Commit

Permalink
Revert "fix chosen_nll_loss in chunked losses (#486)" (#489)
Browse files Browse the repository at this point in the history
This reverts commit 61eefe9.

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
shivam15s authored Dec 19, 2024
1 parent 61eefe9 commit 7a781b7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 101 deletions.
11 changes: 1 addition & 10 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def forward(
alpha=1.0,
compute_nll_loss=True,
compiled=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx,
Expand All @@ -61,13 +60,12 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None
return *grads, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
Expand All @@ -82,24 +80,18 @@ def __init__(
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
is_encoder_decoder: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
alpha (float): Weight for the NLL loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to compile the loss function.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.is_encoder_decoder = is_encoder_decoder

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCPOFunction.apply(
Expand All @@ -112,5 +104,4 @@ def forward(self, lin_weight, _input, target, bias=None):
self.alpha,
self.compute_nll_loss,
self.compiled,
self.is_encoder_decoder,
)
8 changes: 1 addition & 7 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -84,13 +83,12 @@ def forward(
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -105,7 +103,6 @@ def __init__(
compute_nll_loss: bool = True,
compiled: bool = True,
use_ref_model: bool = False,
is_encoder_decoder: bool = False,
):
"""
Args:
Expand All @@ -114,15 +111,13 @@ def __init__(
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
self.is_encoder_decoder = is_encoder_decoder

def forward(
self,
Expand All @@ -147,5 +142,4 @@ def forward(
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
self.is_encoder_decoder,
)
92 changes: 31 additions & 61 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def forward(
ignore_index=-100,
alpha=1.0,
beta=0.1,
is_encoder_decoder=False,
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
Expand Down Expand Up @@ -57,7 +56,6 @@ def forward(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
Expand Down Expand Up @@ -96,7 +94,6 @@ def forward(
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
is_encoder_decoder=is_encoder_decoder,
**loss_kwargs,
)

Expand Down Expand Up @@ -285,48 +282,33 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
is_encoder_decoder=False,
):
# Calculate logits and log probabilities
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk += bias
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

# Split chunk into chosen and rejected portions
len_chosen_chunk = target_chunk.shape[0] // 2

# Handle sequence shifting for non-encoder-decoder models
if not is_encoder_decoder:
logits_chunk = logits_chunk[:, :-1]
log_probs_chunk = log_probs_chunk[:, :-1]
target_chunk = target_chunk[:, 1:]

# Calculate NLL loss for chosen sequences
chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_probs = log_probs_chunk[:len_chosen_chunk]
chosen_targets = target_chunk[:len_chosen_chunk]
chosen_nll_loss = F.nll_loss(
chosen_probs.reshape(-1, chosen_probs.shape[-1]),
chosen_targets.reshape(-1),
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)

# Calculate per-token log probabilities
loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

# Split results for chosen and rejected
chosen_logps, rejected_logps = (
average_log_prob[:len_chosen_chunk],
average_log_prob[len_chosen_chunk:],
)
chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

chosen_logits = logits_chunk[:len_chosen_chunk]
rejected_logits = logits_chunk[len_chosen_chunk:]

Expand All @@ -349,7 +331,6 @@ def _compute_loss(
ignore_index=-100,
alpha=1.0,
beta=0.1,
is_encoder_decoder=False,
compute_nll_loss=True,
use_ref_model=False,
ref_input_chunk=None,
Expand All @@ -369,7 +350,6 @@ def _compute_loss(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
Expand All @@ -389,43 +369,33 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
is_encoder_decoder=is_encoder_decoder,
)
if not is_encoder_decoder:
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2, 1:] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0]
)
else:
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)

if use_ref_model:
with torch.no_grad():
(ref_chosen_logps, ref_rejected_logps, _, _, _) = (
LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
is_encoder_decoder=is_encoder_decoder, # assume the ref model is the same family
)
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
10 changes: 1 addition & 9 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def forward(
beta=0.1,
compute_nll_loss=True,
compiled=True,
is_encoder_decoder=False,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -70,13 +69,12 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
is_encoder_decoder=is_encoder_decoder,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None
return *grads, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
Expand All @@ -90,22 +88,17 @@ def __init__(
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
is_encoder_decoder: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to compile the loss function.
is_encoder_decoder (bool): Whether the model is an encoder-decoder model.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.is_encoder_decoder = is_encoder_decoder

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearORPOFunction.apply(
Expand All @@ -117,5 +110,4 @@ def forward(self, lin_weight, _input, target, bias=None):
self.beta,
self.compute_nll_loss,
self.compiled,
self.is_encoder_decoder,
)
17 changes: 3 additions & 14 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,11 @@ def __init__(
beta: float = 0.1,
ignore_index: int = -100,
use_ref_model: bool = False,
is_encoder_decoder: bool = False,
):
self.alpha = alpha
self.beta = beta
self.ignore_index = ignore_index
self.use_ref_model = use_ref_model
self.is_encoder_decoder = is_encoder_decoder

@abstractmethod
def alignment_loss(self):
Expand All @@ -374,6 +372,7 @@ def get_batch_logps(
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
Expand All @@ -382,9 +381,6 @@ def get_batch_logps(
"Logits (batch and sequence length dim) and labels must have the same shape."
)

if not self.is_encoder_decoder:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss_mask = labels != self.ignore_index

# dummy token; we'll ignore the losses on these tokens later
Expand Down Expand Up @@ -444,9 +440,6 @@ def concatenated_forward(
def cross_entropy_loss(logits, labels):
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
if not self.is_encoder_decoder:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
Expand All @@ -468,12 +461,8 @@ def cross_entropy_loss(logits, labels):
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1]
rejected_logits = all_logits[len_chosen:, :-1]
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (
chosen_logps,
Expand Down

0 comments on commit 7a781b7

Please sign in to comment.