Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy Support OpShardingRule in SDY round trip import. #17520

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ using ::mlir::func::FuncOp;
using ::mlir::mhlo::CustomCallOp;

using ::mlir::sdy::kShardingAttr;
using ::mlir::sdy::kShardingRuleAttr;
using ::mlir::sdy::MeshAttr;
using ::mlir::sdy::OpShardingRuleAttr;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;

// Builds the shardings coming from Shardy previously. This means
// Builds the shardy attributes coming from Shardy previously. This means
// the module was exported from Shardy and we are now round-tripping back.
// This should happen after the meshes were created from the `ModuleOp` attrs
// (see `SdyRoundTripImportShardingsPass`).
void convertShardings(FuncOp funcOp) {
void convertShardyAttrs(FuncOp funcOp) {
// Copy over the argument shardings, but not the result shardings yet.
// We need to wait until after we've converted all the Operations before
// copying the result shardings.
Expand All @@ -102,7 +104,7 @@ void convertShardings(FuncOp funcOp) {
resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr));
}

// Extract the round-tripped SDY shardings from the operations.
// Extract the round-tripped SDY shardy attributes from the operations.
funcOp.front().walk([&](Operation* op) {
op->removeAttr(kXlaShardingAttr);
if (DictionaryAttr dictAttr = getFrontendAttrs(op)) {
Expand Down Expand Up @@ -141,6 +143,13 @@ void convertShardings(FuncOp funcOp) {
}
}
removeFrontendAttribute(op, kShardingRoundTripAttr);

// Import sharding rules.
if (auto shardingRuleAttr = parseStringAttr<OpShardingRuleAttr>(
dictAttr, kShardingRuleRoundTripAttr)) {
op->setAttr(kShardingRuleAttr, shardingRuleAttr);
removeFrontendAttribute(op, kShardingRuleRoundTripAttr);
}
}
});
}
Expand Down Expand Up @@ -176,14 +185,15 @@ class SdyRoundTripImportShardingsPass
// Insert the meshes before any functions.
builder.setInsertionPointToStart(moduleOp.getBody());
for (NamedAttribute mesh : sdyMeshes) {
mesh.getValue().dump();
auto meshAttr = mlir::cast<MeshAttr>(mesh.getValue());
symbolTable.insert(builder.create<mlir::sdy::MeshOp>(
moduleOp.getLoc(), mesh.getName(), meshAttr));
}
removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr);

for (auto funcOp : moduleOp.getOps<FuncOp>()) {
convertShardings(funcOp);
convertShardyAttrs(funcOp);
}
}

Expand Down
8 changes: 6 additions & 2 deletions xla/service/spmd/shardy/sdy_round_trip/import_shardings.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ limitations under the License.
namespace xla {
namespace sdy {

// Creates the pass that converts the shardings from strings in MHLO frontend
// attributes to SDY meshes and shardings.
// Creates the pass to convert frontend attributes to SDY attributes:
//
// - Converts shardings from `kShardingRoundTripAttr` to `kShardingAttr`
// - Converts sharding rules from `kShardingRuleRoundTripAttr` to
// `kShardingRuleAttr`
// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols
std::unique_ptr<mlir::Pass> createSdyRoundTripImportShardingsPass();

// Registers the xla-sdy-round-trip-import-shardings pass.
Expand Down
21 changes: 20 additions & 1 deletion xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt %s -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s
// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s

// CHECK-LABEL: module @multiple_func_result_shardings
module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>}"}} {
Expand Down Expand Up @@ -110,3 +110,22 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x
return %3 : tensor<32xi32>
}
}

// -----

module @no_mesh_module attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{}"}} {
// CHECK-LABEL: func @no_sharding_rule
func.func @no_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> {
// CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64>
%0 = stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64>
return %0 : tensor<8x2xf64>
}

// CHECK-LABEL: func @op_sharding_rule
func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> {
// CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>}
%0 = stablehlo.custom_call @foo(%arg0, %arg1)
{mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64>
return %0 : tensor<8x2xf64>
}
}
25 changes: 10 additions & 15 deletions xla/service/spmd/shardy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,20 @@ void removeFrontendAttribute(mlir::func::FuncOp funcOp,

void loadAllRequiredDialects(mlir::MLIRContext* context);

// Parses `stringAttr` to an attribute of type `AttrTy`.
//
// NOTE: assumes `stringAttr` is of type `StringAttr`.
template <typename AttrTy>
AttrTy parseStringAttr(mlir::Attribute stringAttr) {
std::string value;
std::string error;
CHECK(absl::CUnescape(mlir::cast<mlir::StringAttr>(stringAttr).getValue(),
&value, &error))
<< error;
return mlir::cast<AttrTy>(
mlir::parseAttribute(value, stringAttr.getContext()));
}

// Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`.
template <typename AttrTy>
AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr,
llvm::StringRef attrName) {
return parseStringAttr<AttrTy>(dictAttr.get(attrName));
if (mlir::Attribute stringAttr = dictAttr.get(attrName)) {
std::string value;
std::string error;
CHECK(absl::CUnescape(mlir::cast<mlir::StringAttr>(stringAttr).getValue(),
&value, &error))
<< error;
return mlir::cast<AttrTy>(
mlir::parseAttribute(value, stringAttr.getContext()));
}
return nullptr;
}

} // namespace sdy
Expand Down
Loading