From 28ddef3ffdcdd4591fc48999762a80e710a850c9 Mon Sep 17 00:00:00 2001 From: Wonpyo Park Date: Fri, 24 May 2024 07:50:30 -0700 Subject: [PATCH] Move `PallasQTensor` materialization code as a method. Call materialization before `dequant`. This simplifies user interface. PiperOrigin-RevId: 636916800 --- aqt/jax/v2/transpose.py | 70 ++++++++++++++++++++++++++++++++++-- aqt/jax/v2/transpose_test.py | 15 ++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/aqt/jax/v2/transpose.py b/aqt/jax/v2/transpose.py index d17cf667..54c821c4 100644 --- a/aqt/jax/v2/transpose.py +++ b/aqt/jax/v2/transpose.py @@ -26,6 +26,70 @@ import jax.numpy as jnp +def transpose(t: jax.Array, axes: list[utils.AxisIdx]): + """Transpose tensor through reshape if possible. + + This is a workaround to use AQT inside pallas kernel. Pallas support + transposing only 2D tensor. However, tensor with a specific pattern can be + transposed through reshaping. This pattern is observed a lot in scale factor + of QTensor. + + If axes with dimensions larger than one are ordered in ascending order, then + the tensor can be transposed by reshaping. For example, + + Let's say the shape of tensor `x` is [10, 1, 1]. Then, there are only one + axis with dimension larger than one. The tensor can be transposed by + reshaping no matter what the order of permutation, e.g., `jnp.transpose(x, [1, + 2, 0]) == jnp.reshape(x, [x.shape[i] for i in [1, 2, 0]])`. + + Another example is `x` with shape [10, 1, 30]. Then, there are two axes with + dimension larger than one. The tensor can be transposed by reshaping if + tensor is permuted while maintining the order of axes with dimension larger + than one, i.e., the first axes remains in front of third axis after + permutation, e.g., `jnp.transpose(x, [0, 2, 1]) == jnp.reshape(x, [x.shape[i] + for i in [0, 2, 1]])`. + + Args: + t: input tensor. + axes: transpose axes. + + Returns: + A transposed tensor. + + Example: + >>> x = np.random.randn([10, 1, 1]) # x is transposable by reshaping. + >>> axes = [2, 1, 0] + >>> y = jnp.transpose(x, axes) + >>> (jnp.reshape(x, y.shape) == y).all() + True + >>> (transpose(x, axes) == y).all() + + >>> x = np.random.randn([10, 20, 30]) # x is not transposable by reshaping. + >>> axes = [2, 1, 0] + >>> y = jnp.transpose(x, axes) + >>> (jnp.reshape(x, y.shape) == y).all() + False + >>> (transpose(x, axes) == y).all() + True + """ + + axes_bigger_than_one = [i for i, size in enumerate(t.shape) if size > 1] + permutation_axes_bigger_than_one = [ + axes.index(i) for i in axes_bigger_than_one + ] + # Check if axes bigger than one are ordered in ascending order. + reshapable = ( + sorted(permutation_axes_bigger_than_one) + == permutation_axes_bigger_than_one + ) + + if reshapable: + t = jnp.reshape(t, (t.shape[i] for i in axes)) + else: + t = jax.lax.transpose(t, axes) + return t + + def _scale_trans(x, ca, ba): """Transposes x to output dimension order.""" ca = list(ca) @@ -33,7 +97,7 @@ def _scale_trans(x, ca, ba): for i in ca: assert x.shape[i] == 1 ra = utils.get_remaining_axes(x.ndim, ca, ba) - x = jnp.transpose(x, ba + ra + ca) + x = transpose(x, ba + ra + ca) # TODO(lew): x = jnp.squeeze(x, axis=range(len(ba+ra): len(x.shape)) shape_ba = x.shape[: len(ba)] shape_ra = x.shape[len(ba) : len(x.shape) - len(ca)] @@ -107,7 +171,7 @@ def _scale_trans_back( assert -1 not in transpose_back - scale = jnp.transpose(scale, transpose_back) + scale = transpose(scale, transpose_back) return scale @@ -183,7 +247,7 @@ def _scale_trans_for_other_input( assert ra_idx == len(my_ra) # Transpose. - x = jnp.transpose(x, transpose_dim) + x = transpose(x, transpose_dim) # Remove redundant axis. if len(x.shape) > other_rank: diff --git a/aqt/jax/v2/transpose_test.py b/aqt/jax/v2/transpose_test.py index 21109b30..b2d93328 100644 --- a/aqt/jax/v2/transpose_test.py +++ b/aqt/jax/v2/transpose_test.py @@ -14,6 +14,7 @@ """Tests for transpose.""" +import math from absl.testing import absltest from absl.testing import parameterized from aqt.jax.v2 import transpose @@ -22,6 +23,20 @@ class AqtTransposeTest(parameterized.TestCase): + @parameterized.parameters( + ((10, 1, 10), (0, 2, 1)), + ((10, 1, 5), (1, 0, 2)), + ((5, 1, 10), (0, 1, 2)), + ((10, 20, 30), (1, 2, 0)), + ((10, 1, 1), (0, 2, 1)), + ((10, 1, 1), (2, 1, 0)), + ((10, 1, 1), (1, 2, 0)), + ) + def test_transpose(self, tensor_shape, transpose_axes): + t = jnp.arange(math.prod(tensor_shape)).reshape(tensor_shape) + t_t = transpose.transpose(t, transpose_axes) + self.assertTrue((t_t == jnp.transpose(t, transpose_axes)).all()) + @parameterized.parameters( # 'bmnts,bsnh->bmtnh' (