diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index b0e2d9c86c95a..666cc1a041e8e 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1263,94 +1263,7 @@ class FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM } }; -class FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8 - : public MultiHeadedAttentionTest { - protected: - void TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8() { - if (skip_reason_) GTEST_SKIP() << *skip_reason_; - if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(9, 1, 0)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; - } - XlaBuilder builder(TestName()); - std::string hlo_string_ref = - R"( - HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} - clip.33 { - Arg_2.36 = bf16[] parameter(2) - broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} - Arg_1.35 = bf16[] parameter(1) - broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} - Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) - ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) - } // clip.33 - ENTRY main.106 { - Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - constant.6 = bf16[] constant(1) - broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} - divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) - constant.5 = bf16[] constant(-448) - constant.4 = bf16[] constant(448) - call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 - convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) - convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) - Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) - divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) - call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 - convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) - convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) - Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) - divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) - call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 - convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) - convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) - custom-call.4.0 = (bf16[4,4,16,16]{3,1,2,0}, u8[16]{0}) custom-call(convert.19, convert.31, convert.43), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "16", "16"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}}} - ROOT get-tuple-element.5.0 = bf16[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.4.0), index=0 - } // main.106 - )"; // NOLINT - std::string hlo_string = R"( - HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} - clip.33 { - Arg_2.36 = bf16[] parameter(2) - broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} - Arg_1.35 = bf16[] parameter(1) - broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} - Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) - ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) - } // clip.33 - ENTRY main.106 { - constant.99 = f32[] constant(1) - broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} - Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - constant.6 = bf16[] constant(1) - broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} - divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) - constant.5 = bf16[] constant(-448) - constant.4 = bf16[] constant(448) - call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 - convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) - convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) - Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) - divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) - call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 - convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) - convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) - Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) - divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) - call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 - convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) - convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) - custom-call.21.0 = (f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, u8[16]{0}) custom-call(convert.18, convert.30, convert.42, broadcast.99, broadcast.99, /*index=5*/broadcast.99, broadcast.99, broadcast.99, broadcast.99), custom_call_target="__cudnn$fmhaSoftmaxF8", operand_layout_constraints={f8e4m3fn[4,16,4,16]{3,2,1,0}, f8e4m3fn[4,16,4,16]{3,2,1,0}, f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}} - get-tuple-element.5.0 = f8e4m3fn[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.21.0), index=0 - ROOT out = bf16[4,4,16,16]{3,1,2,0} convert(get-tuple-element.5.0) - } // main.106 - )"; // NOLINT - EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, - ErrorSpec{1e-2, 1e-2})); - } -}; +class FlashAttentionBMMScaleSoftmaxBMMF8 : public MultiHeadedAttentionTest {}; class FlashAttentionBMMScaleSoftmaxDropoutBMM : public MultiHeadedAttentionTest { @@ -1465,10 +1378,434 @@ XLA_TEST_F(FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM, bfloat16>(); // NOLINT } +absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef() { + static constexpr absl::string_view hlo_text = + R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,16,4,16]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true} + + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) + } + + ENTRY main.106 { + Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) + + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} + + divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) + call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) + convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) + + divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) + convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) + + divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) + convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) + )"; + return hlo_text; +} + +absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8() { + static constexpr absl::string_view hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,16,4,16]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true} + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) + } // clip.33 + ENTRY main.106 { + constant.99 = f32[] constant(1) + broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} + Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) + constant.5 = bf16[] constant(-448) + constant.4 = bf16[] constant(448) + call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 + convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) + convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) + Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) + divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 + convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) + convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) + Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) + divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 + convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) + convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) + )"; + return hlo_text; +} // BMM1 - Scale - Softmax - BMM2 fp8 -XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8, - Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8) { - TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8(); +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BNTH_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + XlaBuilder builder(TestName()); + std::string ref_bnth = R"( + custom-call.4.0 = ( + bf16[4,4,16,16]{3,1,2,0}, + u8[0]{0} + ) custom-call( + convert.19, + convert.31, + convert.43 + ), + custom_call_target="__cudnn$fmhaSoftmax", + operand_layout_constraints={ + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "4", "16", "16"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"] + } + } + } + get-tuple-element.5.0 = bf16[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.4.0), index=0 + ROOT transpose.7 = bf16[4,16,4,16]{3,2,1,0} transpose(get-tuple-element.5.0), dimensions={0,2,1,3} + } +)"; + + std::string fp8_bnth = R"( + custom-call.21.0 = ( + f8e4m3fn[4,4,16,16]{3,1,2,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + u8[16]{0} + ) custom-call( + convert.18, + convert.30, + convert.42, + broadcast.99, + broadcast.99, + /*index=5*/broadcast.99, + broadcast.99, + broadcast.99, + broadcast.99 + ), + custom_call_target="__cudnn$fmhaSoftmaxF8", + operand_layout_constraints={ + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "4", "16", "16"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"] + } + } + } + get-tuple-element.5.0 = f8e4m3fn[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.21.0), index=0 + transpose.26 = f8e4m3fn[4,16,4,16]{3,2,1,0} transpose(get-tuple-element.5.0), dimensions={0,2,1,3} + ROOT out = bf16[4,16,4,16]{3,2,1,0} convert(transpose.26) + } + )"; + + std::string hlo_string = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8()) + + fp8_bnth; + std::string hlo_string_ref = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef()) + + ref_bnth; + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{5e-2, 5e-2})); +} + +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BTNH_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + XlaBuilder builder(TestName()); + + std::string ref_btnh = R"( + custom-call.4.0 = ( + bf16[4,16,4,16]{3,2,1,0}, + u8[0]{0} + ) custom-call( + convert.19, + convert.31, + convert.43 + ), + custom_call_target="__cudnn$fmhaSoftmax", + operand_layout_constraints={ + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "16", "4", "4"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + ROOT get-tuple-element.5.0 = bf16[4,16,4,16]{3,2,1,0} get-tuple-element(custom-call.4.0), index=0 + } +)"; + + std::string fp8_btnh = R"( + custom-call.21.0 = ( + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + u8[16]{0} + ) custom-call( + convert.18, + convert.30, + convert.42, + broadcast.99, + broadcast.99, + /*index=5*/broadcast.99, + broadcast.99, + broadcast.99, + broadcast.99 + ), + custom_call_target="__cudnn$fmhaSoftmaxF8", + operand_layout_constraints={ + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "16", "4", "4"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + get-tuple-element.5.0 = f8e4m3fn[4,16,4,16]{3,2,1,0} get-tuple-element(custom-call.21.0), index=0 + ROOT out = bf16[4,16,4,16]{3,2,1,0} convert(get-tuple-element.5.0) + } + )"; + + std::string hlo_string = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8()) + + fp8_btnh; + std::string hlo_string_ref = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef()) + + ref_btnh; + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{5e-2, 5e-2})); } // BMM1 - Scale - Softmax - BMM2 fp8 diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 29f7ddd1754df..93bb5e46b14f3 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -5240,10 +5240,12 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( .set_uid(next_uid()); amax_s->set_output(true) .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) .set_uid(next_uid()); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) .set_uid(next_uid());