diff --git a/tests/test_rope_padded.py b/tests/test_rope_padded.py index 821612ffec..3159ff9c2b 100644 --- a/tests/test_rope_padded.py +++ b/tests/test_rope_padded.py @@ -3,8 +3,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import math -from functools import partial from typing import Optional import pytest @@ -23,33 +21,6 @@ ) -def apply_scaling( - freqs: torch.Tensor, - old_context_len: float, - low_freq_factor: float, - high_freq_factor: float, - dynamic_scale_factor: float, -): - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - assert low_freq_wavelen >= high_freq_wavelen - - for idx, freq in enumerate(freqs): - wavelen = 2 * math.pi / freq - if wavelen > low_freq_wavelen: - freqs[idx] = freq / dynamic_scale_factor - - if high_freq_wavelen <= wavelen and wavelen <= low_freq_wavelen: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - freqs[idx] = (1 - smooth) * freqs[ - idx - ] / dynamic_scale_factor + smooth * freqs[idx] - return freqs - - def _slow_rope( x: torch.Tensor, *, @@ -57,11 +28,6 @@ def _slow_rope( theta=10000, linear_scale=1, adjacents: bool = True, - use_dynamic_scaling: bool = False, - dynamic_old_context_len: float = 8192.0, - dynamic_scale_factor: float = 16.0, - dynamic_low_freq_factor: float = 1.0, - dynamic_high_freq_factor: float = 32.0, ): """ Simple rope calculation of rope of one tensor @@ -79,15 +45,7 @@ def _slow_rope( if seqpos is None: seqpos = torch.arange(M, device=x.device) power = torch.arange(0, dim, 2, device=x.device)[: (dim // 2)].float() / dim - freqs: torch.Tensor = 1.0 / (theta**power) # type: ignore - if use_dynamic_scaling: - freqs = apply_scaling( - freqs, - dynamic_old_context_len, - dynamic_low_freq_factor, - dynamic_high_freq_factor, - dynamic_scale_factor, - ) + freqs = 1.0 / (theta**power) all_freqs = torch.outer(seqpos / linear_scale, freqs) freqs_cis = torch.polar(torch.ones_like(all_freqs), all_freqs) # complex64 for _ in range(x.ndim - seq_dim - 2): @@ -160,9 +118,7 @@ def _slow_rope2( @pytest.mark.parametrize("dim", [100, 4098]) @pytest.mark.parametrize("padding", [87, 18300]) @pytest.mark.parametrize("groups", [1, 3]) -@pytest.mark.parametrize( - "linear_scale, use_dynamic_scaling", [(1.0, False), (4.0, False), (1.0, True)] -) +@pytest.mark.parametrize("linear_scale", [1.0, 4.0]) def test_consistency( adjacents: bool, dim: int, @@ -171,7 +127,6 @@ def test_consistency( internal_dtype: str, dtype_str: str, linear_scale: float, - use_dynamic_scaling: bool, ): torch.manual_seed(1) heads, kvheads = 10, 2 @@ -226,7 +181,6 @@ def test_consistency( linear_scale=linear_scale, adjacents=adjacents, internal_dtype=internal_dtype, - use_dynamic_scaling=use_dynamic_scaling, ) seqpos = torch.tensor( @@ -235,9 +189,7 @@ def test_consistency( ) cache_locs = [seqpos[0], seqpos[1], padding + seqpos[2], 2 * padding + seqpos[3]] baseline = _slow_rope if dtype_str == "f32" else _slow_rope2 - if use_dynamic_scaling: - baseline = partial(_slow_rope, use_dynamic_scaling=True) # type: ignore - expected_out = baseline( # type: ignore + expected_out = baseline( xq, linear_scale=linear_scale, seqpos=seqpos, adjacents=adjacents ) atol, rtol = ROPE_ATOL_RTOL[dtype_str] @@ -248,11 +200,7 @@ def test_consistency( assert torch.allclose(cache_v, cache_v_orig) slow_roped_xk = _slow_rope( - xk, - linear_scale=linear_scale, - seqpos=seqpos, - adjacents=adjacents, - use_dynamic_scaling=use_dynamic_scaling, + xk, linear_scale=linear_scale, seqpos=seqpos, adjacents=adjacents ) assert_allclose( cache_k[:, cache_locs], diff --git a/xformers/ops/_triton/rope_padded_kernels.py b/xformers/ops/_triton/rope_padded_kernels.py index 80781b30c9..cc788bb14f 100644 --- a/xformers/ops/_triton/rope_padded_kernels.py +++ b/xformers/ops/_triton/rope_padded_kernels.py @@ -27,11 +27,6 @@ def _rope_padded_kernel( seqlenk, theta, linear_scale, - use_dynamic_scaling: tl.constexpr, - dynamic_old_context_len: tl.constexpr, - dynamic_scale_factor: tl.constexpr, - dynamic_low_freq_factor: tl.constexpr, - dynamic_high_freq_factor: tl.constexpr, first_seqpos, seqpos, k_start: tl.constexpr, @@ -187,28 +182,8 @@ def _rope_padded_kernel( re_x = tl.load(x_in + cols_re, mask=mask) im_x = tl.load(x_in + cols_im, mask=mask) # freqs = seq_pos / (theta ** (powers / dim)) - freqs = pow(theta, powers / (-dim)) - - if use_dynamic_scaling: - lo_freq_wavelen = dynamic_old_context_len / dynamic_low_freq_factor - hi_freq_wavelen = dynamic_old_context_len / dynamic_high_freq_factor - - wavelens = 6.28318530718 / freqs # 2*pi - is_low_freq = wavelens > lo_freq_wavelen - freqs = tl.where(is_low_freq, freqs / dynamic_scale_factor, freqs) - - is_mid_freq = hi_freq_wavelen <= wavelens and wavelens <= lo_freq_wavelen - - smooth = (dynamic_old_context_len / wavelens - dynamic_low_freq_factor) / ( - dynamic_high_freq_factor - dynamic_low_freq_factor - ) - freqs = tl.where( - is_mid_freq, - (1 - smooth) * freqs / dynamic_scale_factor + smooth * freqs, - freqs, - ) - - freqs = seq_pos * freqs / linear_scale + freqs = seq_pos * pow(theta, powers / (-dim)) + freqs = freqs / linear_scale sines = tl.sin(freqs) cosines = tl.cos(freqs) re_out = re_x * cosines - im_x * sines diff --git a/xformers/ops/rope_padded.py b/xformers/ops/rope_padded.py index 1666b3a3ee..3aad1d64cc 100644 --- a/xformers/ops/rope_padded.py +++ b/xformers/ops/rope_padded.py @@ -23,11 +23,6 @@ def rope_padded( *, theta: float = 10000.0, linear_scale: float = 1.0, - use_dynamic_scaling: bool = False, - dynamic_old_context_len: float = 8192.0, - dynamic_scale_factor: float = 16.0, - dynamic_low_freq_factor: float = 1.0, - dynamic_high_freq_factor: float = 32.0, out_q: Optional[torch.Tensor] = None, first_seqpos: Optional[torch.Tensor] = None, seqpos: Optional[torch.Tensor] = None, @@ -85,11 +80,6 @@ def rope_padded( linear_scale: A scaling factor to apply to the sequence ids when computing the RoPE frequencies. When set to K, all sequence indices are divided by K. - use_dynamic_scaling: If true, dynamic scaling in use, using the following - dynamic_old_context_len - dynamic_scale_factor - dynamic_low_freq_factor - dynamic_high_freq_factor internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation """ if torch.is_grad_enabled() and ( @@ -255,11 +245,6 @@ def rope_padded( seqlenk, theta, linear_scale, - use_dynamic_scaling, - dynamic_old_context_len if use_dynamic_scaling else 0, - dynamic_scale_factor if use_dynamic_scaling else 0, - dynamic_low_freq_factor if use_dynamic_scaling else 0, - dynamic_high_freq_factor if use_dynamic_scaling else 0, first_seqpos, seqpos, k_start,