Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python][Relax] Update Rotary positional embedding scaling #17506

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 103 additions & 10 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype:
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str):
"""Compute the inverse frequency of RoPE for gptj RoPE scaling."""
freq = s / tir.power(theta, 2 * (d // 2) % d_range / tir.const(d_range, "float32"))
freq_var = tir.Var("freq", "float32")
cos_freq = tir.cos(freq_var).astype(dtype)
sin_freq = tir.sin(freq_var).astype(dtype)
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
Expand Down Expand Up @@ -123,12 +132,74 @@ def rope_freq_longrope( # pylint: disable=too-many-arguments
return cos_freq, sin_freq, {freq_var: freq}


def yarn_find_correction_dim(
num_rotations: int,
d: tir.Var,
theta: float,
max_position_embeddings: int,
):
"""Inverse dim formula to find dim based on number of rotations"""
return (d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(theta)
)


def yarn_find_correction_range(
low_rot: int,
high_rot: int,
d: tir.Var,
theta: float,
max_position_embeddings: int,
):
"""Find the correction range based on the number of rotations"""
low = tir.floor(yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings))
high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings))
return tir.max(low, 0), tir.min(high, d - 1)


def rope_freq_yarn(
s: tir.Var,
d: tir.Var,
d_range: int,
theta: float,
dtype: str,
original_max_position_embeddings: int,
scaling_factor: float,
beta_fast: int,
beta_slow: int,
): # pylint: disable=too-many-arguments, too-many-locals
"""Compute the inverse frequency of RoPE for yarn RoPE scaling."""
freq_extra = tir.const(1, "float32") / tir.power(
theta, d * 2 % d_range / tir.const(d_range, "float32")
)

freq_inter = tir.const(1, "float32") / tir.power(
scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
)

low, high = yarn_find_correction_range(
beta_fast, beta_slow, d, theta, original_max_position_embeddings
)
high = tir.if_then_else(low == high, high + 0.001, high)
inv_freq_mask = tir.const(1, "float32") - tir.max(
tir.min((d - low) / (high - low), 1.0), 0.0
).astype("float32")
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
freq = s * inv_freq
freq_var = tir.Var("freq", "float32")
cos_freq = tir.cos(freq_var).astype(dtype)
sin_freq = tir.sin(freq_var).astype(dtype)
return cos_freq, sin_freq, {freq_var: freq}


def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
"""Return the RoPE inverse frequency computation function based
on the given RoPE scaling.
"""
if "rope_type" not in rope_scaling:
return rope_freq_default
if rope_scaling["rope_type"] == "gptj":
return rope_freq_gptj
if rope_scaling["rope_type"] == "llama3":
return partial(
rope_freq_llama3,
Expand All @@ -143,6 +214,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
max_position_embeddings=rope_scaling["max_position_embeddings"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "yarn":
return partial(
rope_freq_yarn,
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
scaling_factor=rope_scaling["factor"],
beta_fast=rope_scaling["beta_fast"],
beta_slow=rope_scaling["beta_slow"],
)
raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}')


Expand Down Expand Up @@ -220,11 +299,18 @@ def _rope( # pylint: disable=too-many-arguments
(s + offset) * scale, d, rotary_dim, theta, dtype
)
cos = cos_freq * x[b, s, h, d]
sin = sin_freq * tir.if_then_else(
d < rotary_dim // 2,
-x[b, s, h, d + rotary_dim // 2],
x[b, s, h, d - rotary_dim // 2],
)
if rope_scaling["rope_type"] == "gptj":
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[b, s, h, d + 1],
x[b, s, h, d - 1],
)
else:
sin = sin_freq * tir.if_then_else(
d < rotary_dim // 2,
-x[b, s, h, d + rotary_dim // 2],
x[b, s, h, d - rotary_dim // 2],
)
expr = cos + sin
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
Expand Down Expand Up @@ -341,11 +427,18 @@ def _rope( # pylint: disable=too-many-arguments
pos * scale, d, rotary_dim, theta, "float32", **kwargs
)
cos = cos_freq * x[s, h, d].astype("float32")
sin = sin_freq * tir.if_then_else(
d < rotary_dim // 2,
-x[s, h, d + rotary_dim // 2],
x[s, h, d - rotary_dim // 2],
).astype("float32")
if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
else:
sin = sin_freq * tir.if_then_else(
d < rotary_dim // 2,
-x[s, h, d + rotary_dim // 2],
x[s, h, d - rotary_dim // 2],
).astype("float32")
expr = (cos + sin).astype(dtype)
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
Expand Down
Loading