diff --git a/xla/service/gpu/transforms/windowed_einsum_handler.cc b/xla/service/gpu/transforms/windowed_einsum_handler.cc index b7ac16438ecb5..0abe00005316a 100644 --- a/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -56,12 +56,15 @@ namespace m = match; // and type conversions of FP8 operands into the bodies of their while loops, // i.e. rewrites // -// inputs --> dequant --> while loop {collective-permute/dot/etc} +// inputs --> dequant --> (unary) --> while loop {collective-permute/dot/etc} // // into // -// inputs --> while loop {dequant --> collective-permute/dot/etc}. -// Returns whether the input computation has been changed. +// inputs --> (unary) --> while loop {dequant --> collective-permute/dot/etc}. +// +// Unary bitcast, broadcast, copy, reshape and transpose ops are allowed between +// dequantization and while loop. Returns whether the input computation has been +// changed. absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { HloInstruction* while_instr = while_body->WhileCallInstruction(); // The input of the while loop will be modified and must have no other users. @@ -73,8 +76,21 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { // while loop. HloInstruction* param_tuple = while_instr->mutable_operand(0); std::array binaries, operands, scales; + std::array, 2> unaries; for (int k = 0; k < 2; ++k) { - if (!Match(param_tuple->mutable_operand(k), + HloInstruction* operand = param_tuple->mutable_operand(k); + // Capture bitcast, broadcast, copy, reshape and transpose ops between + // dequantization and the loop. + while (operand->opcode() == HloOpcode::kBitcast || + operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kCopy || + operand->opcode() == HloOpcode::kReshape || + operand->opcode() == HloOpcode::kTranspose) { + unaries[k].emplace_back(operand); + operand = operand->mutable_operand(0); + } + std::reverse(unaries[k].begin(), unaries[k].end()); + if (!Match(operand, m::AnyOf( m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])), m::Broadcast(m::Op(&scales[k]))), @@ -156,6 +172,22 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { return false; } + // Replace any dequantized bitcast, broadcast, copy, reshape and transpose ops + // before the while loop with FP8 unary ops. + for (int k = 0; k < 2; ++k) { + for (HloInstruction* unary : unaries[k]) { + Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( + operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().layout().minor_to_major()); + + operands[k] = unary->AddInstruction(unary->CloneWithNewOperands( + ShapeUtil::MakeShapeWithDenseLayout( + operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().layout().minor_to_major()), + {operands[k]})); + } + } + // Replace the dequantized dot operands in the parameter tuple used by while // with FP8 operands. for (int k = 0; k < 2; ++k) { diff --git a/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index e5f1e57f30630..4f736ef861497 100644 --- a/xla/service/gpu/transforms/windowed_einsum_handler_test.cc +++ b/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -634,127 +634,142 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 TEST_F(WindowedEinsumHandlerTest, AllGatherF8) { constexpr absl::string_view kHloString = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[1536,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 windowed_dot_general_body_ag { - param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) - get-tuple-element.lhs = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0 - collective-permute.send_first_lhs_shard = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.lhs), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} - collective-permute.send_second_lhs_shard = f32[2,512,24576]{2,1,0} collective-permute(collective-permute.send_first_lhs_shard), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} - get-tuple-element.rhs = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1 - get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2 - dot.first_shard_dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.lhs, get-tuple-element.rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} - constant.12 = s32[] constant(0) - constant.13 = s32[4]{0} constant({0, 512, 1024, 1536}) - get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4 - partition-id = u32[] partition-id() - add = u32[] add(get-tuple-element.5, partition-id) - constant.11 = u32[] constant(4) - remainder = u32[] remainder(add, constant.11) - dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1} - reshape = s32[] reshape(dynamic-slice) - dynamic-update-slice.update_first_shard_result = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot.first_shard_dot, constant.12, reshape, constant.12) - dot.second_shard_dot = f32[2,512,24576]{2,1,0} dot(collective-permute.send_first_lhs_shard, get-tuple-element.rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} - constant.15 = u32[] constant(1) - add.1 = u32[] add(get-tuple-element.5, constant.15) - add.2 = u32[] add(add.1, partition-id) - remainder.1 = u32[] remainder(add.2, constant.11) - dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1} - reshape.1 = s32[] reshape(dynamic-slice.1) - dynamic-update-slice.update_second_shard_result = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice.update_first_shard_result, dot.second_shard_dot, constant.12, reshape.1, constant.12) - get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3 - add.3 = u32[] add(add.1, constant.15) - ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.send_second_lhs_shard, get-tuple-element.rhs, dynamic-update-slice.update_second_shard_result, get-tuple-element.4, add.3) + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + lhs = f32[2,512,24576]{2,1,0} get-tuple-element(input), index=0 + permuted_lhs0 = f32[2,512,24576]{2,1,0} collective-permute(lhs), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + permuted_lhs1 = f32[2,512,24576]{2,1,0} collective-permute(permuted_lhs0), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + rhs = f32[24576,24576]{1,0} get-tuple-element(input), index=1 + partial_dot_output = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=2 + dot0 = f32[2,512,24576]{2,1,0} dot(lhs, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c0 = s32[] constant(0) + dot_update_slice_offsets = s32[4]{0} constant({0, 512, 1024, 1536}) + loop_counter = u32[] get-tuple-element(input), index=4 + partition_id = u32[] partition-id() + loop_counter_plus_partition_id = u32[] add(loop_counter, partition_id) + c4 = u32[] constant(4) + dot_update_slice_offsets_index0 = u32[] remainder(loop_counter_plus_partition_id, c4) + dot_update_slice_offset0 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index0), dynamic_slice_sizes={1} + dot_update_slice_offset_scalar0 = s32[] reshape(dot_update_slice_offset0) + updated_dot_output0 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(partial_dot_output, dot0, c0, dot_update_slice_offset_scalar0, c0) + dot1 = f32[2,512,24576]{2,1,0} dot(permuted_lhs0, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c1 = u32[] constant(1) + loop_counter_plus_one = u32[] add(loop_counter, c1) + loop_counter_plus_partiion_id_plus_one = u32[] add(loop_counter_plus_one, partition_id) + dot_update_slice_offsets_index1 = u32[] remainder(loop_counter_plus_partiion_id_plus_one, c4) + dot_update_slice_offset1 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index1), dynamic_slice_sizes={1} + dot_update_slice_offset1_scalar = s32[] reshape(dot_update_slice_offset1) + updated_dot_output1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(updated_dot_output0, dot1, c0, dot_update_slice_offset1_scalar, c0) + pass_through = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=3 + next_loop_counter = u32[] add(loop_counter_plus_one, c1) + ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(permuted_lhs1, rhs, updated_dot_output1, pass_through, next_loop_counter) } // windowed_dot_general_body_ag windowed_dot_general_cond_ag { - param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) - get-tuple-element = u32[] get-tuple-element(param), index=4 - constant.10 = u32[] constant(4) - ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + loop_counter = u32[] get-tuple-element(input), index=4 + loop_limit = u32[] constant(4) + ROOT compare = pred[] compare(loop_counter, loop_limit), direction=LT } -ENTRY test_main { - param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4) - param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - constant.18 = f32[] constant(0) - broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={} - constant.20 = u32[] constant(0) +ENTRY main { + lhs = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[1536,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + c0_f32 = f32[] constant(0) + c0_f32_bcast = f32[2,2048,24576]{2,1,0} broadcast(c0_f32), dimensions={} + c0_u32 = u32[] constant(0) scale_lhs = f32[] parameter(2) scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={} - lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8) - lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast) + lhs_f32 = f32[2,512,24576]{2,1,0} convert(lhs) + lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_f32, scale_lhs_bcast) scale_rhs = f32[] parameter(3) - scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={} - rhs_bf32 = f32[24576,24576]{1,0} convert(param.5) - rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast) - tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20) - while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + scale_rhs_bcast = f32[1536,24576]{1,0} broadcast(scale_rhs), dimensions={} + rhs_f32 = f32[1536,24576]{1,0} convert(rhs) + rhs_scaled = f32[1536,24576]{1,0} multiply(rhs_f32, scale_rhs_bcast) + rhs_bcast = f32[16,1536,24576]{2,1,0} broadcast(rhs_scaled), dimensions={1,2} + rhs_reshaped = f32[24576,24576]{1,0} reshape(rhs_bcast) + while_input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_reshaped, c0_f32_bcast, c0_f32_bcast, c0_u32) + while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(while_input), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2 } )"; RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( -; CHECK-LABEL: unrolled_windowed_dot_general_body_ag -; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) -; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0 -; CHECK-NEXT: [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=6 -; CHECK-NEXT: [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=7 -; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1 -; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2 -; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]]) -; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5 -; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={} -; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]]) -; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]]) -; CHECK-NEXT: [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6 -; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={} -; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]]) -; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]), +; CHECK-LABEL: %unrolled_windowed_dot_general_body_ag +; CHECK-NEXT: [[INPUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) +; CHECK-NEXT: [[LHS:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[INPUT]]), index=0 +; CHECK-NEXT: [[PERMUTED_LHS0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[LHS]]), channel_id=6 +; CHECK-NEXT: [[PERMUTED_LHS1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[PERMUTED_LHS0]]), channel_id=7 +; CHECK-NEXT: [[RHS:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[INPUT]]), index=1 +; CHECK-NEXT: [[PARTIAL_DOT_OUTPUT:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[INPUT]]), index=2 +; CHECK-NEXT: [[LHS_F32:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[LHS]]) +; CHECK-NEXT: [[SCALE_LHS:%[^ ]+]] = f32[] get-tuple-element([[INPUT]]), index=5 +; CHECK-NEXT: [[SCALE_LHS_BCAST:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[SCALE_LHS]]), dimensions={} +; CHECK-NEXT: [[LHS_SCALED:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[LHS_F32]], [[SCALE_LHS_BCAST]]) +; CHECK-NEXT: [[RHS_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS]]) +; CHECK-NEXT: [[SCALE_RHS:%[^ ]+]] = f32[] get-tuple-element([[INPUT]]), index=6 +; CHECK-NEXT: [[SCALE_RHS_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS]]), dimensions={} +; CHECK-NEXT: [[RHS_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS_F32]], [[SCALE_RHS_BCAST]]) +; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[LHS_SCALED]], [[RHS_SCALED]]), ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0}, ; CHECK-DAG: backend_config={ ; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]", ; CHECK-DAG: "wait_on_operation_queues":[], ; CHECK-DAG: "force_earliest_schedule":false} -; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0) -; CHECK-NEXT: [[C4:%[^ ]+]] = u32[] constant(0) +; CHECK-NEXT: [[C0_S32:%[^ ]+]] = s32[] constant(0) +; CHECK-NEXT: [[C0_U32:%[^ ]+]] = u32[] constant(0) ; CHECK-NEXT: [[C5:%[^ ]+]] = u32[] constant(0) -; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id() -; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[C5]], [[PID]]) -; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(3) -; CHECK-NEXT: [[AND0:%[^ ]+]] = u32[] and([[ADD0]], [[C2]]) -; CHECK-NEXT: [[CLAMP0:%[^ ]+]] = u32[] clamp([[C4]], [[AND0]], [[C2]]) +; CHECK-NEXT: [[PARTITION_ID:%[^ ]+]] = u32[] partition-id() +; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[C5]], [[PARTITION_ID]]) +; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(3) +; CHECK-NEXT: [[AND0:%[^ ]+]] = u32[] and([[ADD0]], [[C3]]) +; CHECK-NEXT: [[CLAMP0:%[^ ]+]] = u32[] clamp([[C0_U32]], [[AND0]], [[C3]]) ; CHECK-NEXT: [[CONVERT3:%[^ ]+]] = s32[] convert([[CLAMP0]]) -; CHECK-NEXT: [[C6:%[^ ]+]] = s32[] constant(512) -; CHECK-NEXT: [[MUL3:%[^ ]+]] = s32[] multiply([[CONVERT3]], [[C6]]) +; CHECK-NEXT: [[C512:%[^ ]+]] = s32[] constant(512) +; CHECK-NEXT: [[MUL3:%[^ ]+]] = s32[] multiply([[CONVERT3]], [[C512]]) ; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[MUL3]]) -; CHECK-NEXT: [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]), +; CHECK-NEXT: [[UPDATED_DOT_OUTPUT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[PARTIAL_DOT_OUTPUT]], [[DOT0]], [[C0_S32]], [[RESHAPE0]], [[C0_S32]]), ; CHECK-DAG: backend_config={ ; CHECK-DAG: "operation_queue_id":"0", ; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"], ; CHECK-DAG: "force_earliest_schedule":false} -; CHECK-NEXT: [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]]) -; CHECK-NEXT: [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]]) -; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]), +; CHECK-NEXT: [[PERMUTED_LHS0_F32:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[PERMUTED_LHS0]]) +; CHECK-NEXT: [[PERMUTED_LHS_SCALED:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[PERMUTED_LHS0_F32]], [[SCALE_LHS_BCAST]]) +; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[PERMUTED_LHS_SCALED]], [[RHS_SCALED]]), ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0} -; CHECK-NEXT: [[GTE7:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4 -; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1) -; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[GTE7]], [[C3]]) -; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]]) -; CHECK-NEXT: [[AND1:%[^ ]+]] = u32[] and([[ADD2]], [[C2]]) -; CHECK-NEXT: [[CLAMP1:%[^ ]+]] = u32[] clamp([[C4]], [[AND1]], [[C2]]) +; CHECK-NEXT: [[LOOP_COUNTER:%[^ ]+]] = u32[] get-tuple-element([[INPUT]]), index=4 +; CHECK-NEXT: [[C1:%[^ ]+]] = u32[] constant(1) +; CHECK-NEXT: [[LOOP_COUNTER_PLUS_ONE:%[^ ]+]] = u32[] add([[LOOP_COUNTER]], [[C1]]) +; CHECK-NEXT: [[LOOP_COUNTER_PLUS_ONE_PLUS_PARTITION_ID:%[^ ]+]] = u32[] add([[LOOP_COUNTER_PLUS_ONE]], [[PARTITION_ID]]) +; CHECK-NEXT: [[AND1:%[^ ]+]] = u32[] and([[LOOP_COUNTER_PLUS_ONE_PLUS_PARTITION_ID]], [[C3]]) +; CHECK-NEXT: [[CLAMP1:%[^ ]+]] = u32[] clamp([[C0_U32]], [[AND1]], [[C3]]) ; CHECK-NEXT: [[CONVERT4:%[^ ]+]] = s32[] convert([[CLAMP1]]) -; CHECK-NEXT: [[MUL4:%[^ ]+]] = s32[] multiply([[CONVERT4]], [[C6]]) +; CHECK-NEXT: [[MUL4:%[^ ]+]] = s32[] multiply([[CONVERT4]], [[C512]]) ; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[MUL4]]) -; CHECK-NEXT: [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]]) -; CHECK-NEXT: [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3 -; CHECK-NEXT: [[C7:%[^ ]+]] = u32[] constant(2) -; CHECK-NEXT: [[ADD3:%[^ ]+]] = u32[] add([[GTE7]], [[C7]]) -; CHECK-NEXT: [[TUPLE0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]]) +; CHECK-NEXT: [[UPDATED_DOT_OUTPUT1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[UPDATED_DOT_OUTPUT0]], [[DOT1]], [[C0_S32]], [[RESHAPE1]], [[C0_S32]]) +; CHECK-NEXT: [[PASS_THROUGH:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[INPUT]]), index=3 +; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(2) +; CHECK-NEXT: [[NEXT_LOOP_COUNTER:%[^ ]+]] = u32[] add([[LOOP_COUNTER]], [[C2]]) +; CHECK-NEXT: [[TUPLE:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[PERMUTED_LHS1]], [[RHS]], [[UPDATED_DOT_OUTPUT1]], [[PASS_THROUGH]], [[NEXT_LOOP_COUNTER]], /*index=5*/[[SCALE_LHS]], [[SCALE_RHS]]) +; CHECK-LABEL: ENTRY %main +; CHECK: [[LHS:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} +; CHECK-NEXT: [[RHS:%[^ ]+]] = f8e4m3fn[1536,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS_BCAST:%[^ ]+]] = f8e4m3fn[16,1536,24576]{2,1,0} broadcast([[RHS]]), dimensions={1,2} +; CHECK-NEXT: [[RHS_RESHAPED:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} reshape([[RHS_BCAST]]) +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) +; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[C0]]), dimensions={} +; CHECK-NEXT: [[C0_U32:%[^ ]+]] = u32[] constant(0) +; CHECK-NEXT: [[SCALE_LHS:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[SCALE_RHS:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[WHILE_INPUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[LHS]], [[RHS_RESHAPED]], [[C0_BCAST]], [[C0_BCAST]], [[C0_U32]], /*index=5*/[[SCALE_LHS]], [[SCALE_RHS]]) +; CHECK: [[WHILE:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) while([[WHILE_INPUT]]), +; CHECK-DAG: condition=%unrolled_windowed_dot_general_cond_ag, +; CHECK-DAG: body=%unrolled_windowed_dot_general_body_ag )"); } diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 479935f7d01f6..442e07a71d6ef 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -101,10 +101,11 @@ class CollectiveOpsTestE2E : public HloTestBase { CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); EXPECT_TRUE(executable->has_module()); - HloInstruction* gemm_op = - FindInstruction(&executable->module(), HloOpcode::kCustomCall); - EXPECT_THAT(gemm_op, NotNull()); - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + std::vector gemm_ops = + FindInstructions(&executable->module(), HloOpcode::kCustomCall); + for (HloInstruction* gemm_op : gemm_ops) { + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + } } absl::StatusOr> ExecuteReplicated(Executable* executable, @@ -867,46 +868,62 @@ ENTRY main.12 { CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); } +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[48,192]{1,0} broadcast(scale_rhs), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs_bf16 = bf16[48,192]{1,0} convert(rhs) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[48,192]{1,0} multiply(scale_rhs_bcast, rhs_bf16) + dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, - WindowedEinsumE2EAllGatherAndReduceScatterF8) { + WindowedEinsumE2EAllGatherReshapeF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[2,24,192]{2,1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 -ENTRY main.12 { - Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - Arg_2.3 = bf16[] parameter(3) - Arg_3.4 = bf16[] parameter(4) - broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} - broadcast.1 = bf16[48,192]{1,0} broadcast(Arg_3.4), dimensions={} - convert = bf16[2,16,48]{2,1,0} convert(Arg_0.1) - convert.1 = bf16[48,192]{1,0} convert(Arg_1.2) - multiply = bf16[2,16,48]{2,1,0} multiply(broadcast, convert) - multiply.1 = bf16[48,192]{1,0} multiply(broadcast.1, convert.1) - dot.5 = bf16[2,16,192]{2,1,0} dot(multiply, multiply.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} - custom-call.7 = bf16[2,16,192]{2,1,0} custom-call(dot.5), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} - Arg_4.5 = bf16[] parameter(5) - broadcast.2 = bf16[2,16,192]{2,1,0} broadcast(Arg_4.5), dimensions={} - divide = bf16[2,16,192]{2,1,0} divide(custom-call.7, broadcast.2) - constant = bf16[] constant(-448.) - broadcast.3 = bf16[2,16,192]{2,1,0} broadcast(constant), dimensions={} - constant.1 = bf16[] constant(448.) - broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} - clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = <>[2,16,192]{2,1,0} convert(clamp) - Arg_5.6 = bf16[] parameter(6) - broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} - convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) - multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} - Arg_7.8 = bf16[] parameter(7) - broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} - convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) - multiply.3 = bf16[192,48]{1,0} multiply(convert.4, broadcast.6) - dot.6 = bf16[2,16,48]{2,1,0} dot(multiply.2, multiply.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} - tuple.10 = (bf16[2,16,48]{2,1,0}) tuple(dot.6) - ROOT get-tuple-element.11 = bf16[2,16,48]{2,1,0} get-tuple-element(tuple.10), index=0, sharding={devices=[1,4,1]<=[4]} -} // main.12 +ENTRY main { + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[2,24,192]{2,1,0} parameter(1), sharding={devices=[1,1,4]<=[4]} + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} + scale_rhs_bcast = bf16[2,24,192]{2,1,0} broadcast(scale_lhs), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs_bf16 = bf16[2,24,192]{2,1,0} convert(rhs) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[2,24,192]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) + rhs_reshaped = bf16[48,192]{1,0} reshape(rhs_scaled) + dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_reshaped), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} +} // main )"; // Disable the dot merger pass which can prevent the creation of FP8 GEMM @@ -933,24 +950,61 @@ TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main { - rhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - lhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_lhs = bf16[] parameter(3) + scale_rhs0 = bf16[] parameter(4) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs0_bcast = bf16[48,192]{1,0} broadcast(scale_rhs0), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs0_bf16 = bf16[48,192]{1,0} convert(rhs0) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs0_scaled = bf16[48,192]{1,0} multiply(scale_rhs0_bcast, rhs0_bf16) + dot0 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + rhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_rhs1 = bf16[] parameter(5) + scale_rhs1_bcast = bf16[48,192]{1,0} broadcast(scale_rhs1), dimensions={} + rhs1_bf16 = bf16[48,192]{1,0} convert(rhs1) + rhs1_scaled = bf16[48,192]{1,0} multiply(scale_rhs1_bcast, rhs1_bf16) + dot1 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT add = bf16[2,16,192]{2,1,0} add(dot0, dot1) +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, + WindowedEinsumE2EReduceScatterF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,192]{2,1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + lhs = f8e4m3fn[2,16,192]{2,1,0} parameter(0), sharding={devices=[1,1,4]<=[4]} + rhs = f8e4m3fn[192,48]{1,0} parameter(1), sharding={devices=[4,1]<=[4]} + scale_lhs = bf16[] parameter(2) scale_rhs = bf16[] parameter(3) - scale_lhs0 = bf16[] parameter(4) - scale_rhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} - scale_lhs0_bcast = bf16[48,192]{1,0} broadcast(scale_lhs0), dimensions={} - rhs_bf16 = bf16[2,16,48]{2,1,0} convert(rhs) - lhs0_bf16 = bf16[48,192]{1,0} convert(lhs0) - rhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) - lhs0_scaled = bf16[48,192]{1,0} multiply(scale_lhs0_bcast, lhs0_bf16) - dot0 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} - lhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} - scale_lhs1 = bf16[] parameter(5) - scale_lhs1_bcast = bf16[48,192]{1,0} broadcast(scale_lhs1), dimensions={} - lhs1_bf16 = bf16[48,192]{1,0} convert(lhs1) - lhs1_scaled = bf16[48,192]{1,0} multiply(scale_lhs1_bcast, lhs1_bf16) - dot1 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} - ROOT add.8 = bf16[2,16,192]{2,1,0} add(dot0, dot1) + scale_lhs_bcast = bf16[2,16,192]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[192,48]{1,0} broadcast(scale_rhs), dimensions={} + lhs_bf16 = bf16[2,16,192]{2,1,0} convert(lhs) + rhs_bf16 = bf16[192,48]{1,0} convert(rhs) + lhs_scaled = bf16[2,16,192]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[192,48]{1,0} multiply(scale_rhs_bcast, rhs_bf16) + dot = bf16[2,16,48]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,48]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,4,1]<=[4]} } // main )";