Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pl.BlockSpec that holds block_shape with None. #630

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions aqt/jax/v2/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,78 @@
import jax.numpy as jnp


def transpose(t: jax.Array, axes: list[utils.AxisIdx]):
"""Transpose tensor through reshape if possible.

This is a workaround to use AQT inside pallas kernel. Pallas support
transposing only 2D tensor. However, tensor with a specific pattern can be
transposed through reshaping. This pattern is observed a lot in scale factor
of QTensor.

If axes with dimensions larger than one are ordered in ascending order, then
the tensor can be transposed by reshaping. For example,

Let's say the shape of tensor `x` is [10, 1, 1]. Then, there are only one
axis with dimension larger than one. The tensor can be transposed by
reshaping no matter what the order of permutation, e.g., `jnp.transpose(x, [1,
2, 0]) == jnp.reshape(x, [x.shape[i] for i in [1, 2, 0]])`.

Another example is `x` with shape [10, 1, 30]. Then, there are two axes with
dimension larger than one. The tensor can be transposed by reshaping if
tensor is permuted while maintining the order of axes with dimension larger
than one, i.e., the first axes remains in front of third axis after
permutation, e.g., `jnp.transpose(x, [0, 2, 1]) == jnp.reshape(x, [x.shape[i]
for i in [0, 2, 1]])`.

Args:
t: input tensor.
axes: transpose axes.

Returns:
A transposed tensor.

Example:
>>> x = np.random.randn([10, 1, 1]) # x is transposable by reshaping.
>>> axes = [2, 1, 0]
>>> y = jnp.transpose(x, axes)
>>> (jnp.reshape(x, y.shape) == y).all()
True
>>> (transpose(x, axes) == y).all()

>>> x = np.random.randn([10, 20, 30]) # x is not transposable by reshaping.
>>> axes = [2, 1, 0]
>>> y = jnp.transpose(x, axes)
>>> (jnp.reshape(x, y.shape) == y).all()
False
>>> (transpose(x, axes) == y).all()
True
"""

axes_bigger_than_one = [i for i, size in enumerate(t.shape) if size > 1]
permutation_axes_bigger_than_one = [
axes.index(i) for i in axes_bigger_than_one
]
# Check if axes bigger than one are ordered in ascending order.
reshapable = (
sorted(permutation_axes_bigger_than_one)
== permutation_axes_bigger_than_one
)

if reshapable:
t = jnp.reshape(t, (t.shape[i] for i in axes))
else:
t = jax.lax.transpose(t, axes)
return t


def _scale_trans(x, ca, ba):
"""Transposes x to output dimension order."""
ca = list(ca)
ba = list(ba)
for i in ca:
assert x.shape[i] == 1
ra = utils.get_remaining_axes(x.ndim, ca, ba)
x = jnp.transpose(x, ba + ra + ca)
x = transpose(x, ba + ra + ca)
# TODO(lew): x = jnp.squeeze(x, axis=range(len(ba+ra): len(x.shape))
shape_ba = x.shape[: len(ba)]
shape_ra = x.shape[len(ba) : len(x.shape) - len(ca)]
Expand Down Expand Up @@ -107,7 +171,7 @@ def _scale_trans_back(

assert -1 not in transpose_back

scale = jnp.transpose(scale, transpose_back)
scale = transpose(scale, transpose_back)
return scale


Expand Down Expand Up @@ -183,7 +247,7 @@ def _scale_trans_for_other_input(
assert ra_idx == len(my_ra)

# Transpose.
x = jnp.transpose(x, transpose_dim)
x = transpose(x, transpose_dim)

# Remove redundant axis.
if len(x.shape) > other_rank:
Expand Down
15 changes: 15 additions & 0 deletions aqt/jax/v2/transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for transpose."""

import math
from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax.v2 import transpose
Expand All @@ -22,6 +23,20 @@

class AqtTransposeTest(parameterized.TestCase):

@parameterized.parameters(
((10, 1, 10), (0, 2, 1)),
((10, 1, 5), (1, 0, 2)),
((5, 1, 10), (0, 1, 2)),
((10, 20, 30), (1, 2, 0)),
((10, 1, 1), (0, 2, 1)),
((10, 1, 1), (2, 1, 0)),
((10, 1, 1), (1, 2, 0)),
)
def test_transpose(self, tensor_shape, transpose_axes):
t = jnp.arange(math.prod(tensor_shape)).reshape(tensor_shape)
t_t = transpose.transpose(t, transpose_axes)
self.assertTrue((t_t == jnp.transpose(t, transpose_axes)).all())

@parameterized.parameters(
# 'bmnts,bsnh->bmtnh'
(
Expand Down