Skip to content

Commit

Permalink
fix linters and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 25, 2024
1 parent d56301d commit 2410c6f
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 248 deletions.
8 changes: 5 additions & 3 deletions turbo_alignment/cherry_picks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def _get_dataset_metrics(
labels = [record['labels'] for record in dataset]

contexts = [
Conversation(system_prompt=dataset.source.system_prompt, messages=record.messages).get_prompt_repr(
0, len(record.messages)
)
Conversation(
system_prompt=dataset.source.system_prompt,
messages=record.messages,
ignore_system_prompt=dataset.settings.chat_settings.ignore_system_prompt,
).get_prompt_repr(0, len(record.messages))
for record in generations
]

Expand Down
139 changes: 7 additions & 132 deletions turbo_alignment/common/tf/liger_kernels/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,28 @@ def liger_cross_entropy_kernel(
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
reduction: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The string for the reduction to apply
BLOCK_SIZE (int): The block size for Triton operations.
"""

# https://github.com/triton-lang/triton/issues/1058
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
program_id = tl.program_id(0).to(tl.int64)

# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)

# 2. locate the start index
X_ptr += program_id * X_stride

if y == ignore_index:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return

loss_ptr += program_id * loss_stride

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
m = float('-inf')
d = 0.0
ori_X_y = tl.load(X_ptr + y)

# 3. [Online softmax] first pass: find max + sum
m = float('-inf') # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation

# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
scaled_x_sum = 0.0
eps = label_smoothing / n_cols

Expand All @@ -76,29 +46,11 @@ def liger_cross_entropy_kernel(
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float('-inf'))
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new

# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
#
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
# = dx_i - (1 - label_smoothing)

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float('-inf'))
Expand All @@ -109,35 +61,17 @@ def liger_cross_entropy_kernel(

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
tl.debug_barrier()

# 5. Calculate the loss

# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = -(ori_X_y - m - tl.log(d))

# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss

# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == 'mean':
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
if reduction == 'mean':
X_y += -(1 - label_smoothing) / (n_non_ignore)
Expand All @@ -148,10 +82,7 @@ def liger_cross_entropy_kernel(
tl.store(X_ptr + y, X_y)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
MAX_FUSED_SIZE = 65536 // 2


@triton.jit
Expand All @@ -162,28 +93,12 @@ def element_mul_kernel(
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""

# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)

# Locate the start index
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
Expand All @@ -196,33 +111,28 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)

n_non_ignore = (target != ignore_index).sum().item()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()

# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
Y_stride=target.stride(-1),
loss_ptr=loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
loss_stride=loss_1d.stride(-1),
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction=reduction,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
)

Expand All @@ -231,12 +141,9 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti


def cross_entropy_backward(_input, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass

# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
Expand All @@ -255,46 +162,14 @@ def cross_entropy_backward(_input, grad_output):


class LigerCrossEntropyFunction(torch.autograd.Function):
"""
This class implements a custom autograd function for the Liger Cross Entropy loss.
It overrides the forward and backward methods of the torch.autograd.Function class.
"""

@staticmethod
def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction='mean'):
"""
The forward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
Returns:
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
return loss

@staticmethod
def backward(ctx, grad_output):
"""
The backward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
Expand Down
19 changes: 2 additions & 17 deletions turbo_alignment/common/tf/liger_kernels/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@

if compare_version('triton', operator.ge, '3.0.0'):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh
Expand All @@ -26,7 +24,6 @@
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0)

# locate start index
a += program_id * stride
b += program_id * stride
c += program_id * stride
Expand All @@ -36,9 +33,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)

# tanh approximation form of GELU is computed with:
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
sqrt_2_over_pi = 0.7978845608028654
a_cubed = a_row * a_row * a_row
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
tanh_result = tanh(tanh_arg)
Expand All @@ -51,7 +46,6 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0)

# locate start index
dc += program_id * stride
a += program_id * stride
b += program_id * stride
Expand All @@ -63,18 +57,14 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)

# recomputation to save memory
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
sqrt_2_over_pi = 0.7978845608028654
a_cubed = a_row * a_row * a_row
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
tanh_result = tanh(tanh_arg)
geglu_a = 0.5 * a_row * (1 + tanh_result)

db_row = dc_row * geglu_a

# Gradient w.r.t. a can be computed with:
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
Expand Down Expand Up @@ -153,11 +143,6 @@ def __init__(self, config):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# TODO: support exact GELU
# Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
# So we can safely assume we use tanh approximation form all the time

def forward(self, x):
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
Loading

0 comments on commit 2410c6f

Please sign in to comment.