Skip to content

Commit

Permalink
make pretty
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 23, 2024
1 parent b25a34d commit d0aedc7
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 83 deletions.
43 changes: 15 additions & 28 deletions turbo_alignment/common/tf/liger_kernels/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import triton
import triton.language as tl

from torch.nn import CrossEntropyLoss


Expand Down Expand Up @@ -63,11 +62,9 @@ def liger_cross_entropy_kernel(
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
m = float('-inf') # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
ori_X_y = tl.load(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation
ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation

# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
Expand All @@ -76,9 +73,7 @@ def liger_cross_entropy_kernel(

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float('-inf'))
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
Expand Down Expand Up @@ -106,10 +101,8 @@ def liger_cross_entropy_kernel(

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if reduction == "mean":
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float('-inf'))
if reduction == 'mean':
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
Expand Down Expand Up @@ -141,12 +134,12 @@ def liger_cross_entropy_kernel(
loss = loss * (1 - label_smoothing) + smooth_loss

# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
if reduction == 'mean':
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
if reduction == "mean":
if reduction == 'mean':
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
Expand Down Expand Up @@ -268,9 +261,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
"""

@staticmethod
def forward(
ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
):
def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction='mean'):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -285,9 +276,7 @@ def forward(
Returns:
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction
)
loss, _input = cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
Expand Down Expand Up @@ -315,21 +304,19 @@ def backward(ctx, grad_output):
None,
None,
)


class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
), f'label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}'
assert self.reduction in {
"mean",
"sum",
"none",
'mean',
'sum',
'none',
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing, self.reduction
)
return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index, self.label_smoothing, self.reduction)
34 changes: 12 additions & 22 deletions turbo_alignment/common/tf/liger_kernels/geglu.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import operator

import torch
import torch.nn as nn
import triton
import triton.language as tl

import torch.nn as nn

from turbo_alignment.common.tf.liger_kernels.utils import calculate_settings, compare_version, ensure_contiguous

from turbo_alignment.common.tf.liger_kernels.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
)


if compare_version("triton", operator.ge, "3.0.0"):
if compare_version('triton', operator.ge, '3.0.0'):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
Expand All @@ -22,9 +23,7 @@


@triton.jit
def _geglu_tanh_forward_kernel(
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0)

# locate start index
Expand All @@ -49,9 +48,7 @@ def _geglu_tanh_forward_kernel(


@triton.jit
def _geglu_tanh_backward_kernel(
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0)

# locate start index
Expand Down Expand Up @@ -80,12 +77,7 @@ def _geglu_tanh_backward_kernel(
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
term2 = (
0.5
* a_row
* (1 - tanh_sq)
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
)
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
da_row = dc_row * b_row * (term1 + term2)

tl.store(a + col_offsets, da_row, mask=mask)
Expand Down Expand Up @@ -151,6 +143,7 @@ def backward(ctx, dc):
a, b = geglu_backward(a, b, dc)
return a, b


class LigerGEGLUMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -167,7 +160,4 @@ def __init__(self, config):
# So we can safely assume we use tanh approximation form all the time

def forward(self, x):

return self.down_proj(
LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
)
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
5 changes: 3 additions & 2 deletions turbo_alignment/common/tf/liger_kernels/monkey_patch_liger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from transformers import PretrainedConfig, PreTrainedModel

from turbo_alignment.common.tf.liger_kernels.cross_entropy import LigerCrossEntropyLoss
from turbo_alignment.common.tf.liger_kernels.geglu import LigerGEGLUMLP
from turbo_alignment.common.tf.liger_kernels.rope import liger_rotary_pos_emb
Expand Down Expand Up @@ -37,7 +38,7 @@ def apply_liger_kernel_to_gemma2(
# instance variables that reference already-instantiated modules
config: PretrainedConfig = model.config

if hasattr(model, "model"):
if hasattr(model, 'model'):
# The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
base_model = model.model
else:
Expand All @@ -50,4 +51,4 @@ def apply_liger_kernel_to_gemma2(
if geglu:
decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)

print('🙈'*15)
print('🙈' * 15)
35 changes: 9 additions & 26 deletions turbo_alignment/common/tf/liger_kernels/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,36 +61,20 @@ def _triton_rope(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_half_k_offsets = (
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype
)
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)

# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)

if not BACKWARD_PASS:
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
Expand Down Expand Up @@ -118,7 +102,6 @@ def _triton_rope(


def rope_forward(q, k, cos, sin):

# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
q = q.transpose(1, 2)
Expand Down Expand Up @@ -257,4 +240,4 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation.
"""

return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
6 changes: 3 additions & 3 deletions turbo_alignment/common/tf/liger_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def calculate_settings(n):
BLOCK_SIZE = triton.next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds "
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
f'Cannot launch Triton kernel since n = {n} exceeds '
f'the recommended Triton blocksize = {MAX_FUSED_SIZE}.'
)

num_warps = 4
Expand All @@ -59,4 +59,4 @@ def compare_version(package: str, operator: Callable, target: str):
except ImportError:
return False
pkg_version = Version(pkg.__version__)
return operator(pkg_version, Version(target))
return operator(pkg_version, Version(target))
5 changes: 3 additions & 2 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.common.tf.liger_kernels.monkey_patch_liger import apply_liger_kernel_to_gemma2
from turbo_alignment.common.tf.liger_kernels.monkey_patch_liger import (
apply_liger_kernel_to_gemma2,
)
from turbo_alignment.common.tf.loaders.model.registry import (
PeftConfigRegistry,
TransformersAutoModelRegistry,
Expand Down Expand Up @@ -44,7 +46,6 @@ def load_model(
model_settings: PreTrainedModelSettings,
tokenizer: PreTrainedTokenizerBase,
) -> PreTrainedModel:

if model_settings.use_liger_kernels:
apply_liger_kernel_to_gemma2()

Expand Down

0 comments on commit d0aedc7

Please sign in to comment.