Skip to content

Commit

Permalink
#sdy Support OpShardingRule in SDY round trip import.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677862331
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Sep 27, 2024
1 parent 6c3194e commit 9654502
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
18 changes: 14 additions & 4 deletions xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ using ::mlir::func::FuncOp;
using ::mlir::mhlo::CustomCallOp;

using ::mlir::sdy::kShardingAttr;
using ::mlir::sdy::kShardingRuleAttr;
using ::mlir::sdy::MeshAttr;
using ::mlir::sdy::OpShardingRuleAttr;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;

// Builds the shardings coming from Shardy previously. This means
// Builds the shardy attributes coming from Shardy previously. This means
// the module was exported from Shardy and we are now round-tripping back.
// This should happen after the meshes were created from the `ModuleOp` attrs
// (see `SdyRoundTripImportShardingsPass`).
void convertShardings(FuncOp funcOp) {
void convertShardyAttrs(FuncOp funcOp) {
// Copy over the argument shardings, but not the result shardings yet.
// We need to wait until after we've converted all the Operations before
// copying the result shardings.
Expand All @@ -102,7 +104,7 @@ void convertShardings(FuncOp funcOp) {
resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr));
}

// Extract the round-tripped SDY shardings from the operations.
// Extract the round-tripped SDY shardy attributes from the operations.
funcOp.front().walk([&](Operation* op) {
op->removeAttr(kXlaShardingAttr);
if (DictionaryAttr dictAttr = getFrontendAttrs(op)) {
Expand Down Expand Up @@ -141,6 +143,13 @@ void convertShardings(FuncOp funcOp) {
}
}
removeFrontendAttribute(op, kShardingRoundTripAttr);

// Import sharding rules.
if (auto shardingRuleAttr = parseStringAttr<OpShardingRuleAttr>(
dictAttr, kShardingRuleRoundTripAttr)) {
op->setAttr(kShardingRuleAttr, shardingRuleAttr);
removeFrontendAttribute(op, kShardingRuleRoundTripAttr);
}
}
});
}
Expand Down Expand Up @@ -176,14 +185,15 @@ class SdyRoundTripImportShardingsPass
// Insert the meshes before any functions.
builder.setInsertionPointToStart(moduleOp.getBody());
for (NamedAttribute mesh : sdyMeshes) {
mesh.getValue().dump();
auto meshAttr = mlir::cast<MeshAttr>(mesh.getValue());
symbolTable.insert(builder.create<mlir::sdy::MeshOp>(
moduleOp.getLoc(), mesh.getName(), meshAttr));
}
removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr);

for (auto funcOp : moduleOp.getOps<FuncOp>()) {
convertShardings(funcOp);
convertShardyAttrs(funcOp);
}
}

Expand Down
8 changes: 6 additions & 2 deletions xla/service/spmd/shardy/sdy_round_trip/import_shardings.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ limitations under the License.
namespace xla {
namespace sdy {

// Creates the pass that converts the shardings from strings in MHLO frontend
// attributes to SDY meshes and shardings.
// Creates the pass to convert frontend attributes to SDY attributes:
//
// - Converts shardings from `kShardingRoundTripAttr` to `kShardingAttr`
// - Converts sharding rules from `kShardingRuleRoundTripAttr` to
// `kShardingRuleAttr`
// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols
std::unique_ptr<mlir::Pass> createSdyRoundTripImportShardingsPass();

// Registers the xla-sdy-round-trip-import-shardings pass.
Expand Down
21 changes: 20 additions & 1 deletion xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt %s -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s
// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s

// CHECK-LABEL: module @multiple_func_result_shardings
module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>}"}} {
Expand Down Expand Up @@ -110,3 +110,22 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x
return %3 : tensor<32xi32>
}
}

// -----

module @no_mesh_module attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{}"}} {
// CHECK-LABEL: func @no_sharding_rule
func.func @no_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> {
// CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64>
%0 = stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64>
return %0 : tensor<8x2xf64>
}

// CHECK-LABEL: func @op_sharding_rule
func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> {
// CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>}
%0 = stablehlo.custom_call @foo(%arg0, %arg1)
{mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64>
return %0 : tensor<8x2xf64>
}
}
25 changes: 10 additions & 15 deletions xla/service/spmd/shardy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,20 @@ void removeFrontendAttribute(mlir::func::FuncOp funcOp,

void loadAllRequiredDialects(mlir::MLIRContext* context);

// Parses `stringAttr` to an attribute of type `AttrTy`.
//
// NOTE: assumes `stringAttr` is of type `StringAttr`.
template <typename AttrTy>
AttrTy parseStringAttr(mlir::Attribute stringAttr) {
std::string value;
std::string error;
CHECK(absl::CUnescape(mlir::cast<mlir::StringAttr>(stringAttr).getValue(),
&value, &error))
<< error;
return mlir::cast<AttrTy>(
mlir::parseAttribute(value, stringAttr.getContext()));
}

// Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`.
template <typename AttrTy>
AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr,
llvm::StringRef attrName) {
return parseStringAttr<AttrTy>(dictAttr.get(attrName));
if (mlir::Attribute stringAttr = dictAttr.get(attrName)) {
std::string value;
std::string error;
CHECK(absl::CUnescape(mlir::cast<mlir::StringAttr>(stringAttr).getValue(),
&value, &error))
<< error;
return mlir::cast<AttrTy>(
mlir::parseAttribute(value, stringAttr.getContext()));
}
return nullptr;
}

} // namespace sdy
Expand Down

0 comments on commit 9654502

Please sign in to comment.