Skip to content

Commit

Permalink
[Mosaic:TPU] Fix elementwise inference with i1s
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703263310
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 5, 2024
1 parent d782b24 commit 651ab18
Showing 1 changed file with 38 additions and 21 deletions.
59 changes: 38 additions & 21 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class VectorLayoutInferer {
false_ty.getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit select supported");
}
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ExtUIOp>(any_op)) {
Expand All @@ -198,7 +198,7 @@ class VectorLayoutInferer {
auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth()
: op.getIn().getType().getIntOrFloatBitWidth();
if (in_bitwidth == 1) {
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else {
Expand All @@ -214,7 +214,7 @@ class VectorLayoutInferer {
TPU_CHECK_OP(static_cast<bool>(lhs_ty) == static_cast<bool>(rhs_ty),
"Only one side of cmp is a vector?");
// TODO(tlongeri): Check that TPU generation supports comparison.
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ConstantOp>(any_op)) {
Expand Down Expand Up @@ -1726,7 +1726,7 @@ class VectorLayoutInferer {
return success();
}

LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) {
LogicalResult inferElementwise(Operation *op) {
TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported");
TPU_CHECK_OP(op->getNumOperands() > 0,
"elementwise ops with no operands unsupported");
Expand All @@ -1735,26 +1735,45 @@ class VectorLayoutInferer {
std::optional<VectorLayout> out_layout_candidate;
std::optional<VectorLayout> out_layout;
SmallVector<std::optional<Layout>, 4> in_layouts;
int64_t bit_width = -1;
int64_t bitwidth = -1;
// Find the bitwidth of the operands/results. They must all be the same
// except for the case of i1s, which use a "fake" bitwidth for layouts.
// They can be relayouted (in principle) to any other fake bitwidth, so we
// don't commit to their bitwidth. See comments in VectorLayout class.
for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
if (const VectorType vty = dyn_cast<VectorType>(val.getType())) {
const int64_t val_bitwidth = vty.getElementTypeBitWidth();
if (val_bitwidth != 1) {
if (bitwidth == -1) {
bitwidth = val_bitwidth;
} else if (bitwidth != val_bitwidth) {
return op->emitOpError(
"Mismatched bitwidth in elementwise for non-i1 "
"operands/results");
}
}
}
}
for (int64_t i = 0; i < op->getNumOperands(); ++i) {
if (auto vty = dyn_cast<VectorType>(op->getOperand(i).getType())) {
if (bit_width == -1) {
bit_width = vty.getElementTypeBitWidth();
}
TPU_CHECK_OP(
!check_bitwidth || bit_width == vty.getElementTypeBitWidth(),
"Generic elementwise rule only supports operands of same width");
auto some_layout = getLayout(op->getOperand(i));
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
// If the input is fully replicated, don't use it to commit to any
// layout. Replicated values are easy to relayout.
if (is_fully_replicated(some_layout)) {
if (bitwidth == -1) {
// All operands/results are i1s, just commit to the first bitwidth
DCHECK(!out_layout.has_value());
bitwidth = layout.bitwidth();
out_layout = layout;
in_layouts.push_back(layout);
} else if (bitwidth != layout.bitwidth()) {
DCHECK_EQ(vty.getElementTypeBitWidth(), 1);
in_layouts.push_back(std::nullopt);
} else if (is_fully_replicated(some_layout)) {
// If the input is fully replicated, don't use it to commit to any
// layout. Replicated values are easy to relayout.
in_layouts.push_back(std::nullopt);
out_layout_candidate = layout;
continue;
}
if (!out_layout) {
} else if (!out_layout) {
// TODO(apaszke): There are probably smarter ways to choose layout.
out_layout = layout;
in_layouts.push_back(some_layout);
Expand All @@ -1768,8 +1787,9 @@ class VectorLayoutInferer {
// any replication bits that might have been present in out_layout,
// since there is no guarantee that the conflicting inputs could
// even become replicated.
DCHECK_EQ(out_layout->bitwidth(), bitwidth);
out_layout =
VectorLayout(out_layout->bitwidth(),
VectorLayout(bitwidth,
{out_layout->offsets()[0].value_or(0),
out_layout->offsets()[1].value_or(0)},
out_layout->tiling(), out_layout->implicit_dim());
Expand All @@ -1784,9 +1804,6 @@ class VectorLayoutInferer {
}
Layout final_out_layout = std::nullopt;
if (auto out_vty = dyn_cast<VectorType>(op->getResult(0).getType())) {
TPU_CHECK_OP(
!check_bitwidth || bit_width == out_vty.getElementTypeBitWidth(),
"Generic elementwise rule can't change element type width");
if (out_layout) {
final_out_layout = *out_layout;
} else if (out_layout_candidate) {
Expand Down

0 comments on commit 651ab18

Please sign in to comment.