Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6 #25829

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
}

// Integer packs are always signed at the moment.
// Float to integer packing rounds to nearest even.
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> {
let arguments = (ins
Variadic<TPU_Vreg>:$sources,
Expand Down Expand Up @@ -414,6 +415,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> {
}];
}

def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [
I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">,
I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";
}

def TPU_RoundingModeEnum : EnumAttr<TPU_Dialect, TPU_RoundingMode, "rounding_mode"> {
let assemblyFormat = "`<` $value `>`";
}

// Internal operation. All arith.fptosi operations that change the bitwidth
// must be canonicalized to this operation.
def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> {
let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode);
let results = (outs AnyVectorOfAnyRank:$output);
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
let hasCanonicalizeMethod = 1;
}

def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> {
let parameters = (ins
Expand Down
11 changes: 11 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/strings/str_format.h"
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -1053,6 +1054,16 @@ LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op,
return success();
}

LogicalResult FPToSIOp::canonicalize(FPToSIOp op, PatternRewriter &rewriter) {
if (auto round_op = op.getInput().getDefiningOp<mlir::math::RoundEvenOp>()) {
rewriter.replaceOpWithNewOp<tpu::FPToSIOp>(
op, op.getType(), round_op.getOperand(),
tpu::RoundingMode::kToNearestEven);
return success();
}
return failure();
}

LogicalResult ConcatenateOp::verify() {
auto dimension = getDimension();
if (getOperands().size() < 2) {
Expand Down
43 changes: 42 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "llvm/include/llvm/Support/LogicalResult.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/Builders.h"
Expand Down Expand Up @@ -859,7 +860,7 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto source = cast<TypedValue<VectorType>>(op.getIn());
auto source = cast<TypedValue<VectorType>>(op.getOperand());
auto result_ty = cast<VectorType>(op.getResult().getType());
auto output_vregs_shape =
layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape);
Expand Down Expand Up @@ -1062,6 +1063,45 @@ LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
*layouts_out.front());
}

LogicalResult tpu_fptosi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
TPU_ASSERT_OP(layouts_in.front().has_value());
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_out.front().has_value());
auto& layout_in = *layouts_in.front();
auto& layout_out = *layouts_out.front();
if (layout_in.bitwidth() == layout_out.bitwidth()) {
return elementwise_op_rule(ctx, op, layouts_in, layouts_out);
} else if (layout_in.bitwidth() > layout_out.bitwidth()) {
// FPToSI semantics require rounding towards zero, but packing instructions
// use rounding towards nearest even. We need to insert explicit rounding,
// unless the input is already rounded to nearest even.
auto fptosi_op = cast<tpu::FPToSIOp>(op);
switch (fptosi_op.getRoundingMode()) {
case tpu::RoundingMode::kToNearestEven:
break; // That is the mode used by tpu.pack_subelements.
case tpu::RoundingMode::kTowardsZero: {
auto input = cast<TypedValue<VectorType>>(fptosi_op.getInput());
ImplicitLocOpBuilder builder(op.getLoc(), fptosi_op);
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> vregs,
disassemble(builder, layout_in, input, ctx.target_shape));
vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<mlir::math::TruncOp>(op.getLoc(), v->getType(),
*v);
});
fptosi_op->replaceUsesOfWith(
input, assemble(builder, input.getType(), layout_in, vregs,
ctx.target_shape));
} break;
}
return trunc_op_rule_impl(ctx, fptosi_op, layout_in, layout_out);
}
return op.emitOpError("Unsupported FPToSI conversion");
}

LogicalResult func_return_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -4672,6 +4712,7 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
{tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule},
{tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule},
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
{vector::ExtractOp::getOperationName(), vector_extract_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},
Expand Down
20 changes: 20 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/IR/PatternMatch.h"
#include "mlir/include/mlir/IR/Region.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
Expand Down Expand Up @@ -540,6 +541,7 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
return success();
}

// All conversions that change bitwidth must be canonicalized to tpu.fptosi.
LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = cast<arith::FPToSIOp>(raw_op);
Expand All @@ -561,6 +563,24 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
if (dst_bitwidth > 32) {
return op.emitOpError("Target bitwidth too large");
}
if (ctx.hardware_generation >= 6 && is_vector &&
src_vty.getElementType().isBF16() &&
dst_vty.getElementType().isSignlessInteger(8)) {
auto new_op = builder.create<tpu::FPToSIOp>(
op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero);
op.replaceAllUsesWith(new_op.getResult());
op.erase();
// We briefly trigger canonicalization here to potentially fuse the rounding
// ops into the newly created tpu.fptosi.
{
PatternRewriter rewriter(new_op.getContext());
rewriter.setInsertionPoint(new_op);
// We don't care if the canonicalization pattern matched or not.
(void)tpu::FPToSIOp::canonicalize(new_op, rewriter);
new_op = nullptr; // Canonicalization may have erased the op!
}
return success();
}
Value x = op.getIn();
// Upcast the input to f32.
if (src_bitwidth < 32) {
Expand Down
8 changes: 8 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ class VectorLayoutInferer {
if (inferTrunc(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::FPToSIOp>(any_op);
op &&
cast<VectorType>(op.getOperand().getType())
.getElementTypeBitWidth() >
cast<VectorType>(op.getType()).getElementTypeBitWidth()) {
if (inferTrunc(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::SelectOp>(any_op)) {
auto true_ty = dyn_cast<VectorType>(op.getTrueValue().getType());
auto false_ty = dyn_cast<VectorType>(op.getFalseValue().getType());
Expand Down
26 changes: 26 additions & 0 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for TPU specific operations within pallas_call."""

import functools
import math
import sys
import unittest

Expand Down Expand Up @@ -370,6 +371,31 @@ def kernel(x_ref, mask_ref, o_ref):
expected = jnp.where(mask, x, jnp.zeros_like(x))
self.assertArraysEqual(out, expected)

@parameterized.product(
target=(jnp.int8,), # TODO(apaszke): Add int4.
round=(False, True),
)
def test_quantize(self, target, round):
if not jtu.if_cloud_tpu_at_least(2025, 1, 15):
self.skipTest("Requires libtpu built after 2025-01-15")
shape = (256, 256)
# NOTE: 256 * 256 == 2 ** 16, so those are all bf16 values.
x = lax.bitcast_convert_type(
np.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape),
jnp.bfloat16,
)

round_fn = jnp.rint if round else lambda x: x

def kernel(x_ref, o_ref):
o_ref[...] = round_fn(x_ref[...]).astype(target)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct(shape, target)
)(x)

ref = jax.jit(lambda x: round_fn(x).astype(target))(x)
np.testing.assert_array_equal(out, ref)


class OpsInterpretTest(OpsTest):
INTERPRET = True
Expand Down
Loading