diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 94dbeb3aa70d..2b1cad7c9a66 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,6 +11,12 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> +## Released with jax 0.5.0 + +* New functionality + + * Added vector support for {func}`jax.experimental.pallas.debug_print` on TPU. + ## Released with jax 0.4.37 * New functionality diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 65bdf9d84c45..3ebc4deeedb0 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3140,29 +3140,71 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int): def _debug_print_rule( ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool ): - if any(aval.shape for aval in ctx.avals_in): - raise NotImplementedError("Only scalar values are supported") - - primitives.check_debug_print_format(fmt, *args) - if has_placeholders: - if not all( - isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 - for arg in args - ): - raise TypeError( - "All arguments must be 32-bit integers when using" - " placeholders (`{...}`). If you need to print values of other types," - " remove placeholders from the format string." + is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in] + is_all_scalars = all(is_scalar_inputs) + is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0] + if not (is_all_scalars or is_single_vector): + raise ValueError( + "All inputs to debug_print must be all scalars or a single vector, but" + f" got {ctx.avals_in}" + ) + + # Scalar case. + if is_all_scalars: + primitives.check_debug_print_format(fmt, *args) + if has_placeholders: + if not all( + isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 + for arg in args + ): + raise TypeError( + "All arguments must be 32-bit integers when using" + " placeholders (`{...}`). If you need to print values of other types," + " remove placeholders from the format string." + ) + + # TPU expects $0, $1 etc as placeholders. + fmt = "".join( + f"{text}${idx}" + for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) ) - # TPU expects $0, $1 etc as placeholders. - tpu_fmt = "".join( - f"{text}${idx}" - for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) + tpu.log(args, fmt, formatted=has_placeholders) + return () + + # Vector case. + # Copy the array to vmem for logging. + # Note that the shape of the array must be explicitly provided here. This is + # because the underlying implementation aligns shapes to tile boundaries, + # potentially altering the original shape and making it unrecoverable. + if len(ctx.avals_in) != 1: + raise ValueError( + "Only one vector input to debug_print is supported." ) - else: - tpu_fmt = fmt - tpu.log(args, tpu_fmt, formatted=has_placeholders) + (aval,) = ctx.avals_in + (arg,) = args + + if not has_placeholders or not fmt.endswith("{}"): + raise ValueError("For vector input, the format string must end with {}.") + + fmt = fmt[:-2] + + region = tpu.RegionOp(()) + with ir.InsertionPoint(region.body): + element_type = _dtype_to_ir_type(aval.dtype) + ref_type = ir.MemRefType.get( + aval.shape, + element_type, + memory_space=ir.Attribute.parse("#tpu.memory_space"), + ) + ref = memref.alloca(ref_type, [], []) + + index_type = ir.IndexType.get() + zero = arith.constant(index_type, 0) + indices = [zero] * len(aval.shape) + vector.store(arg, ref, indices) + tpu.log_buffer(ref, aval.shape, fmt) + tpu.yield_([]) return () diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index d77ca86c152a..8565c266a2aa 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -732,9 +732,12 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike): * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must contain a placeholder for each value to be printed. Format specs and conversions are not supported. All values must be scalars. - * In TPU, if ``fmt`` contains placeholders, all values must be 32-bit - integers. If there are no placeholders, the values are printed after - the format string. All values must be scalars. + * On TPU, if all inputs are scalars: If ``fmt`` contains placeholders, + all values must be 32-bit integers. If there are no placeholders, the + values are printed after the format string. + * On TPU, if the input is a single vector, the vector is printed after + the format string. The format string must end with a single placeholder + ``{}``. *args: The values to print. """ # fmt: skip has_placeholders = False diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 96ecacc6e298..76d4c6c149e6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -781,6 +781,17 @@ def TPU_LogOp : TPU_Op<"log"> { let hasVerifier = 1; } +def TPU_LogBufferOp : TPU_Op<"log_buffer"> { + let arguments = (ins + AnyMemRef:$input, + DenseI64ArrayAttr:$shape, + StrAttr:$tag + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }]; + let hasVerifier = 1; +} + def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::func::FuncDialect", diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 093f1616d85a..f3ab13a8e524 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1134,6 +1134,15 @@ LogicalResult WeirdOp::verify() { return success(); } +LogicalResult LogBufferOp::verify() { + const MemRefType input_type = getInput().getType(); + if (input_type.getRank() != getShape().size()) { + return emitOpError( + "Shape must have the same length as the rank of the input"); + } + return success(); +} + void PackSubelementsOp::build(OpBuilder &builder, OperationState &state, const VectorType output_type, const ArrayRef padded_sources, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ba661ec091cb..544b9e40b322 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1339,10 +1339,11 @@ def kernel(x_ref, o_ref): "plgpu.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Test for TPU is covered in tpu_pallas_test.py") + if config.use_shardy_partitioner.value: self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): @@ -1369,7 +1370,7 @@ def kernel(x_ref, o_ref): ) def test_debug_print_with_values(self): if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") + self.skipTest("Test for TPU is covered in tpu_pallas_test.py") # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 8eee883a536e..7058cacd8079 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2114,6 +2114,51 @@ def kernel(x_ref, o_ref): jax.block_until_ready(compiled_kernel(x)) self.assertIn('x[0] == 42', get_output()) + @parameterized.named_parameters( + (f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype) + for shape in ( + (2, 8, 128), + # test unaligned shapes + (3,), + (3, 4), + (2, 3, 4), + (2, 9, 129), + ) + for dtype in (jnp.int32, jnp.uint32, jnp.float32) + ) + def test_debug_print_vector(self, shape, dtype): + # TODO(ayx): Remove after this date. + if not jtu.if_cloud_tpu_at_least(2025, 1, 16): + self.skipTest("Requires libtpu built after 2025-01-16") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + def kernel(x_ref, o_ref): + pl.debug_print("{}", x_ref[...]) + o_ref[...] = x_ref[...] + + n = np.prod(shape) + x = jnp.arange(n, dtype=dtype).reshape(shape) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({"xla_tpu_enable_log_recorder": "true"}) + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(x)) + output = get_output() + numbers = [ + int(num) + for line in output.splitlines() + if (match := re.search(r"\{(.*)", line)) # extract contents after `{` + for num in re.findall(r"\d+", match.group(1)) + ] + # Check if the numbers in the output match the values generated by `arange`. + self.assertLen(numbers, n) + self.assertTrue(all(num == i for i, num in enumerate(numbers))) + class PallasCallTraceTest(PallasBaseTest):