Skip to content

Commit

Permalink
#sdy add JAX Shardy support for shard_map.
Browse files Browse the repository at this point in the history
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 26, 2024
1 parent 3b6122b commit 0923e35
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 51 deletions.
1 change: 1 addition & 0 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
deps = [
"//xla/hlo/ir:hlo",
"//xla/service:call_inliner",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/strings",
],
)
Expand Down
5 changes: 4 additions & 1 deletion xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand All @@ -119,11 +118,15 @@ cc_library(
":export_ops",
":export_shardings",
":import_shardings",
":shard_map_export",
":shard_map_import",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service:hlo_proto_cc",
"//xla/service/spmd/shardy/mhlo_round_trip:export_shardings",
"//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import",
"//xla/service/spmd/shardy/round_trip_common:pipeline_passes",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
13 changes: 8 additions & 5 deletions xla/service/spmd/shardy/sdy_round_trip/pipelines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,30 @@ limitations under the License.
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h"
#include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h"

namespace xla {
namespace sdy {

using ::mlir::PassPipelineRegistration;

void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) {
// NOTE: we don't do any exporting for ManualComputationOp, since during
// SDY round-trip we expect the same pattern of custom calls to continue to
// exist. We save `sdy.sharding`s on those custom calls during
// Run canonicalizer to simplify `ManualComputationOp`s.
pm.addPass(mlir::createCanonicalizerPass());
// We save `sdy.sharding`s on those custom calls during
// `createSdyRoundTripExportShardingsPass` and make use of
// `createSdyRoundTripImportShardingsPass` to import them.
pm.addPass(createSdyRoundTripExportOpsPass());
pm.addPass(createSdyRoundTripShardMapExportPass());
// Preserve the SDY shardings for `createExportMhloShardingsPass` so that
// we have both `mhlo.sharding`s and hidden `sdy.sharding`s on the module. We
// want to have `mhlo.sharding`s for Pathways to read from.
Expand All @@ -50,8 +54,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) {
void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) {
addCommonPreImportPasses(pm);
pm.addPass(createSdyRoundTripImportShardingsPass());
// TODO(bartchr): replace with an sdy round trip shard map pass.
pm.addPass(createMhloRoundTripShardMapImportPass());
pm.addPass(createSdyRoundTripShardMapImportPass());
addCommonPostImportPasses(pm);
}

Expand Down
20 changes: 9 additions & 11 deletions xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"

Expand All @@ -60,8 +60,8 @@ using ::mlir::StringRef;
using ::mlir::SymbolTable;
using ::mlir::func::CallOp;
using ::mlir::func::FuncOp;
using ::mlir::mhlo::CustomCallOp;

namespace stablehlo = ::mlir::stablehlo;
namespace sdy = ::mlir::sdy;

// Converts a CallOp calling a @xla.sdy.manual_computation_body func with in/out
Expand All @@ -86,23 +86,22 @@ class ManualComputationPattern : public OpConversionPattern<CallOp> {
// we have to take the operands/results of the newly created
// `ManualComputationOp` differently depending on whether the original had
// operands/results.
stablehlo::CustomCallOp fullToShard;
CustomCallOp fullToShard;
mlir::ValueRange operands = callOp->getOperands();
if (!operands.empty()) {
fullToShard =
callOp->getOperand(0).getDefiningOp<stablehlo::CustomCallOp>();
operands = fullToShard->getOperands();
fullToShard = callOp->getOperand(0).getDefiningOp<CustomCallOp>();
CHECK(fullToShard);
CHECK(fullToShard.getCallTargetName() ==
kGlobalToLocalShapeCallTargetName);
operands = fullToShard->getOperands();
}
mlir::TypeRange resultTypes = callOp->getResultTypes();
stablehlo::CustomCallOp shardToFull;
CustomCallOp shardToFull;
if (!resultTypes.empty()) {
CHECK(callOp->getResult(0).hasOneUse())
<< "all CallOp results should be used by a single ShardToFull";
shardToFull = mlir::cast<stablehlo::CustomCallOp>(
*callOp->getResult(0).getUsers().begin());
shardToFull =
mlir::cast<CustomCallOp>(*callOp->getResult(0).getUsers().begin());
CHECK(shardToFull.getCallTargetName() ==
kLocalToGlobalShapeCallTargetName);
resultTypes = shardToFull->getResultTypes();
Expand Down Expand Up @@ -161,8 +160,7 @@ class SdyRoundTripShardMapImportPass
target.addDynamicallyLegalOp<CallOp>([](CallOp op) {
return !absl::StartsWith(op.getCallee(), kManualComputationBodyFuncName);
});
target.addLegalOp<sdy::ManualComputationOp, sdy::ReturnOp,
stablehlo::CustomCallOp>();
target.addLegalOp<sdy::ManualComputationOp, sdy::ReturnOp, CustomCallOp>();
mlir::RewritePatternSet patterns(&context);
patterns.add<ManualComputationPattern>(&context, symbolTable);
if (mlir::failed(mlir::applyPartialConversion(module, target,
Expand Down
5 changes: 4 additions & 1 deletion xla/service/spmd/shardy/shardy_call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ limitations under the License.
#include "absl/strings/match.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/call_inliner.h"
#include "xla/service/spmd/shardy/constants.h"

namespace xla {

bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return CallInliner::IsInlineableCallOp(instruction) &&
!(instruction->GetModule()->config().use_shardy_partitioner() &&
absl::StrContains(instruction->to_apply()->name(), "shmap_body"));
(absl::StrContains(instruction->to_apply()->name(), "shmap_body") ||
absl::StartsWith(instruction->to_apply()->name(),
sdy::kManualComputationBodyFuncName)));
}

} // namespace xla
54 changes: 54 additions & 0 deletions xla/service/spmd/shardy/shardy_call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,59 @@ TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) {
EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4");
}

// Don't inline when the name starts with "xla.sdy.manual_computation_body".
TEST_F(ShardyCallInlinerTest, ManualComputationBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4");
}

// Inliner only checks if the name of the function has
// "xla.sdy.manual_computation_body" a prefix, not if it contains it.
TEST_F(ShardyCallInlinerTest, ManualComputationBodyInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// Will be inlined.
EXPECT_TRUE(changed);
}

} // namespace sdy
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@ func.func @main(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> {
// CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>)
// CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32>
// CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32>
%0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32>
%1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32>
%2:2 = call @shmap_body_4(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>)
%3 = mhlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32>
%4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32>
%5 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32>
%6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32>
%7 = mhlo.add %4, %6 : tensor<128x32xf32>
return %7 : tensor<128x32xf32>
%0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32>
%1:2 = call @xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={\\\22a\\\22, \\\22b\\\22}>]>", xla.sdy.manual_axes = "#sdy<manual_axes{\\\22a\\\22, \\\22b\\\22}>", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22, \\\22b\\\22}, {}]>, <@mesh_1, [{\\\22b\\\22, \\\22a\\\22}, {}]>]>"}} : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>)
%2:2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>)
%3 = mhlo.add %2#0, %2#1 : tensor<128x32xf32>
return %3 : tensor<128x32xf32>
}
// CHECK-NOT: func.func private @shmap_body_4
func.func private @shmap_body_4(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) {
// CHECK-NOT: func.func private @xla.sdy.manual_computation_body
func.func private @xla.sdy.manual_computation_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) {
return %arg0, %arg0 : tensor<16x32xf32>, tensor<16x32xf32>
}
Loading

0 comments on commit 0923e35

Please sign in to comment.