Skip to content

Commit

Permalink
[Mosaic:TPU] Allow null parts for tpu.pack_subelements, meaning "don'…
Browse files Browse the repository at this point in the history
…t care"

PiperOrigin-RevId: 707439259
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 18, 2024
1 parent 3262770 commit dc0b774
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 39 deletions.
14 changes: 11 additions & 3 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,21 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
}

// Integer packs are always signed at the moment.
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> {
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> {
let arguments = (ins
Variadic<AnyVectorOfNonZeroRank>:$sources,
Variadic<TPU_Vreg>:$sources,
DenseI32ArrayAttr:$positions,
TPU_PackFormatEnum:$pack_format
);
let results = (outs AnyVectorOfNonZeroRank:$output);
let results = (outs TPU_Vreg:$output);
let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
let builders = [
OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>,
];
let extraClassDeclaration = [{
static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef<int32_t> positions, int packing_factor);
}];
let hasVerifier = 1;
}

def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
Expand Down
49 changes: 49 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstddef>
#include <cstdint>
#include <optional>
#include <string_view>
Expand Down Expand Up @@ -1113,6 +1114,54 @@ LogicalResult WeirdOp::verify() {
return success();
}

void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
const VectorType output_type,
const ArrayRef<Value> padded_sources,
const PackFormat pack_format) {
SmallVector<Value> sources;
SmallVector<int32_t> positions;
for (size_t i = 0; i < padded_sources.size(); ++i) {
if (padded_sources[i] != nullptr) {
sources.push_back(padded_sources[i]);
positions.push_back(i);
}
}
build(builder, state, output_type, sources, positions, pack_format);
}

SmallVector<Value> PackSubelementsOp::getPaddedSources(
ValueRange sources, const ArrayRef<int32_t> positions,
const int packing_factor) {
SmallVector<Value> padded_sources(packing_factor);
for (const auto [source, position] : llvm::zip(sources, positions)) {
padded_sources[position] = source;
}
return padded_sources;
}

LogicalResult PackSubelementsOp::verify() {
if (getSources().empty()) {
return emitOpError("At least one source is required");
}
if (getPositions().size() != getSources().size()) {
return emitOpError("Size of sources and positions must match");
}
const int packing_factor = cast<VectorType>(getSources().front().getType())
.getElementTypeBitWidth() /
getType().getElementTypeBitWidth();
SmallVector<bool> seen_positions(packing_factor, false);
for (const int32_t position : getPositions()) {
if (position < 0 || packing_factor <= position) {
return emitOpError("Positions must be between 0 and the packing factor");
}
if (seen_positions[position]) {
return emitOpError("Positions must be unique");
}
seen_positions[position] = true;
}
return success();
}

} // namespace tpu
} // namespace mlir

Expand Down
113 changes: 77 additions & 36 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,12 +1035,19 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
if (!layout_out.offsets()[1].has_value()) {
idxs_local.back() = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, input_vregs(idxs_local));
} else {
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
if (idxs_local.back() < input_vregs.dimensions().back()) {
parts.push_back(input_vregs(idxs_local));
++idxs_local.back();
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
Expand All @@ -1053,16 +1060,19 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= packing;
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
while (parts.size() < packing) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
} else {
// Once we run out of tiles, we can pick any one we like.
parts.push_back(parts.back());
if (!layout_out.offsets()[0].has_value()) {
*(idxs_local.end() - 2) = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, input_vregs(idxs_local));
} else {
*(idxs_local.end() - 2) *= packing;
for (int64_t i = 0; i < packing; ++i) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
++*(idxs_local.end() - 2);
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
Expand Down Expand Up @@ -6253,6 +6263,11 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
ctx.target_shape[1]}) {
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
// Note: The code below does not work when src is replicated and dst is
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src_offsets);
xla::Array<Value> retiled(dst_tiles_shape);
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
Expand All @@ -6263,19 +6278,29 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
SmallVector<Value, 8> parts;
parts.reserve(packing);
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src_idx[src_idx.size() - 2] *= packing;
src_idx[src_idx.size() - 1] /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
if (src_idx[src_idx.size() - 2] <
vregs.dim(vregs.num_dimensions() - 2) - 1) {
++src_idx[src_idx.size() - 2];
*(src_idx.end() - 1) /= packing;
if (!dst.offsets()[0].has_value()) {
*(src_idx.end() - 2) = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
} else {
*(src_idx.end() - 2) *= packing;
for (int i = 0; i < packing; ++i) {
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
++*(src_idx.end() - 2);
} else {
parts.push_back(nullptr);
}
}
}
*tile = builder.create<tpu::PackSubelementsOp>(
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed);
loc, cast<VectorType>(vregs.begin()->getType()), parts,
tpu::PackFormat::kCompressed);
});
return std::pair(dst, std::move(retiled));
}
Expand Down Expand Up @@ -6334,6 +6359,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
// moving to the next one. This is exactly an interleaving of the sublanes
// of the vreg parts.

// Note: The code below does not work when src is replicated and dst is
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src.offsets());
xla::Array<Value> retiled(dst_tiles_shape);
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
Expand All @@ -6343,20 +6374,30 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
SmallVector<Value> parts;
parts.reserve(packing);
SmallVector<int64_t> src_idx(toArrayRef(idx));
*(src_idx.end() - 2) *= packing;
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
*(src_idx.end() - 1) /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
++*(src_idx.end() - 2);
} // The rest is padding, so just pick any of the input parts (but not
// an arbitrary vreg so we don't add an extra dependency).
if (!dst.offsets()[0].has_value()) {
*(src_idx.end() - 2) = 0;
// Make sure we set all parts of the output vreg to make it replicated
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
} else {
*(src_idx.end() - 2) *= packing;
for (int i = 0; i < packing; ++i) {
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
++*(src_idx.end() - 2);
} else {
parts.push_back(nullptr);
}
}
}
*tile = builder.create<tpu::PackSubelementsOp>(
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kInterleaved);
loc, cast<VectorType>(vregs.begin()->getType()), parts,
tpu::PackFormat::kInterleaved);
});
return std::pair(dst, std::move(retiled));
}
Expand Down

0 comments on commit dc0b774

Please sign in to comment.