diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index eb11cc53f1c45..fc6e55b203678 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -71,15 +71,17 @@ using ::mlir::func::FuncOp; using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshAttr; +using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Builds the shardings coming from Shardy previously. This means +// Builds the shardy attributes coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs // (see `SdyRoundTripImportShardingsPass`). -void convertShardings(FuncOp funcOp) { +void convertShardyAttrs(FuncOp funcOp) { // Copy over the argument shardings, but not the result shardings yet. // We need to wait until after we've converted all the Operations before // copying the result shardings. @@ -102,7 +104,7 @@ void convertShardings(FuncOp funcOp) { resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr)); } - // Extract the round-tripped SDY shardings from the operations. + // Extract the round-tripped SDY shardy attributes from the operations. funcOp.front().walk([&](Operation* op) { op->removeAttr(kXlaShardingAttr); if (DictionaryAttr dictAttr = getFrontendAttrs(op)) { @@ -141,6 +143,13 @@ void convertShardings(FuncOp funcOp) { } } removeFrontendAttribute(op, kShardingRoundTripAttr); + + // Import sharding rules. + if (auto shardingRuleAttr = parseStringAttr( + dictAttr, kShardingRuleRoundTripAttr)) { + op->setAttr(kShardingRuleAttr, shardingRuleAttr); + removeFrontendAttribute(op, kShardingRuleRoundTripAttr); + } } }); } @@ -176,6 +185,7 @@ class SdyRoundTripImportShardingsPass // Insert the meshes before any functions. builder.setInsertionPointToStart(moduleOp.getBody()); for (NamedAttribute mesh : sdyMeshes) { + mesh.getValue().dump(); auto meshAttr = mlir::cast(mesh.getValue()); symbolTable.insert(builder.create( moduleOp.getLoc(), mesh.getName(), meshAttr)); @@ -183,7 +193,7 @@ class SdyRoundTripImportShardingsPass removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr); for (auto funcOp : moduleOp.getOps()) { - convertShardings(funcOp); + convertShardyAttrs(funcOp); } } diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h index 2f77466af8762..c750ef2870555 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h +++ b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h @@ -23,8 +23,12 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts the shardings from strings in MHLO frontend -// attributes to SDY meshes and shardings. +// Creates the pass to convert frontend attributes to SDY attributes: +// +// - Converts shardings from `kShardingRoundTripAttr` to `kShardingAttr` +// - Converts sharding rules from `kShardingRuleRoundTripAttr` to +// `kShardingRuleAttr` +// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols std::unique_ptr createSdyRoundTripImportShardingsPass(); // Registers the xla-sdy-round-trip-import-shardings pass. diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 04944f697ba87..aa7caa093de28 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s // CHECK-LABEL: module @multiple_func_result_shardings module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>}"}} { @@ -110,3 +110,22 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x return %3 : tensor<32xi32> } } + +// ----- + +module @no_mesh_module attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{}"}} { + // CHECK-LABEL: func @no_sharding_rule + func.func @no_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + %0 = stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } + + // CHECK-LABEL: func @op_sharding_rule + func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) + {mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } +} diff --git a/xla/service/spmd/shardy/utils.h b/xla/service/spmd/shardy/utils.h index 80194b3ca04c4..394b6c48e0bfd 100644 --- a/xla/service/spmd/shardy/utils.h +++ b/xla/service/spmd/shardy/utils.h @@ -61,25 +61,20 @@ void removeFrontendAttribute(mlir::func::FuncOp funcOp, void loadAllRequiredDialects(mlir::MLIRContext* context); -// Parses `stringAttr` to an attribute of type `AttrTy`. -// -// NOTE: assumes `stringAttr` is of type `StringAttr`. -template -AttrTy parseStringAttr(mlir::Attribute stringAttr) { - std::string value; - std::string error; - CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), - &value, &error)) - << error; - return mlir::cast( - mlir::parseAttribute(value, stringAttr.getContext())); -} - // Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`. template AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr, llvm::StringRef attrName) { - return parseStringAttr(dictAttr.get(attrName)); + if (mlir::Attribute stringAttr = dictAttr.get(attrName)) { + std::string value; + std::string error; + CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), + &value, &error)) + << error; + return mlir::cast( + mlir::parseAttribute(value, stringAttr.getContext())); + } + return nullptr; } } // namespace sdy