diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e8..82b2b176c3400 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,86 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/include/clang/AST/DeclID.h b/clang/include/clang/AST/DeclID.h +--- a/clang/include/clang/AST/DeclID.h ++++ b/clang/include/clang/AST/DeclID.h +@@ -189,6 +189,7 @@ + // Every Decl ID is a local decl ID to the module being writing in ASTWriter. + friend class ASTWriter; + friend class GlobalDeclID; ++ friend struct llvm::DenseMapInfo; + + public: + LocalDeclID() : Base() {} +@@ -266,6 +267,27 @@ + return L == R; + } + }; ++ ++template <> struct DenseMapInfo { ++ using LocalDeclID = clang::LocalDeclID; ++ using DeclID = LocalDeclID::DeclID; ++ ++ static LocalDeclID getEmptyKey() { ++ return LocalDeclID(DenseMapInfo::getEmptyKey()); ++ } ++ ++ static LocalDeclID getTombstoneKey() { ++ return LocalDeclID(DenseMapInfo::getTombstoneKey()); ++ } ++ ++ static unsigned getHashValue(const LocalDeclID &Key) { ++ return DenseMapInfo::getHashValue(Key.getRawValue()); ++ } ++ ++ static bool isEqual(const LocalDeclID &L, const LocalDeclID &R) { ++ return L == R; ++ } ++}; + + } // namespace llvm + +diff -ruN --strip-trailing-cr a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h +--- a/clang/include/clang/Serialization/ASTWriter.h ++++ b/clang/include/clang/Serialization/ASTWriter.h +@@ -233,13 +233,13 @@ + /// instead of comparing the result of `getDeclID()` or `GetDeclRef()`. + llvm::SmallPtrSet PredefinedDecls; + +- /// Mapping from FunctionDecl to the list of lambda IDs inside the function. ++ /// Mapping from FunctionDecl ID to the list of lambda IDs inside the ++ /// function. + /// + /// These lambdas have to be loaded right after the function they belong to. + /// In order to have canonical declaration for lambda class from the same + /// module as enclosing function during deserialization. +- llvm::DenseMap> +- FunctionToLambdasMap; ++ llvm::DenseMap> FunctionToLambdasMap; + + /// Offset of each declaration in the bitstream, indexed by + /// the declaration's ID. +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp +--- a/clang/lib/Serialization/ASTWriter.cpp ++++ b/clang/lib/Serialization/ASTWriter.cpp +@@ -5713,8 +5713,7 @@ + // efficent becuase it allows lazy deserialization. + RecordData FunctionToLambdasMapRecord; + for (const auto &Pair : FunctionToLambdasMap) { +- FunctionToLambdasMapRecord.push_back( +- GetDeclRef(Pair.first).getRawValue()); ++ FunctionToLambdasMapRecord.push_back(Pair.first.getRawValue()); + FunctionToLambdasMapRecord.push_back(Pair.second.size()); + for (const auto &Lambda : Pair.second) + FunctionToLambdasMapRecord.push_back(Lambda.getRawValue()); +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp +--- a/clang/lib/Serialization/ASTWriterDecl.cpp ++++ b/clang/lib/Serialization/ASTWriterDecl.cpp +@@ -1524,7 +1524,8 @@ + // For lambdas inside canonical FunctionDecl remember the mapping. + if (auto FD = llvm::dyn_cast_or_null(D->getDeclContext()); + FD && FD->isCanonicalDecl()) { +- Writer.FunctionToLambdasMap[FD].push_back(Writer.GetDeclRef(D)); ++ Writer.FunctionToLambdasMap[Writer.GetDeclRef(FD)].push_back( ++ Writer.GetDeclRef(D)); + } + } else { + Record.push_back(CXXRecNotTemplate); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 7b11086785b61..106f3665c46e9 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" - LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" + LLVM_COMMIT = "23487be4903630a4c06160562fb939f6389aa99d" + LLVM_SHA256 = "7c4c8c99df91e9e9859006b0435f83b5ed1260289a649befacfb529dc0a5f68f" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 014b81b4e7518..fbdfd0815e3ce 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,8 +1,670 @@ +diff --git a/docs/sdy_export_passes.md b/docs/sdy_export_passes.md +index 7a7e3ef..d5f678f 100755 +--- a/docs/sdy_export_passes.md ++++ b/docs/sdy_export_passes.md +@@ -1,4 +1,53 @@ + ++### `-sdy-insert-explicit-reshards` ++ ++_Inserts explicit reshards to make all operations have compatible shardings._ ++ ++A compatible sharding essentially means that the operation can accept the ++sharded operands and produce a sharded result without requiring any reshard ++communications (note that the operation might still require communication ++such as all-reduce or halo-swaps). ++ ++After propagation, some opeartions may still have incompatible shardings. ++ ++Please note, when an axis (or sub-axis) is used to shard non-corresponding ++dimensions (e.g. non-contracting dimensions in matmul) across multiple ++tensors, or when an axis shards a dimension in one tensor but not the ++corresponding dimension in the other tensor, it is said that the operation ++has a sharding conflict. Hence, after this pass, the opeartions become ++conflict-free. ++ ++This pass injects reshard operations explicitly so that, for each operation, ++corresponding dimensions become sharded in the same way across all operands ++and results, and every axis (or sub-axis) can only be used to shard a single ++dimension type. ++ ++A clarifying example: ++ ++Input: ++```mlir ++mesh = <"x"=4, "y"=2> ++%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>} ++%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} ++stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} ++ : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> ++``` ++ ++Output: ++```mlir ++sdy.mesh = <"x"=4, "y"=2> ++%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>} ++%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} ++%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32> ++stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} ++ : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> ++``` ++ ++In the example above, there is a conflict since `lhs` and `rhs` tensors ++are both sharded on axis "x" on their non-contracting dimensions. Here, ++`rhs` tensor is resharded, before the dot operation, explicitly to be ++sharded only on its first dimension and on axis "x". This way, the dot ++opearation becomes compatible. + ### `-sdy-sharding-constraint-to-reshard` + + _Converts ShardingConstraintOp into ReshardOp._ +diff --git a/shardy/dialect/sdy/transforms/export/BUILD b/shardy/dialect/sdy/transforms/export/BUILD +index deb699d..16aed54 100644 +--- a/shardy/dialect/sdy/transforms/export/BUILD ++++ b/shardy/dialect/sdy/transforms/export/BUILD +@@ -36,6 +36,7 @@ cc_library( + name = "passes", + srcs = [ + "export_pipeline.cc", ++ "insert_explicit_reshards.cc", + "sharding_constraint_to_reshard.cc", + "sink_data_flow_edges.cc", + "update_non_divisible_input_output_shardings.cc", +diff --git a/shardy/dialect/sdy/transforms/export/export_pipeline.cc b/shardy/dialect/sdy/transforms/export/export_pipeline.cc +index 1dde661..1de197c 100644 +--- a/shardy/dialect/sdy/transforms/export/export_pipeline.cc ++++ b/shardy/dialect/sdy/transforms/export/export_pipeline.cc +@@ -28,6 +28,7 @@ void addExportPipeline(OpPassManager& pm, StringRef dumpDirectory) { + pm.addNestedPass(createShardingConstraintToReshardPass()); + pm.addNestedPass( + createUpdateNonDivisibleInputOutputShardingsPass()); ++ pm.addNestedPass(createInsertExplicitReshardsPass()); + pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory, + "sdy_module_after_sdy_export")); + } +diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +new file mode 100644 +index 0000000..e92b13a +--- /dev/null ++++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +@@ -0,0 +1,42 @@ ++/* Copyright 2024 The Shardy Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include ++ ++#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep ++#include "mlir/Pass/Pass.h" // IWYU pragma: keep ++#include "shardy/dialect/sdy/ir/dialect.h" // IWYU pragma: keep ++ ++namespace mlir { ++namespace sdy { ++ ++#define GEN_PASS_DEF_INSERTEXPLICITRESHARDSPASS ++#include "shardy/dialect/sdy/transforms/export/passes.h.inc" ++ ++namespace { ++ ++struct InsertExplicitReshardsPass ++ : public impl::InsertExplicitReshardsPassBase { ++ using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase; ++ ++ void runOnOperation() final { ++ // Not ready yet. It is currently a no-op. ++ } ++}; ++ ++} // namespace ++ ++} // namespace sdy ++} // namespace mlir +diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td +index c87a24f..1c0725c 100644 +--- a/shardy/dialect/sdy/transforms/export/passes.td ++++ b/shardy/dialect/sdy/transforms/export/passes.td +@@ -43,3 +43,55 @@ def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible- + }]; + let dependentDialects = ["mlir::sdy::SdyDialect"]; + } ++ ++def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::FuncOp"> { ++ let summary = "Inserts explicit reshards to make all operations have compatible shardings."; ++ let description = [{ ++ A compatible sharding essentially means that the operation can accept the ++ sharded operands and produce a sharded result without requiring any reshard ++ communications (note that the operation might still require communication ++ such as all-reduce or halo-swaps). ++ ++ After propagation, some opeartions may still have incompatible shardings. ++ ++ Please note, when an axis (or sub-axis) is used to shard non-corresponding ++ dimensions (e.g. non-contracting dimensions in matmul) across multiple ++ tensors, or when an axis shards a dimension in one tensor but not the ++ corresponding dimension in the other tensor, it is said that the operation ++ has a sharding conflict. Hence, after this pass, the opeartions become ++ conflict-free. ++ ++ This pass injects reshard operations explicitly so that, for each operation, ++ corresponding dimensions become sharded in the same way across all operands ++ and results, and every axis (or sub-axis) can only be used to shard a single ++ dimension type. ++ ++ A clarifying example: ++ ++ Input: ++ ```mlir ++ mesh = <"x"=4, "y"=2> ++ %lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>} ++ %rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} ++ stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} ++ : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> ++ ``` ++ ++ Output: ++ ```mlir ++ sdy.mesh = <"x"=4, "y"=2> ++ %lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>} ++ %rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} ++ %0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32> ++ stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} ++ : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> ++ ``` ++ ++ In the example above, there is a conflict since `lhs` and `rhs` tensors ++ are both sharded on axis "x" on their non-contracting dimensions. Here, ++ `rhs` tensor is resharded, before the dot operation, explicitly to be ++ sharded only on its first dimension and on axis "x". This way, the dot ++ opearation becomes compatible. ++ }]; ++ let dependentDialects = ["mlir::sdy::SdyDialect"]; ++} +diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +new file mode 100644 +index 0000000..3eb7c42 +--- /dev/null ++++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +@@ -0,0 +1,20 @@ ++// RUN: sdy_opt %s -sdy-insert-explicit-reshards | FileCheck %s ++ ++sdy.mesh @mesh = <["x"=4, "y"=2]> ++ ++// CHECK-LABEL: func @dot_matrix_matrix_compatible ++func.func @dot_matrix_matrix_compatible(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> tensor<8x16xf32> { ++ // CHECK: stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} ++ // CHECK-NOT: sdy.reshard ++ %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> ++ return %0 : tensor<8x16xf32> ++} ++ ++ ++// CHECK-LABEL: func @dot_matrix_matrix_incompatible_same_non_contracting_dims ++func.func @dot_matrix_matrix_incompatible_same_non_contracting_dims(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<8x16xf32> { ++ // CHECK: stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} ++ // CHECK-NOT: sdy.reshard ++ %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> ++ return %0 : tensor<8x16xf32> ++} +diff --git a/shardy/dialect/sdy/transforms/import/import_pipeline.cc b/shardy/dialect/sdy/transforms/import/import_pipeline.cc +index 5ff4027..7c89a85 100644 +--- a/shardy/dialect/sdy/transforms/import/import_pipeline.cc ++++ b/shardy/dialect/sdy/transforms/import/import_pipeline.cc +@@ -36,6 +36,9 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) { + pm.addNestedPass(createConstantSplitterPass()); + pm.addNestedPass(createAddDataFlowEdgesPass()); + pm.addNestedPass(createApplyShardingConstraintsPass()); ++ // The sharding group import pass must run after applying sharding ++ // constraints. This ensures we can detect sharding conflicts between group ++ // members which have pre-propagation shardings due to sharding constraints. + pm.addPass(createShardingGroupImportPass()); + pm.addPass(createImportMaximalShardingPass()); + +diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +index 348685f..74d66fe 100644 +--- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc ++++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +@@ -17,6 +17,7 @@ limitations under the License. + #include // IWYU pragma: keep + + #include "llvm/ADT/DenseMap.h" ++#include "llvm/ADT/MapVector.h" + #include "llvm/ADT/EquivalenceClasses.h" + #include "llvm/ADT/SmallVector.h" + #include "mlir/IR/BuiltinOps.h" +@@ -38,47 +39,10 @@ namespace { + using llvm::DenseMap; + using llvm::EquivalenceClasses; + using llvm::SmallDenseMap; +-using llvm::SmallVector; + + using ValueToShardingGroup = +- llvm::DenseMap>; +- +-void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) { +- if (tensorToGroups.empty()) { +- return; +- } +- // Merge the equivalence classes of group ids which had the same tensors +- // within them. (unionSets uses the default comparator and will consider the +- // minimum group_id as the representative element of the equivalence class). +- EquivalenceClasses shardingGroupEquivalences; +- for (auto& [_, groupsForTensor] : tensorToGroups) { +- const int64_t canonicalId = groupsForTensor.front().getGroupId(); +- for (ShardingGroupOp group : groupsForTensor) { +- shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId()); +- } +- } +- +- // After merging groups we reindex the group IDs so that they take values +- // from the set {0,1,...,N-1} (N is the number of equivalence classes). +- // The leader element of each equivalent class corresponds to the minimum +- // group_id, so by looping over the group leaders in order their reindexed +- // ids can be set to maintain the same relative ordering. +- int64_t reindexId = 0; +- SmallDenseMap reindexMap; +- for (const auto& group : shardingGroupEquivalences) { +- if (group.isLeader()) { +- reindexMap[group.getData()] = reindexId++; +- } +- } +- +- // Update the graph to replace group_ids with their canonical id. +- for (auto& [_, groupsForTensor] : tensorToGroups) { +- for (ShardingGroupOp op : groupsForTensor) { +- op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( +- op.getGroupId())]); +- } +- } +-} ++ llvm::MapVector>; ++using GroupIdToShardingGroups = SmallVector>; + + LogicalResult buildShardingGroupMappingAndValidateGroups( + ModuleOp module, ValueToShardingGroup& tensorToGroups) { +@@ -126,6 +90,83 @@ LogicalResult buildShardingGroupMappingAndValidateGroups( + return failure(result.wasInterrupted()); + } + ++GroupIdToShardingGroups unifyShardingGroups( ++ ValueToShardingGroup& tensorToGroups) { ++ // Merge the equivalence classes of group ids which had the same tensors ++ // within them. (unionSets uses the default comparator and will consider the ++ // minimum group_id as the representative element of the equivalence class). ++ EquivalenceClasses shardingGroupEquivalences; ++ for (auto& [_, groupsForTensor] : tensorToGroups) { ++ int64_t canonicalId = groupsForTensor.front().getGroupId(); ++ for (ShardingGroupOp group : groupsForTensor) { ++ shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId()); ++ } ++ } ++ ++ // After merging groups we reindex the group IDs so that they take values ++ // from the set {0,1,...,N-1} (N is the number of equivalence classes). ++ // The leader element of each equivalent class corresponds to the minimum ++ // group_id, so by looping over the group leaders in order their reindexed ++ // ids can be set to maintain the same relative ordering. ++ int64_t reindexId = 0; ++ SmallDenseMap reindexMap; ++ for (const auto& group : shardingGroupEquivalences) { ++ if (group.isLeader()) { ++ reindexMap[group.getData()] = reindexId++; ++ } ++ } ++ ++ GroupIdToShardingGroups reindexGroups(reindexId); ++ // Update the graph to replace group_ids with their canonical id. ++ for (auto& [_, groupsForTensor] : tensorToGroups) { ++ for (ShardingGroupOp op : groupsForTensor) { ++ op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( ++ op.getGroupId())]); ++ reindexGroups[op.getGroupId()].push_back(op); ++ } ++ } ++ return reindexGroups; ++} ++ ++// This function verifies that sharding groups with pre-existing shardings are ++// compatible. Compatibility means all values in the group must have either no ++// sharding or the same sharding. ++LogicalResult validateCompatibilityAndApplyInitialShardingConstraints( ++ ModuleOp module, GroupIdToShardingGroups& groupIdToShardingGroups) { ++ SmallDenseMap groupIdToSharding; ++ // Tensors can have initial shardings defined in several ways (e.g., sharding ++ // constraints, function arguments, manual computations). These initial ++ // shardings only conflict with Sharding Groups if their value belongs to a ++ // group. Therefore, we only need to validate the consistency of shardings ++ // within ShardingGroupOps to ensure no conflicts. ++ for (const auto& shardingGroups : groupIdToShardingGroups) { ++ for (ShardingGroupOp shardingGroupOp : shardingGroups) { ++ TensorShardingAttr sharding = getSharding(shardingGroupOp.getInput()); ++ int64_t groupId = shardingGroupOp.getGroupId(); ++ if (!sharding) { ++ continue; ++ } ++ auto [it, inserted] = groupIdToSharding.try_emplace(groupId, sharding); ++ if (!inserted && it->second != sharding) { ++ shardingGroupOp.emitError( ++ "Inconsistent shardings prior to propagation for ShardingGroupOps " ++ "with canonicalized groupId: ") ++ << groupId; ++ return failure(); ++ } ++ } ++ } ++ ++ // Apply initial shardings to all values in the group. ++ for (auto& [groupId, sharding] : groupIdToSharding) { ++ for (ShardingGroupOp shardingGroupOp : groupIdToShardingGroups[groupId]) { ++ setSharding(shardingGroupOp.getInput(), sharding); ++ } ++ } ++ ++ return success(); ++} ++ + struct ShardingGroupImportPass + : public impl::ShardingGroupImportPassBase { + using ShardingGroupImportPassBase::ShardingGroupImportPassBase; +@@ -134,12 +175,26 @@ struct ShardingGroupImportPass + // Extract the sharding group ids and tensor -> {group_id} mapping from the + // high level module and validate any sharding group constrainst are met. + ValueToShardingGroup tensorToGroups; +- if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(), ++ ModuleOp module = getOperation(); ++ if (failed(buildShardingGroupMappingAndValidateGroups(module, + tensorToGroups))) { + signalPassFailure(); + } ++ // If there are no sharding groups, the rest of the preprocessing steps ++ // are not necessary. ++ if (tensorToGroups.empty()) { ++ return; ++ } + +- unifyShardingGroups(tensorToGroups); ++ GroupIdToShardingGroups groupIdToReindexedTensors = ++ unifyShardingGroups(tensorToGroups); ++ // This pass assumes sharding constraints are already applied to values. ++ // Compatibility constraints are applied after group unification to detect ++ // conflicts within the unified groups. ++ if (failed(validateCompatibilityAndApplyInitialShardingConstraints( ++ module, groupIdToReindexedTensors))) { ++ signalPassFailure(); ++ } + } + }; + +diff --git a/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir b/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir +index c3086e7..aa4dcca 100644 +--- a/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir ++++ b/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir +@@ -53,3 +53,59 @@ func.func @main(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) { + sdy.sharding_group %arg1 group_id = 3456 : tensor<8x8xf32> + func.return + } ++ ++// ----- ++ ++// Verifies that the `-apply-sharding-constraints` pass is applied before the ++// `-sharding-group-import` pass. This is validated by asserting that members ++// of a sharding group pick up the sharding of a group member with a sharding ++// constraint (the constraint needs to be added to the value in order for it to ++// be applied to other group members). ++sdy.mesh @mesh = <["a"=2]> ++// CHECK-LABEL: func.func @main ++func.func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { ++ // CHECK: %0 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} ++ %0 = stablehlo.add %arg0, %arg0 : tensor<16x16xf32> ++ %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"a"}]> : tensor<16x16xf32> ++ // CHECK: %2 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} ++ %2 = stablehlo.add %arg0, %arg0 : tensor<16x16xf32> ++ sdy.sharding_group %0 group_id = 32 : tensor<16x16xf32> ++ sdy.sharding_group %2 group_id = 32 : tensor<16x16xf32> ++ return %1 : tensor<16x16xf32> ++} ++ ++// ----- ++ ++// Verifies that the `-sdy-add-data-flow-edges` pass is applied before the ++// `-sharding-group-import` pass. This is validated by adding a block argument ++// of a while op to a sharding group which has a sharding constraint. This ++// should be applied to other members of the group but can only happen if the ++// `-sdy-add-data-flow-edges` pass is applied first. ++ ++sdy.mesh @mesh = <["a"=2]> ++ ++// CHECK: func.func @main ++// CHECK-NEXT %arg0: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} ++func.func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { ++ %0 = stablehlo.constant dense<0> : tensor ++ %inc = stablehlo.constant dense<1> : tensor ++ %comp = stablehlo.constant dense<32> : tensor ++ %1:2 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %0) : tensor<16x16xf32>, tensor ++ cond { ++ %2 = stablehlo.compare LT, %iterArg_2, %comp : (tensor, tensor) -> tensor ++ stablehlo.return %2 : tensor ++ } do { ++ %2 = stablehlo.add %iterArg_2, %inc : tensor ++ // Add a value with an explicit sharding to group_id=50 which will apply an ++ // initial sharding to the result of the WhileOp outside of the loop. ++ %3 = stablehlo.add %iterArg, %iterArg : tensor<16x16xf32> ++ %4 = sdy.sharding_constraint %3 <@mesh, [{"a"}, {}]> : tensor<16x16xf32> ++ sdy.sharding_group %3 group_id = 50 : tensor<16x16xf32> ++ stablehlo.return %3, %2 : tensor<16x16xf32>, tensor ++ } ++ ++ // CHECK: sdy.data_flow_edge %3#0 sharding=<@mesh, [{"a"}, {}]> : tensor<16x16xf32> ++ sdy.sharding_group %1#0 group_id = 50 : tensor<16x16xf32> ++ return %1#0 : tensor<16x16xf32> ++} ++ +diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir +index 341d14d..6c0ba90 100644 +--- a/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir ++++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir +@@ -184,3 +184,76 @@ func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + func.return %0: tensor<8x8xf32> + } + ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++// Throw error for sharding groups which have incompatible shardings inferred ++// from initial constraints. ++func.func @error_for_incompatible_shardings_in_sharding_group( ++ %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, ++ %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}) { ++ // Sharding Group and Sharding Constraint compatibility checks happend after ++ // unification + canonicalization of group ids, which is why the group id ++ // below (555) corresponds to group id: 0 in the check-error. ++ sdy.sharding_group %arg0 group_id = 555 : tensor<8x8xf32> ++ // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} ++ sdy.sharding_group %arg1 group_id = 555 : tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++// Throw error for sharding groups which have incompatible shardings inferred ++// from initial constraints. ++func.func @error_for_transitively_inferred_incompatible_shardings_in_unified_sharding_group( ++ %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, ++ %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { ++ ++ %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ ++ sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32> ++ sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32> ++ sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32> ++ sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32> ++ ++ // The shard group below will cause the above sharding groups to be merged ++ // by transitivity this implies that all of {%arg0, %arg1, 0, 1} should have ++ // the same sharding. Note that %0 and %1 are compatible by them selves but ++ // %arg0 and %arg1 are not due to their initial shardings. ++ sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32> ++ // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} ++ sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++func.func @error_for_incompatible_shardings_in_manual_computation(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) { ++ %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"a"}, {}]>, <@mesh, [{"b"}, {}]>] out_shardings=[<@mesh, [{"b"}, {}]>] manual_axes={} (%arg2: tensor<8x8xf32>, %arg3: tensor<8x8xf32>) { ++ sdy.sharding_group %arg2 group_id = 8675 : tensor<8x8xf32> ++ // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} ++ sdy.sharding_group %arg3 group_id = 8675 : tensor<8x8xf32> ++ sdy.return %arg2 : tensor<8x8xf32> ++ } : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++func.func @error_for_incompatible_shardings_with_sharding_constraint(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { ++ %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> ++ %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"b"}]> : tensor<8x8xf32> ++ sdy.sharding_group %arg0 group_id = 1000 : tensor<8x8xf32> ++ // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} ++ sdy.sharding_group %1 group_id = 1000 : tensor<8x8xf32> ++ func.return ++} ++ +diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +index 7cd8589..9fd7e88 100644 +--- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir ++++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +@@ -79,3 +79,107 @@ func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: + sdy.sharding_group %arg2 group_id = 123456 : tensor<4xf32> + func.return + } ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++// CHECK-LABEL: set_existing_shardings_for_sharding_group_members ++func.func @set_existing_shardings_for_sharding_group_members( ++ %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, ++ %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}) { ++ // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32> ++ %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ ++ sdy.sharding_group %arg0 group_id = 43210 : tensor<8x8xf32> ++ sdy.sharding_group %arg1 group_id = 43210 : tensor<8x8xf32> ++ sdy.sharding_group %0 group_id = 43210 : tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++// CHECK-LABEL: transitively_update_shardings_for_sharding_group_members ++func.func @transitively_update_shardings_for_sharding_group_members( ++ %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, ++ %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { ++ // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> ++ // CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> ++ %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ ++ sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32> ++ sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32> ++ sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32> ++ sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32> ++ sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32> ++ sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++// CHECK-LABEL: set_existing_shards_for_disjoint_groups ++// CHECK-SAMEL %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} ++// CHECK-SAMEL %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} ++func.func @set_existing_shards_for_disjoint_groups( ++ %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, ++ %arg1: tensor<8x8xf32>, ++ %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}, ++ %arg3: tensor<8x8xf32>) { ++ // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> ++ %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ // CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32> ++ %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ // CHECK: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8xf32> ++ %2 = stablehlo.constant dense<0.0> : tensor<8x8xf32> ++ ++ sdy.sharding_group %arg0 group_id = 11111 : tensor<8x8xf32> ++ sdy.sharding_group %arg1 group_id = 11111 : tensor<8x8xf32> ++ sdy.sharding_group %0 group_id = 11111 : tensor<8x8xf32> ++ ++ sdy.sharding_group %arg2 group_id = 22222 : tensor<8x8xf32> ++ sdy.sharding_group %arg3 group_id = 22222 : tensor<8x8xf32> ++ sdy.sharding_group %1 group_id = 22222 : tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++// CHECK-LABEL: set_existing_shardings_in_manual_computation_op ++func.func @set_existing_shardings_in_manual_computation_op(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) { ++ %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"a"}, {}]>, <@mesh, [{"a"}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>] manual_axes={} (%arg2: tensor<8x8xf32>, %arg3: tensor<8x8xf32>) { ++ // CHECK: %1 = stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} : tensor<8x8xf32> ++ %1 = stablehlo.add %arg2, %arg2 : tensor<8x8xf32> ++ // CHECK: %2 = stablehlo.add %arg3, %arg3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} : tensor<8x8xf32> ++ %2 = stablehlo.add %arg3, %arg3 : tensor<8x8xf32> ++ ++ sdy.sharding_group %1 group_id = 1000 : tensor<8x8xf32> ++ sdy.sharding_group %2 group_id = 1000 : tensor<8x8xf32> ++ sdy.sharding_group %arg2 group_id = 1000 : tensor<8x8xf32> ++ sdy.sharding_group %arg3 group_id = 1000 : tensor<8x8xf32> ++ sdy.return %1 : tensor<8x8xf32> ++ } : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> ++ func.return ++} ++ ++// ----- ++ ++sdy.mesh @mesh = <["a"=2, "b"=2]> ++ ++func.func @set_existing_shardings_in_groups_with_sharding_constraint(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { ++ %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> ++ %1 = sdy.sharding_constraint %0 <@mesh, [{"a"}, {}]> : tensor<8x8xf32> ++ // CHECK: %2 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} : tensor<8x8xf32> ++ %2 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> ++ sdy.sharding_group %arg0 group_id = 1000 : tensor<8x8xf32> ++ sdy.sharding_group %1 group_id = 1000 : tensor<8x8xf32> ++ sdy.sharding_group %2 group_id = 1000 : tensor<8x8xf32> ++ func.return ++} diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index de92cb4..509398d 100644 +index de92cb4..82b2b17 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,4095 +1 @@ +@@ -1,4095 +1,86 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst ---- a/llvm/docs/NVPTXUsage.rst @@ -219,7 +881,20 @@ index de92cb4..509398d 100644 - Expand = false; - -@@ -2271,117 +2258,6 @@ -- } ++diff -ruN --strip-trailing-cr a/clang/include/clang/AST/DeclID.h b/clang/include/clang/AST/DeclID.h ++--- a/clang/include/clang/AST/DeclID.h +++++ b/clang/include/clang/AST/DeclID.h ++@@ -189,6 +189,7 @@ ++ // Every Decl ID is a local decl ID to the module being writing in ASTWriter. ++ friend class ASTWriter; ++ friend class GlobalDeclID; +++ friend struct llvm::DenseMapInfo; ++ ++ public: ++ LocalDeclID() : Base() {} ++@@ -266,6 +267,27 @@ ++ return L == R; + } - } - --static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, @@ -985,12 +1660,80 @@ index de92cb4..509398d 100644 -+ : NVPTX::cvta_local_64) -+ : NVPTX::cvta_local; - break; -- } ++ }; +++ +++template <> struct DenseMapInfo { +++ using LocalDeclID = clang::LocalDeclID; +++ using DeclID = LocalDeclID::DeclID; +++ +++ static LocalDeclID getEmptyKey() { +++ return LocalDeclID(DenseMapInfo::getEmptyKey()); +++ } +++ +++ static LocalDeclID getTombstoneKey() { +++ return LocalDeclID(DenseMapInfo::getTombstoneKey()); +++ } +++ +++ static unsigned getHashValue(const LocalDeclID &Key) { +++ return DenseMapInfo::getHashValue(Key.getRawValue()); +++ } +++ +++ static bool isEqual(const LocalDeclID &L, const LocalDeclID &R) { +++ return L == R; +++ } +++}; ++ ++ } // namespace llvm ++ ++diff -ruN --strip-trailing-cr a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h ++--- a/clang/include/clang/Serialization/ASTWriter.h +++++ b/clang/include/clang/Serialization/ASTWriter.h ++@@ -233,13 +233,13 @@ ++ /// instead of comparing the result of `getDeclID()` or `GetDeclRef()`. ++ llvm::SmallPtrSet PredefinedDecls; ++ ++- /// Mapping from FunctionDecl to the list of lambda IDs inside the function. +++ /// Mapping from FunctionDecl ID to the list of lambda IDs inside the +++ /// function. ++ /// ++ /// These lambdas have to be loaded right after the function they belong to. ++ /// In order to have canonical declaration for lambda class from the same ++ /// module as enclosing function during deserialization. ++- llvm::DenseMap> ++- FunctionToLambdasMap; +++ llvm::DenseMap> FunctionToLambdasMap; ++ ++ /// Offset of each declaration in the bitstream, indexed by ++ /// the declaration's ID. ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp ++--- a/clang/lib/Serialization/ASTWriter.cpp +++++ b/clang/lib/Serialization/ASTWriter.cpp ++@@ -5713,8 +5713,7 @@ ++ // efficent becuase it allows lazy deserialization. ++ RecordData FunctionToLambdasMapRecord; ++ for (const auto &Pair : FunctionToLambdasMap) { ++- FunctionToLambdasMapRecord.push_back( ++- GetDeclRef(Pair.first).getRawValue()); +++ FunctionToLambdasMapRecord.push_back(Pair.first.getRawValue()); ++ FunctionToLambdasMapRecord.push_back(Pair.second.size()); ++ for (const auto &Lambda : Pair.second) ++ FunctionToLambdasMapRecord.push_back(Lambda.getRawValue()); ++diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp ++--- a/clang/lib/Serialization/ASTWriterDecl.cpp +++++ b/clang/lib/Serialization/ASTWriterDecl.cpp ++@@ -1524,7 +1524,8 @@ ++ // For lambdas inside canonical FunctionDecl remember the mapping. ++ if (auto FD = llvm::dyn_cast_or_null(D->getDeclContext()); ++ FD && FD->isCanonicalDecl()) { ++- Writer.FunctionToLambdasMap[FD].push_back(Writer.GetDeclRef(D)); +++ Writer.FunctionToLambdasMap[Writer.GetDeclRef(FD)].push_back( +++ Writer.GetDeclRef(D)); + } -- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); -+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -+ Src)); - return; -- } else { + } else { - // Generic to specific -@@ -1153,28 +1153,30 @@ - Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; @@ -4098,8 +4841,9 @@ index de92cb4..509398d 100644 - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 37 // DW_AT_call_column ++ Record.push_back(CXXRecNotTemplate); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index af35fe7..7b11086 100644 +index af35fe7..106f366 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") @@ -4108,8 +4852,23 @@ index af35fe7..7b11086 100644 """Imports LLVM.""" - LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" - LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" -+ LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" -+ LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" ++ LLVM_COMMIT = "23487be4903630a4c06160562fb939f6389aa99d" ++ LLVM_SHA256 = "7c4c8c99df91e9e9859006b0435f83b5ed1260289a649befacfb529dc0a5f68f" tf_http_archive( name = name, +diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl +index 2e87599..0a9d3d0 100644 +--- a/third_party/stablehlo/workspace.bzl ++++ b/third_party/stablehlo/workspace.bzl +@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + + def repo(): + # +- STABLEHLO_COMMIT = "ca13d31b5ed0b2053dde0a624480ad765e219ebf" +- STABLEHLO_SHA256 = "123462093f087f2576bb6a6cc471370eed2d43c291f881ff359fd4ca812003db" ++ STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee" ++ STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815" + # + + tf_http_archive( diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 509398da979e8..82b2b176c3400 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1 +1,86 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/include/clang/AST/DeclID.h b/clang/include/clang/AST/DeclID.h +--- a/clang/include/clang/AST/DeclID.h ++++ b/clang/include/clang/AST/DeclID.h +@@ -189,6 +189,7 @@ + // Every Decl ID is a local decl ID to the module being writing in ASTWriter. + friend class ASTWriter; + friend class GlobalDeclID; ++ friend struct llvm::DenseMapInfo; + + public: + LocalDeclID() : Base() {} +@@ -266,6 +267,27 @@ + return L == R; + } + }; ++ ++template <> struct DenseMapInfo { ++ using LocalDeclID = clang::LocalDeclID; ++ using DeclID = LocalDeclID::DeclID; ++ ++ static LocalDeclID getEmptyKey() { ++ return LocalDeclID(DenseMapInfo::getEmptyKey()); ++ } ++ ++ static LocalDeclID getTombstoneKey() { ++ return LocalDeclID(DenseMapInfo::getTombstoneKey()); ++ } ++ ++ static unsigned getHashValue(const LocalDeclID &Key) { ++ return DenseMapInfo::getHashValue(Key.getRawValue()); ++ } ++ ++ static bool isEqual(const LocalDeclID &L, const LocalDeclID &R) { ++ return L == R; ++ } ++}; + + } // namespace llvm + +diff -ruN --strip-trailing-cr a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h +--- a/clang/include/clang/Serialization/ASTWriter.h ++++ b/clang/include/clang/Serialization/ASTWriter.h +@@ -233,13 +233,13 @@ + /// instead of comparing the result of `getDeclID()` or `GetDeclRef()`. + llvm::SmallPtrSet PredefinedDecls; + +- /// Mapping from FunctionDecl to the list of lambda IDs inside the function. ++ /// Mapping from FunctionDecl ID to the list of lambda IDs inside the ++ /// function. + /// + /// These lambdas have to be loaded right after the function they belong to. + /// In order to have canonical declaration for lambda class from the same + /// module as enclosing function during deserialization. +- llvm::DenseMap> +- FunctionToLambdasMap; ++ llvm::DenseMap> FunctionToLambdasMap; + + /// Offset of each declaration in the bitstream, indexed by + /// the declaration's ID. +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp +--- a/clang/lib/Serialization/ASTWriter.cpp ++++ b/clang/lib/Serialization/ASTWriter.cpp +@@ -5713,8 +5713,7 @@ + // efficent becuase it allows lazy deserialization. + RecordData FunctionToLambdasMapRecord; + for (const auto &Pair : FunctionToLambdasMap) { +- FunctionToLambdasMapRecord.push_back( +- GetDeclRef(Pair.first).getRawValue()); ++ FunctionToLambdasMapRecord.push_back(Pair.first.getRawValue()); + FunctionToLambdasMapRecord.push_back(Pair.second.size()); + for (const auto &Lambda : Pair.second) + FunctionToLambdasMapRecord.push_back(Lambda.getRawValue()); +diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp +--- a/clang/lib/Serialization/ASTWriterDecl.cpp ++++ b/clang/lib/Serialization/ASTWriterDecl.cpp +@@ -1524,7 +1524,8 @@ + // For lambdas inside canonical FunctionDecl remember the mapping. + if (auto FD = llvm::dyn_cast_or_null(D->getDeclContext()); + FD && FD->isCanonicalDecl()) { +- Writer.FunctionToLambdasMap[FD].push_back(Writer.GetDeclRef(D)); ++ Writer.FunctionToLambdasMap[Writer.GetDeclRef(FD)].push_back( ++ Writer.GetDeclRef(D)); + } + } else { + Record.push_back(CXXRecNotTemplate); diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 7b11086785b61..106f3665c46e9 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" - LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" + LLVM_COMMIT = "23487be4903630a4c06160562fb939f6389aa99d" + LLVM_SHA256 = "7c4c8c99df91e9e9859006b0435f83b5ed1260289a649befacfb529dc0a5f68f" tf_http_archive( name = name,