Skip to content

Commit

Permalink
[PJRT-IFRT] Improve IFRT SE GPU client test coverage
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694579017
  • Loading branch information
sizhit2 authored and Google-ML-Automation committed Nov 8, 2024
1 parent ddcd348 commit b8ec2c4
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions xla/python/ifrt/remap_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ TEST(RemapImplTest, ExtractSingleShard) {
RemapPlan plan;
plan.input_specs.push_back(
ArraySpec{/*dtype=*/DType(DType::kS32),
/*shape=*/Shape({8, 3}),
/*shape=*/Shape({4, 3}),
/*sharding=*/
ConcreteEvenSharding::Create(
test_util::GetDevices(client.get(), {0, 1, 2, 3}).value(),
MemoryKind(), /*shape=*/Shape({8, 3}),
test_util::GetDevices(client.get(), {0, 1}).value(),
MemoryKind(), /*shape=*/Shape({4, 3}),
/*shard_shape=*/Shape({2, 3}))});
plan.output_specs.push_back(
ArraySpec{/*dtype=*/DType(DType::kS32),
Expand All @@ -171,10 +171,9 @@ TEST(RemapImplTest, ExtractSingleShard) {
TF_ASSERT_OK(plan.Validate());

std::vector<tsl::RCReference<Array>> arrays;
TF_ASSERT_OK_AND_ASSIGN(
arrays.emplace_back(),
CreateArray(client.get(), /*base_values=*/{0, 6, 100, 106},
/*device_indices=*/{0, 1, 2, 3}));
TF_ASSERT_OK_AND_ASSIGN(arrays.emplace_back(),
CreateArray(client.get(), /*base_values=*/{0, 6},
/*device_indices=*/{0, 1}));

{
TF_ASSERT_OK_AND_ASSIGN(
Expand Down

0 comments on commit b8ec2c4

Please sign in to comment.