Skip to content

Commit

Permalink
[XLA:GPU] Add a verifier to IndexingMap and reuse it in IndexingMapAttr.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679504735
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 27, 2024
1 parent f727369 commit b961957
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
15 changes: 7 additions & 8 deletions xla/service/gpu/fusions/ir/xla_gpu_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <cstdint>
#include <optional>
#include <sstream>
#include <string>
#include <utility>

Expand Down Expand Up @@ -92,15 +93,13 @@ mlir::LogicalResult IndexingMapAttr::verify(
mlir::AffineMap map, ArrayRef<DimVar> dim_vars,
ArrayRef<RangeVar> range_vars,
ArrayRef<std::pair<AffineExpr, Interval>> constraints, bool is_simplified) {
if (map.getNumDims() != dim_vars.size()) {
return emitError() << "dim size must match the number of dimensions in "
"the affine map";
auto indexing_map = IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{},
constraints, is_simplified);
std::stringstream ss;
if (!indexing_map.Verify(ss)) {
return emitError() << ss.str();
}
if (map.getNumSymbols() != range_vars.size()) {
return emitError()
<< "range size must match the number of symbols in the affine map";
}
return mlir::success();
return success();
}

IndexingMap IndexingMapAttr::getIndexingMap() const {
Expand Down
17 changes: 17 additions & 0 deletions xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,23 @@ IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) {
return ComposeIndexingMaps(lhs, rhs);
}

bool IndexingMap::Verify(std::ostream& out) const {
if (IsUndefined()) {
return true;
}
if (affine_map_.getNumDims() != dim_vars_.size()) {
out << "dim size must match the number of dimensions in "
"the affine map";
return false;
}
if (affine_map_.getNumSymbols() != range_vars_.size() + rt_vars_.size()) {
out << "range vars size + rt var size must match the number of "
"symbols in the affine map";
return false;
}
return true;
}

// Simplification of IndexingMap has two main parts.
// At first we optimized constraints to make the domain as small and simple as
// possible. And only then we simplify the affine_map, because its
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/model/indexing_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ class IndexingMap {
absl::Span<const int64_t> symbol_upper_bounds,
bool is_simplified = false);

// Returns true if the indexing map is valid.
bool Verify(std::ostream& out) const;

// Returns true if the map was simplified.
bool Simplify();

Expand Down
25 changes: 25 additions & 0 deletions xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
#include <limits>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -86,6 +88,29 @@ TEST_F(IndexingMapTest, VariableKind) {
EXPECT_EQ(ToString(VariableKind::kBlockZ), "block_z");
}

TEST_F(IndexingMapTest, VerifyDimensions) {
auto indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0) -> (d0)", &mlir_context_),
/*dim_upper_bounds=*/{10, 10}, /*symbol_upper_bounds=*/{});

std::stringstream ss;
EXPECT_FALSE(indexing_map.Verify(ss));
EXPECT_EQ(ss.str(),
"dim size must match the number of dimensions in the affine map");
}

TEST_F(IndexingMapTest, VerifySymbols) {
auto indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0) -> (d0)", &mlir_context_),
/*dim_upper_bounds=*/{10}, /*symbol_upper_bounds=*/{10});

std::stringstream ss;
EXPECT_FALSE(indexing_map.Verify(ss));
EXPECT_EQ(ss.str(),
"range vars size + rt var size must match the number of symbols in "
"the affine map");
}

TEST_F(IndexingMapTest, RTVar) {
auto zero_dim_map =
AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, &mlir_context_);
Expand Down

0 comments on commit b961957

Please sign in to comment.