Skip to content

Commit

Permalink
Move PallasQTensor materialization code as a method. Call materiali…
Browse files Browse the repository at this point in the history
…zation before `dequant`. This simplifies user interface.

PiperOrigin-RevId: 636916800
  • Loading branch information
lenscloth authored and copybara-github committed May 24, 2024
1 parent 02baee0 commit 28ddef3
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 3 deletions.
70 changes: 67 additions & 3 deletions aqt/jax/v2/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,78 @@
import jax.numpy as jnp


def transpose(t: jax.Array, axes: list[utils.AxisIdx]):
"""Transpose tensor through reshape if possible.
This is a workaround to use AQT inside pallas kernel. Pallas support
transposing only 2D tensor. However, tensor with a specific pattern can be
transposed through reshaping. This pattern is observed a lot in scale factor
of QTensor.
If axes with dimensions larger than one are ordered in ascending order, then
the tensor can be transposed by reshaping. For example,
Let's say the shape of tensor `x` is [10, 1, 1]. Then, there are only one
axis with dimension larger than one. The tensor can be transposed by
reshaping no matter what the order of permutation, e.g., `jnp.transpose(x, [1,
2, 0]) == jnp.reshape(x, [x.shape[i] for i in [1, 2, 0]])`.
Another example is `x` with shape [10, 1, 30]. Then, there are two axes with
dimension larger than one. The tensor can be transposed by reshaping if
tensor is permuted while maintining the order of axes with dimension larger
than one, i.e., the first axes remains in front of third axis after
permutation, e.g., `jnp.transpose(x, [0, 2, 1]) == jnp.reshape(x, [x.shape[i]
for i in [0, 2, 1]])`.
Args:
t: input tensor.
axes: transpose axes.
Returns:
A transposed tensor.
Example:
>>> x = np.random.randn([10, 1, 1]) # x is transposable by reshaping.
>>> axes = [2, 1, 0]
>>> y = jnp.transpose(x, axes)
>>> (jnp.reshape(x, y.shape) == y).all()
True
>>> (transpose(x, axes) == y).all()
>>> x = np.random.randn([10, 20, 30]) # x is not transposable by reshaping.
>>> axes = [2, 1, 0]
>>> y = jnp.transpose(x, axes)
>>> (jnp.reshape(x, y.shape) == y).all()
False
>>> (transpose(x, axes) == y).all()
True
"""

axes_bigger_than_one = [i for i, size in enumerate(t.shape) if size > 1]
permutation_axes_bigger_than_one = [
axes.index(i) for i in axes_bigger_than_one
]
# Check if axes bigger than one are ordered in ascending order.
reshapable = (
sorted(permutation_axes_bigger_than_one)
== permutation_axes_bigger_than_one
)

if reshapable:
t = jnp.reshape(t, (t.shape[i] for i in axes))
else:
t = jax.lax.transpose(t, axes)
return t


def _scale_trans(x, ca, ba):
"""Transposes x to output dimension order."""
ca = list(ca)
ba = list(ba)
for i in ca:
assert x.shape[i] == 1
ra = utils.get_remaining_axes(x.ndim, ca, ba)
x = jnp.transpose(x, ba + ra + ca)
x = transpose(x, ba + ra + ca)
# TODO(lew): x = jnp.squeeze(x, axis=range(len(ba+ra): len(x.shape))
shape_ba = x.shape[: len(ba)]
shape_ra = x.shape[len(ba) : len(x.shape) - len(ca)]
Expand Down Expand Up @@ -107,7 +171,7 @@ def _scale_trans_back(

assert -1 not in transpose_back

scale = jnp.transpose(scale, transpose_back)
scale = transpose(scale, transpose_back)
return scale


Expand Down Expand Up @@ -183,7 +247,7 @@ def _scale_trans_for_other_input(
assert ra_idx == len(my_ra)

# Transpose.
x = jnp.transpose(x, transpose_dim)
x = transpose(x, transpose_dim)

# Remove redundant axis.
if len(x.shape) > other_rank:
Expand Down
15 changes: 15 additions & 0 deletions aqt/jax/v2/transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for transpose."""

import math
from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax.v2 import transpose
Expand All @@ -22,6 +23,20 @@

class AqtTransposeTest(parameterized.TestCase):

@parameterized.parameters(
((10, 1, 10), (0, 2, 1)),
((10, 1, 5), (1, 0, 2)),
((5, 1, 10), (0, 1, 2)),
((10, 20, 30), (1, 2, 0)),
((10, 1, 1), (0, 2, 1)),
((10, 1, 1), (2, 1, 0)),
((10, 1, 1), (1, 2, 0)),
)
def test_transpose(self, tensor_shape, transpose_axes):
t = jnp.arange(math.prod(tensor_shape)).reshape(tensor_shape)
t_t = transpose.transpose(t, transpose_axes)
self.assertTrue((t_t == jnp.transpose(t, transpose_axes)).all())

@parameterized.parameters(
# 'bmnts,bsnh->bmtnh'
(
Expand Down

0 comments on commit 28ddef3

Please sign in to comment.