Skip to content

Commit

Permalink
rearrange fns for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam15s committed Dec 7, 2024
1 parent e381569 commit 8aa842a
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 159 deletions.
116 changes: 58 additions & 58 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,64 @@ def chunk_forward(

return student_logits_chunk, teacher_logits_chunk, ce_loss

@staticmethod
def _compute_loss(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
student_logits_chunk, teacher_logits_chunk, hard_loss = (
LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=student_bias,
teacher_bias=teacher_bias,
ignore_index=ignore_index,
compute_ce_loss=compute_ce_loss,
)
)

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -190,61 +248,3 @@ def backward(ctx, grad_output):
grad_bias = grad_bias * grad_output if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias

@staticmethod
def _compute_loss(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
student_logits_chunk, teacher_logits_chunk, hard_loss = (
LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=student_bias,
teacher_bias=teacher_bias,
ignore_index=ignore_index,
compute_ce_loss=compute_ce_loss,
)
)

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
202 changes: 101 additions & 101 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,103 @@ def chunk_forward(
chosen_nll_loss,
)

@staticmethod
def _compute_loss(
input_chunk,
weight,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)

if use_ref_model:
with torch.no_grad():
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps

preference_loss_outputs = preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
)
if isinstance(preference_loss_outputs, tuple):
preference_loss, *aux_outputs = preference_loss_outputs
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - preference_loss
return_vars = (
chosen_logps,
rejected_logps,
chosen_logits_mean,
rejected_logits_mean,
chosen_nll_loss,
)
return loss, (*return_vars, *aux_outputs)

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -134,7 +231,7 @@ def forward(
**loss_kwargs,
)

def accumulate_helper(input_chunk, target_chunk):
def accumulate_core(input_chunk, target_chunk):
if bias is not None:
return torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
Expand All @@ -156,7 +253,7 @@ def accumulate_chunk(input_chunk, target_chunk):
chunk_nll_loss,
*aux_outputs,
),
) = accumulate_helper(input_chunk, target_chunk)
) = accumulate_core(input_chunk, target_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(chunk_grad_input, chunk_grad_weight), (
Expand All @@ -169,7 +266,7 @@ def accumulate_chunk(input_chunk, target_chunk):
chunk_nll_loss,
*aux_outputs,
),
) = accumulate_helper(input_chunk, target_chunk)
) = accumulate_core(input_chunk, target_chunk)

grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
Expand Down Expand Up @@ -199,7 +296,7 @@ def accumulate_chunk(input_chunk, target_chunk):
return chunk_grad_input

if compiled:
accumulate_helper = torch.compile(accumulate_helper)
accumulate_core = torch.compile(accumulate_core)

len_chosen = target.shape[0] // 2
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
Expand Down Expand Up @@ -270,100 +367,3 @@ def backward(ctx, *grad_output):
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias, None, None, None

@staticmethod
def _compute_loss(
input_chunk,
weight,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)

if use_ref_model:
with torch.no_grad():
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps

preference_loss_outputs = preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
)
if isinstance(preference_loss_outputs, tuple):
preference_loss, *aux_outputs = preference_loss_outputs
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - preference_loss
return_vars = (
chosen_logps,
rejected_logps,
chosen_logits_mean,
rejected_logits_mean,
chosen_nll_loss,
)
return loss, (*return_vars, *aux_outputs)

0 comments on commit 8aa842a

Please sign in to comment.