Skip to content

Commit

Permalink
[Pallas TPU] Add vector support to pl.debug_print
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699920381
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Jan 13, 2025
1 parent e72c148 commit 217c67e
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 26 deletions.
6 changes: 6 additions & 0 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->

## Released with jax 0.5.0

* New functionality

* Added vector support for {func}`jax.experimental.pallas.debug_print` on TPU.

## Released with jax 0.4.37

* New functionality
Expand Down
82 changes: 62 additions & 20 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3140,29 +3140,71 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int):
def _debug_print_rule(
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
):
if any(aval.shape for aval in ctx.avals_in):
raise NotImplementedError("Only scalar values are supported")

primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
for arg in args
):
raise TypeError(
"All arguments must be 32-bit integers when using"
" placeholders (`{...}`). If you need to print values of other types,"
" remove placeholders from the format string."
is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in]
is_all_scalars = all(is_scalar_inputs)
is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0]
if not (is_all_scalars or is_single_vector):
raise ValueError(
"All inputs to debug_print must be all scalars or a single vector, but"
f" got {ctx.avals_in}"
)

# Scalar case.
if is_all_scalars:
primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
for arg in args
):
raise TypeError(
"All arguments must be 32-bit integers when using"
" placeholders (`{...}`). If you need to print values of other types,"
" remove placeholders from the format string."
)

# TPU expects $0, $1 etc as placeholders.
fmt = "".join(
f"{text}${idx}"
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
)

# TPU expects $0, $1 etc as placeholders.
tpu_fmt = "".join(
f"{text}${idx}"
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
tpu.log(args, fmt, formatted=has_placeholders)
return ()

# Vector case.
# Copy the array to vmem for logging.
# Note that the shape of the array must be explicitly provided here. This is
# because the underlying implementation aligns shapes to tile boundaries,
# potentially altering the original shape and making it unrecoverable.
if len(ctx.avals_in) != 1:
raise ValueError(
"Only one vector input to debug_print is supported."
)
else:
tpu_fmt = fmt
tpu.log(args, tpu_fmt, formatted=has_placeholders)
(aval,) = ctx.avals_in
(arg,) = args

if not has_placeholders or not fmt.endswith("{}"):
raise ValueError("For vector input, the format string must end with {}.")

fmt = fmt[:-2]

region = tpu.RegionOp(())
with ir.InsertionPoint(region.body):
element_type = _dtype_to_ir_type(aval.dtype)
ref_type = ir.MemRefType.get(
aval.shape,
element_type,
memory_space=ir.Attribute.parse("#tpu.memory_space<vmem>"),
)
ref = memref.alloca(ref_type, [], [])

index_type = ir.IndexType.get()
zero = arith.constant(index_type, 0)
indices = [zero] * len(aval.shape)
vector.store(arg, ref, indices)
tpu.log_buffer(ref, aval.shape, fmt)
tpu.yield_([])
return ()


Expand Down
9 changes: 6 additions & 3 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,12 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike):
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported. All values must be scalars.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string. All values must be scalars.
* On TPU, if all inputs are scalars: If ``fmt`` contains placeholders,
all values must be 32-bit integers. If there are no placeholders, the
values are printed after the format string.
* On TPU, if the input is a single vector, the vector is printed after
the format string. The format string must end with a single placeholder
``{}``.
*args: The values to print.
""" # fmt: skip
has_placeholders = False
Expand Down
11 changes: 11 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,17 @@ def TPU_LogOp : TPU_Op<"log"> {
let hasVerifier = 1;
}

def TPU_LogBufferOp : TPU_Op<"log_buffer"> {
let arguments = (ins
AnyMemRef:$input,
DenseI64ArrayAttr:$shape,
StrAttr:$tag
);
let results = (outs);
let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }];
let hasVerifier = 1;
}

def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
Expand Down
9 changes: 9 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,15 @@ LogicalResult WeirdOp::verify() {
return success();
}

LogicalResult LogBufferOp::verify() {
const MemRefType input_type = getInput().getType();
if (input_type.getRank() != getShape().size()) {
return emitOpError(
"Shape must have the same length as the rank of the input");
}
return success();
}

void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
const VectorType output_type,
const ArrayRef<Value> padded_sources,
Expand Down
7 changes: 4 additions & 3 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,10 +1339,11 @@ def kernel(x_ref, o_ref):
"plgpu.TritonCompilerParams unavailable on Windows",
)
def test_debug_print(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")

if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):
Expand All @@ -1369,7 +1370,7 @@ def kernel(x_ref, o_ref):
)
def test_debug_print_with_values(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")

# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):
Expand Down
45 changes: 45 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,51 @@ def kernel(x_ref, o_ref):
jax.block_until_ready(compiled_kernel(x))
self.assertIn('x[0] == 42', get_output())

@parameterized.named_parameters(
(f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype)
for shape in (
(2, 8, 128),
# test unaligned shapes
(3,),
(3, 4),
(2, 3, 4),
(2, 9, 129),
)
for dtype in (jnp.int32, jnp.uint32, jnp.float32)
)
def test_debug_print_vector(self, shape, dtype):
# TODO(ayx): Remove after this date.
if not jtu.if_cloud_tpu_at_least(2025, 1, 16):
self.skipTest("Requires libtpu built after 2025-01-16")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, dtype),
)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
o_ref[...] = x_ref[...]

n = np.prod(shape)
x = jnp.arange(n, dtype=dtype).reshape(shape)
compiled_kernel = (
jax.jit(kernel)
.lower(x)
.compile({"xla_tpu_enable_log_recorder": "true"})
)
with jtu.capture_stderr() as get_output:
jax.block_until_ready(compiled_kernel(x))
output = get_output()
numbers = [
int(num)
for line in output.splitlines()
if (match := re.search(r"\{(.*)", line)) # extract contents after `{`
for num in re.findall(r"\d+", match.group(1))
]
# Check if the numbers in the output match the values generated by `arange`.
self.assertLen(numbers, n)
self.assertTrue(all(num == i for i, num in enumerate(numbers)))


class PallasCallTraceTest(PallasBaseTest):

Expand Down

0 comments on commit 217c67e

Please sign in to comment.