Skip to content

Commit

Permalink
Cleanup. Remove unnecessary TF_ASSIGN_OR_RETURN in dot_handler.cc. No…
Browse files Browse the repository at this point in the history
… behavior change.

PiperOrigin-RevId: 694665378
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Nov 8, 2024
1 parent 6d95565 commit 3ed5632
Showing 1 changed file with 13 additions and 31 deletions.
44 changes: 13 additions & 31 deletions xla/service/spmd/dot_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1866,10 +1866,7 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
return nullptr;
}
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
auto dot,
create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
return dot;
return create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window);
}
// RHS and output are batch partitioned in the same way.
if (rhs_batch_partitions == num_partitions &&
Expand All @@ -1883,10 +1880,7 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
return nullptr;
}
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
TF_ASSIGN_OR_RETURN(
auto dot,
create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
return dot;
return create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window);
}
return nullptr;
};
Expand Down Expand Up @@ -2018,19 +2012,15 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
output_lhs_non_contracting_partitions == num_partitions &&
lhs_sharding_transposed_to_match_output == output_sharding) {
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
b, conv_window));
return dot;
return create_sharded_dot(lhs.hlo(), rhs_replicated, b, conv_window);
}

// RHS and output have the same partitioned non-contracting dimensions.
if (rhs_non_contracting_partitions == num_partitions &&
output_rhs_non_contracting_partitions == num_partitions &&
rhs_sharding_transposed_to_match_output == output_sharding) {
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
b, conv_window));
return dot;
return create_sharded_dot(lhs_replicated, rhs.hlo(), b, conv_window);
}

if (may_reshard_without_detecting_match) {
Expand All @@ -2040,30 +2030,24 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
b, conv_window));
return dot;
return create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b,
conv_window);
}
// Output is partitioned along LHS non-contracting dimensions.
if (output_lhs_non_contracting_partitions == num_partitions) {
auto resharded_lhs =
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(resharded_lhs.hlo(),
replicated_rhs.hlo(), b, conv_window));
return dot;
return create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b,
conv_window);
}
// Output is partitioned along RHS non-contracting dimensions.
if (output_rhs_non_contracting_partitions == num_partitions) {
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), b, conv_window));
return dot;
return create_sharded_dot(replicated_lhs.hlo(), resharded_rhs.hlo(), b,
conv_window);
}
}

Expand Down Expand Up @@ -4264,11 +4248,9 @@ absl::StatusOr<HloInstruction*> ReshardLHSRHSToMatchOutputSharding(
consider_other_operand,
may_combine_partial_sharding);

TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.Reshard(infered_lhs_sharding).hlo(),
rhs.Reshard(infered_rhs_sharding).hlo(), b,
conv_window));
return dot;
return create_sharded_dot(lhs.Reshard(infered_lhs_sharding).hlo(),
rhs.Reshard(infered_rhs_sharding).hlo(), b,
conv_window);
}

absl::StatusOr<HloInstruction*> PartitionDot(
Expand Down

0 comments on commit 3ed5632

Please sign in to comment.