diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 96ecacc6e298..f709a429c1a1 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -371,6 +371,7 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { } // Integer packs are always signed at the moment. +// Float to integer packing rounds to nearest even. def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> { let arguments = (ins Variadic:$sources, @@ -414,6 +415,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { }]; } +def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [ + I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">, + I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">, +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_RoundingModeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +// Internal operation. All arith.fptosi operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasCanonicalizeMethod = 1; +} def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { let parameters = (ins diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 093f1616d85a..01f1acdcd43d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" #include "absl/strings/str_format.h" +#include "mlir/include/mlir/Dialect/Math/IR/Math.h" #include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" @@ -1053,6 +1054,16 @@ LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op, return success(); } +LogicalResult FPToSIOp::canonicalize(FPToSIOp op, PatternRewriter &rewriter) { + if (auto round_op = op.getInput().getDefiningOp()) { + rewriter.replaceOpWithNewOp( + op, op.getType(), round_op.getOperand(), + tpu::RoundingMode::kToNearestEven); + return success(); + } + return failure(); +} + LogicalResult ConcatenateOp::verify() { auto dimension = getDimension(); if (getOperands().size() < 2) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 505df7311feb..5529dcd99819 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -54,6 +54,7 @@ #include "llvm/include/llvm/Support/LogicalResult.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/include/mlir/Dialect/Math/IR/Math.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/Builders.h" @@ -859,7 +860,7 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_in, const VectorLayout &layout_out) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - auto source = cast>(op.getIn()); + auto source = cast>(op.getOperand()); auto result_ty = cast(op.getResult().getType()); auto output_vregs_shape = layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); @@ -1062,6 +1063,45 @@ LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op, *layouts_out.front()); } +LogicalResult tpu_fptosi_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(layouts_in.size(), 1); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + auto& layout_in = *layouts_in.front(); + auto& layout_out = *layouts_out.front(); + if (layout_in.bitwidth() == layout_out.bitwidth()) { + return elementwise_op_rule(ctx, op, layouts_in, layouts_out); + } else if (layout_in.bitwidth() > layout_out.bitwidth()) { + // FPToSI semantics require rounding towards zero, but packing instructions + // use rounding towards nearest even. We need to insert explicit rounding, + // unless the input is already rounded to nearest even. + auto fptosi_op = cast(op); + switch (fptosi_op.getRoundingMode()) { + case tpu::RoundingMode::kToNearestEven: + break; // That is the mode used by tpu.pack_subelements. + case tpu::RoundingMode::kTowardsZero: { + auto input = cast>(fptosi_op.getInput()); + ImplicitLocOpBuilder builder(op.getLoc(), fptosi_op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array vregs, + disassemble(builder, layout_in, input, ctx.target_shape)); + vregs.Each([&](absl::Span idxs, Value *v) { + *v = builder.create(op.getLoc(), v->getType(), + *v); + }); + fptosi_op->replaceUsesOfWith( + input, assemble(builder, input.getType(), layout_in, vregs, + ctx.target_shape)); + } break; + } + return trunc_op_rule_impl(ctx, fptosi_op, layout_in, layout_out); + } + return op.emitOpError("Unsupported FPToSI conversion"); +} + LogicalResult func_return_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -4672,6 +4712,7 @@ const llvm::StringMap &rules() { {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, + {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, {vector::ExtractOp::getOperationName(), vector_extract_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 11a587e5d5ee..609da2b2013b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -36,6 +36,7 @@ #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/IR/PatternMatch.h" #include "mlir/include/mlir/IR/Region.h" #include "mlir/include/mlir/IR/Value.h" #include "mlir/include/mlir/Support/LLVM.h" @@ -540,6 +541,7 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx, return success(); } +// All conversions that change bitwidth must be canonicalized to tpu.fptosi. LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, Operation &raw_op) { auto op = cast(raw_op); @@ -561,6 +563,24 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, if (dst_bitwidth > 32) { return op.emitOpError("Target bitwidth too large"); } + if (ctx.hardware_generation >= 6 && is_vector && + src_vty.getElementType().isBF16() && + dst_vty.getElementType().isSignlessInteger(8)) { + auto new_op = builder.create( + op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero); + op.replaceAllUsesWith(new_op.getResult()); + op.erase(); + // We briefly trigger canonicalization here to potentially fuse the rounding + // ops into the newly created tpu.fptosi. + { + PatternRewriter rewriter(new_op.getContext()); + rewriter.setInsertionPoint(new_op); + // We don't care if the canonicalization pattern matched or not. + (void)tpu::FPToSIOp::canonicalize(new_op, rewriter); + new_op = nullptr; // Canonicalization may have erased the op! + } + return success(); + } Value x = op.getIn(); // Upcast the input to f32. if (src_bitwidth < 32) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index b49f1991cce0..3af2864ce1e0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -165,6 +165,14 @@ class VectorLayoutInferer { if (inferTrunc(&any_op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op); + op && + cast(op.getOperand().getType()) + .getElementTypeBitWidth() > + cast(op.getType()).getElementTypeBitWidth()) { + if (inferTrunc(&any_op).failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { auto true_ty = dyn_cast(op.getTrueValue().getType()); auto false_ty = dyn_cast(op.getFalseValue().getType()); diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 7e1da537f732..5a0a0be0d3ce 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -14,6 +14,7 @@ """Tests for TPU specific operations within pallas_call.""" import functools +import math import sys import unittest @@ -370,6 +371,31 @@ def kernel(x_ref, mask_ref, o_ref): expected = jnp.where(mask, x, jnp.zeros_like(x)) self.assertArraysEqual(out, expected) + @parameterized.product( + target=(jnp.int8,), # TODO(apaszke): Add int4. + round=(False, True), + ) + def test_quantize(self, target, round): + if not jtu.if_cloud_tpu_at_least(2025, 1, 15): + self.skipTest("Requires libtpu built after 2025-01-15") + shape = (256, 256) + # NOTE: 256 * 256 == 2 ** 16, so those are all bf16 values. + x = lax.bitcast_convert_type( + np.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape), + jnp.bfloat16, + ) + + round_fn = jnp.rint if round else lambda x: x + + def kernel(x_ref, o_ref): + o_ref[...] = round_fn(x_ref[...]).astype(target) + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(shape, target) + )(x) + + ref = jax.jit(lambda x: round_fn(x).astype(target))(x) + np.testing.assert_array_equal(out, ref) + class OpsInterpretTest(OpsTest): INTERPRET = True