Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Move RTVars folding tests to indexing_analys…
Browse files Browse the repository at this point in the history
…is_test.

Next step will be moving the folding to indexing_analysis.cc and removing `hlo` and `map` fields from RTVar struct.

PiperOrigin-RevId: 680600028
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 30, 2024
1 parent 6d6512c commit bb4f5ad
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 393 deletions.
233 changes: 233 additions & 0 deletions xla/service/gpu/model/indexing_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,239 @@ TEST_F(IndexingAnalysisTest, BroadcastingElementwise) {
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_ScalarConstant) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = s32[4096] parameter(0)
offset = s64[] constant(42)
ROOT dynamic-slice = s32[10]
dynamic-slice(p0, offset), dynamic_slice_sizes={10}
}
ENTRY main {
p0 = s32[4096] parameter(0)
ROOT fusion = s32[10] fusion(p0), kind=kInput, calls=fused_computation
}
)hlo"));

EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0
(d0) -> (d0 + 42),
domain:
d0 in [0, 9]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Iota) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = f32[33,76] parameter(0)
iota = s64[42,1] iota(), iota_dimension=0
ROOT gather = f32[42,1,1] gather(p0, iota),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1,1}
}
ENTRY main {
p0 = f32[33,76] parameter(0)
ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation
}
)hlo"));
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0
(d0, d1, d2) -> (d0, 0),
domain:
d0 in [0, 41],
d1 in [0, 0],
d2 in [0, 0]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_IotaAsConstant) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = f32[33,76] parameter(0)
iota = s64[42,1] iota(), iota_dimension=1
ROOT gather = f32[42,1,1] gather(p0, iota),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1,1}
}
ENTRY main {
p0 = f32[33,76] parameter(0)
ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation
}
)hlo"));
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0
(d0, d1, d2) -> (0, 0),
domain:
d0 in [0, 41],
d1 in [0, 0],
d2 in [0, 0]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Broadcast) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = f32[33,76] parameter(0)
c42 = s64[] constant(42)
bcast = s64[42, 1] broadcast(s64[] c42), dimensions={}
ROOT gather = f32[42,1,1] gather(p0, bcast),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1,1}
}
ENTRY main {
p0 = f32[33,76] parameter(0)
ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation
}
)hlo"));
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0
(d0, d1, d2) -> (42, 0),
domain:
d0 in [0, 41],
d1 in [0, 0],
d2 in [0, 0]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Reverse) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = f32[33,76] parameter(0)
iota = s64[42,1] iota(), iota_dimension=0
reverse = s64[42,1] reverse(iota), dimensions={0}
ROOT gather = f32[42,1,1] gather(p0, reverse),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1,1}
}
ENTRY main {
p0 = f32[33,76] parameter(0)
ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation
}
)hlo"));
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0
(d0, d1, d2) -> (-d0 + 41, 0),
domain:
d0 in [0, 41],
d1 in [0, 0],
d2 in [0, 0]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Add) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = s32[4096] parameter(0)
p1 = s64[] parameter(1)
c42 = s64[] constant(42)
add = s64[] add(c42, p1)
ROOT dynamic-slice = s32[10]
dynamic-slice(p0, add), dynamic_slice_sizes={10}
}
ENTRY main {
p0 = s32[4096] parameter(0)
p1 = s64[] parameter(1)
ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation
}
)hlo"));
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0 (d0)[rt0] -> (d0 + rt0 + 42),
domain:
d0 in [0, 9],
rt0 in [0, 4086],
hlo: %p1 = s64[] parameter(1),
(d0) -> ()
operand id = 1
(d0) -> (),
domain:
d0 in [0, 9]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Multiply) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = s32[4096] parameter(0)
p1 = s64[] parameter(1)
c42 = s64[] constant(42)
add = s64[] multiply(c42, p1)
ROOT dynamic-slice = s32[10]
dynamic-slice(p0, add), dynamic_slice_sizes={10}
}
ENTRY main {
p0 = s32[4096] parameter(0)
p1 = s64[] parameter(1)
ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation
}
)hlo"));
// TODO: Figure out why the bounds are not updated.
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0 (d0)[rt0] -> (d0 + rt0 * 42),
domain:
d0 in [0, 9],
rt0 in [0, 4086],
hlo: %p1 = s64[] parameter(1),
(d0) -> ()
operand id = 1
(d0) -> (),
domain:
d0 in [0, 9]
)"));
}

TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_ChainedOps) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
fused_computation {
p0 = s32[4096] parameter(0)
p1 = s64[] parameter(1)
c42 = s64[] constant(42)
c2 = s64[] constant(2)
add = s64[] add(c42, p1)
multiply = s64[] multiply(c2, add)
ROOT dynamic-slice = s32[10]
dynamic-slice(p0, multiply), dynamic_slice_sizes={10}
}
ENTRY main {
p0 = s32[4096] parameter(0)
p1 = s64[] parameter(1)
ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation
}
)hlo"));
EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"(
operand id = 0
(d0)[rt0] -> (d0 + rt0 * 2 + 84),
domain: d0 in [0, 9],
rt0 in [0, 4086],
hlo: %p1 = s64[] parameter(1),
(d0) -> ()
operand id = 1
(d0) -> (),
domain:
d0 in [0, 9]
)"));
}

TEST_F(IndexingAnalysisTest, FusionOpWithDUS) {
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo(
HloModule m
Expand Down
Loading

0 comments on commit bb4f5ad

Please sign in to comment.