From 6440945d2e5a14cd2345ebb5aaa52a1d9ccb9c62 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Sun, 5 Jan 2025 19:13:39 +0800 Subject: [PATCH] merge recent changes from ROCm/xformers (#1182) --- .github/actions/setup-build-cuda/action.yml | 5 +- .github/workflows/rocm_build.yml | 2 +- .github/workflows/rocm_ci.yml | 10 +- .github/workflows/rocm_docker.yml | 27 + .gitignore | 6 +- .gitmodules | 2 +- setup.py | 13 +- tests/test_mem_eff_attention.py | 19 +- third_party/composable_kernel_tiled | 2 +- xformers/csrc/attention/attention.cpp | 2 +- .../{hip_fmha => hip_decoder}/CMakeLists.txt | 0 .../attention_forward_decoder.cpp | 2 +- .../hip_decoder/attention_forward_splitk.cpp | 320 +++++ .../ck_attention_forward_decoder.h | 0 .../ck_attention_forward_decoder_splitk.h | 458 +++++++ .../ck_attention_inner_product.h | 0 .../ck_attention_math_ext.h | 0 .../attention_backward_generic_ck_tiled.cpp | 9 +- .../hip_fmha/attention_ck_rand_uniform.cpp | 6 +- .../attention_forward_generic_ck_tiled.cpp | 106 +- .../hip_fmha/attention_forward_splitk.cpp | 1184 ----------------- .../ck_attention_forward_decoder_splitk.h | 713 ---------- .../csrc/attention/hip_fmha/ck_fmha_util.h | 14 +- .../hip_fmha/ck_tiled_fmha_batched_backward.h | 158 ++- .../ck_tiled_fmha_batched_backward_bf16.cpp | 10 +- .../ck_tiled_fmha_batched_backward_fp16.cpp | 10 +- .../hip_fmha/ck_tiled_fmha_batched_forward.h | 234 +--- .../ck_tiled_fmha_batched_forward_bf16.cpp | 10 +- .../ck_tiled_fmha_batched_forward_dispatch.h | 171 +++ .../ck_tiled_fmha_batched_forward_fp16.cpp | 10 +- ...ed_fmha_batched_forward_splitkv_dispatch.h | 358 +++++ ..._batched_forward_splitkv_smallq_dispatch.h | 357 +++++ .../hip_fmha/ck_tiled_fmha_batched_infer.h | 267 +--- .../ck_tiled_fmha_batched_infer_bf16.cpp | 10 +- .../ck_tiled_fmha_batched_infer_dispatch.h | 206 +++ .../ck_tiled_fmha_batched_infer_fp16.cpp | 10 +- ...iled_fmha_batched_infer_splitkv_dispatch.h | 371 ++++++ ...ha_batched_infer_splitkv_smallq_dispatch.h | 370 ++++++ .../hip_fmha/ck_tiled_fmha_bwd_setting.h | 24 +- .../hip_fmha/ck_tiled_fmha_fwd_setting.h | 203 +-- .../ck_tiled_fmha_fwd_splitkv_selector.h | 95 ++ .../ck_tiled_fmha_fwd_splitkv_setting.h | 177 +++ ...k_tiled_fmha_fwd_splitkv_smallq_selector.h | 22 + ...ck_tiled_fmha_fwd_splitkv_smallq_setting.h | 137 ++ .../hip_fmha/ck_tiled_fmha_fwd_type_config.h | 46 + .../hip_fmha/ck_tiled_fmha_grouped_backward.h | 159 ++- .../ck_tiled_fmha_grouped_backward_bf16.cpp | 10 +- .../ck_tiled_fmha_grouped_backward_fp16.cpp | 10 +- .../hip_fmha/ck_tiled_fmha_grouped_forward.h | 229 +--- .../ck_tiled_fmha_grouped_forward_bf16.cpp | 10 +- .../ck_tiled_fmha_grouped_forward_dispatch.h | 157 +++ .../ck_tiled_fmha_grouped_forward_fp16.cpp | 10 +- ...ed_fmha_grouped_forward_splitkv_dispatch.h | 336 +++++ ..._grouped_forward_splitkv_smallq_dispatch.h | 333 +++++ .../hip_fmha/ck_tiled_fmha_grouped_infer.h | 274 +--- .../ck_tiled_fmha_grouped_infer_bf16.cpp | 10 +- .../ck_tiled_fmha_grouped_infer_dispatch.h | 189 +++ .../ck_tiled_fmha_grouped_infer_fp16.cpp | 10 +- ...iled_fmha_grouped_infer_splitkv_dispatch.h | 360 +++++ ...ha_grouped_infer_splitkv_smallq_dispatch.h | 359 +++++ .../ck_tiled_fmha_num_kv_split_switch.h | 23 + .../attention/hip_fmha/ck_tiled_fmha_params.h | 36 +- .../hip_fmha/ck_tiled_fmha_seqlen_q_switch.h | 21 + .../hip_fmha/ck_tiled_headdim_switch.h | 12 + .../hip_fmha/ck_tiled_rand_uniform_kernel.h | 25 +- .../attention/hip_fmha/generate_instances.py | 79 +- ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 4 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 4 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 4 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 4 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 4 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...fmha_batched_backward_bf16_instances_ref.h | 192 ++- ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 4 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 4 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 6 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 4 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 6 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 4 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 6 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 6 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...fmha_batched_backward_fp16_instances_ref.h | 192 ++- ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 4 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 6 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 6 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_batched_forward_bf16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 4 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 2 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...f16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_batched_forward_fp16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 4 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...p16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_batched_infer_bf16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...f16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_batched_infer_fp16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...p16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 4 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 6 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 4 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...fmha_grouped_backward_bf16_instances_ref.h | 192 ++- ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 4 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 6 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 6 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 4 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 4 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 6 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 6 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...fmha_grouped_backward_fp16_instances_ref.h | 192 ++- ...ias_has_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...ias_has_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ...bias_has_biasgrad_has_dropout_maxk_32.cpp} | 4 +- ...bias_has_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ..._bias_has_biasgrad_has_dropout_maxk_96.cpp | 20 + ...bias_has_biasgrad_no_dropout_maxk_128.cpp} | 4 +- ...bias_has_biasgrad_no_dropout_maxk_256.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_32.cpp} | 4 +- ..._bias_has_biasgrad_no_dropout_maxk_64.cpp} | 4 +- ...s_bias_has_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 6 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 6 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 6 +- ...s_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...s_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...as_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...bias_no_biasgrad_has_dropout_maxk_128.cpp} | 2 +- ...bias_no_biasgrad_has_dropout_maxk_256.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_32.cpp} | 2 +- ..._bias_no_biasgrad_has_dropout_maxk_64.cpp} | 2 +- ...o_bias_no_biasgrad_has_dropout_maxk_96.cpp | 20 + ..._bias_no_biasgrad_no_dropout_maxk_128.cpp} | 2 +- ..._bias_no_biasgrad_no_dropout_maxk_256.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_32.cpp} | 2 +- ...o_bias_no_biasgrad_no_dropout_maxk_64.cpp} | 2 +- ...no_bias_no_biasgrad_no_dropout_maxk_96.cpp | 20 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_grouped_forward_bf16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 4 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 2 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...f16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_grouped_forward_fp16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 4 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...p16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_grouped_infer_bf16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...f16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + ...as_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...as_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ...has_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ..._has_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ...has_mask_has_bias_no_dropout_maxk_128.cpp} | 2 +- ...has_mask_has_bias_no_dropout_maxk_256.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_32.cpp} | 2 +- ..._has_mask_has_bias_no_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ...has_mask_no_bias_has_dropout_maxk_128.cpp} | 2 +- ...has_mask_no_bias_has_dropout_maxk_256.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_32.cpp} | 2 +- ..._has_mask_no_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_has_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ..._has_mask_no_bias_no_dropout_maxk_128.cpp} | 4 +- ..._has_mask_no_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_has_mask_no_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_has_mask_no_bias_no_dropout_maxk_96.cpp | 19 + .../fmha_grouped_infer_fp16_instances_ref.h | 120 +- ...no_mask_has_bias_has_dropout_maxk_128.cpp} | 2 +- ...no_mask_has_bias_has_dropout_maxk_256.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_32.cpp} | 2 +- ..._no_mask_has_bias_has_dropout_maxk_64.cpp} | 2 +- ...6_no_mask_has_bias_has_dropout_maxk_96.cpp | 19 + ..._no_mask_has_bias_no_dropout_maxk_128.cpp} | 4 +- ..._no_mask_has_bias_no_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_has_bias_no_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_has_bias_no_dropout_maxk_96.cpp | 19 + ..._no_mask_no_bias_has_dropout_maxk_128.cpp} | 4 +- ..._no_mask_no_bias_has_dropout_maxk_256.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_32.cpp} | 4 +- ...6_no_mask_no_bias_has_dropout_maxk_64.cpp} | 4 +- ...16_no_mask_no_bias_has_dropout_maxk_96.cpp | 19 + ...6_no_mask_no_bias_no_dropout_maxk_128.cpp} | 2 +- ...6_no_mask_no_bias_no_dropout_maxk_256.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_32.cpp} | 2 +- ...16_no_mask_no_bias_no_dropout_maxk_64.cpp} | 2 +- ...p16_no_mask_no_bias_no_dropout_maxk_96.cpp | 19 + xformers/ops/fmha/ck.py | 133 +- 639 files changed, 9944 insertions(+), 4223 deletions(-) create mode 100644 .github/workflows/rocm_docker.yml rename xformers/csrc/attention/{hip_fmha => hip_decoder}/CMakeLists.txt (100%) rename xformers/csrc/attention/{hip_fmha => hip_decoder}/attention_forward_decoder.cpp (99%) create mode 100644 xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp rename xformers/csrc/attention/{hip_fmha => hip_decoder}/ck_attention_forward_decoder.h (100%) create mode 100644 xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h rename xformers/csrc/attention/{hip_fmha => hip_decoder}/ck_attention_inner_product.h (100%) rename xformers/csrc/attention/{hip_fmha => hip_decoder}/ck_attention_math_ext.h (100%) delete mode 100644 xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp delete mode 100644 xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_smallq_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_smallq_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_selector.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_setting.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_selector.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_setting.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_type_config.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_smallq_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_num_kv_split_switch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_seqlen_q_switch.h rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp => fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp} (87%) rename xformers/csrc/attention/hip_fmha/instances/{fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp => fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp} (87%) create mode 100644 xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp diff --git a/.github/actions/setup-build-cuda/action.yml b/.github/actions/setup-build-cuda/action.yml index dba3488c9e..824be1bd6b 100644 --- a/.github/actions/setup-build-cuda/action.yml +++ b/.github/actions/setup-build-cuda/action.yml @@ -29,8 +29,9 @@ runs: "124": ("12.4.1", "https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run"), "121": ("12.1.0", "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"), "118": ("11.8.0", "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"), - "6.0": ("6.0.2", "https://repo.radeon.com/amdgpu-install/6.0.2/rhel/7.9/amdgpu-install-6.0.60002-1.el7.noarch.rpm"), - "6.1": ("6.1.2", "https://repo.radeon.com/amdgpu-install/6.1.2/el/7/amdgpu-install-6.1.60102-1.el7.noarch.rpm"), + "6.0": ("6.0.2", "https://repo.radeon.com/amdgpu-install/6.0.2/rhel/8.9/amdgpu-install-6.0.60002-1.el8.noarch.rpm"), + "6.1": ("6.1.2", "https://repo.radeon.com/amdgpu-install/6.1.3/rhel/8.9/amdgpu-install-6.1.60103-1.el8.noarch.rpm"), + "6.2": ("6.2.3", "https://repo.radeon.com/amdgpu-install/6.2.3/rhel/8.9/amdgpu-install-6.2.60203-1.el8.noarch.rpm"), }[cushort] with open(os.environ['GITHUB_OUTPUT'], "r+") as fp: fp.write("CUDA_VERSION=" + full_version + "\n") diff --git a/.github/workflows/rocm_build.yml b/.github/workflows/rocm_build.yml index 8c8bd9b081..37fe17b4ec 100644 --- a/.github/workflows/rocm_build.yml +++ b/.github/workflows/rocm_build.yml @@ -24,7 +24,7 @@ jobs: python: ['3.11'] torch_version: ['2.5.1'] toolkit_type: ['rocm'] - toolkit_short_version: ['6.0', '6.1'] + toolkit_short_version: ['6.1', '6.2'] uses: ./.github/workflows/wheels_build.yml if: github.repository == 'rocm/xformers' diff --git a/.github/workflows/rocm_ci.yml b/.github/workflows/rocm_ci.yml index 0dc8d1cefd..1897eab1d1 100644 --- a/.github/workflows/rocm_ci.yml +++ b/.github/workflows/rocm_ci.yml @@ -12,10 +12,10 @@ on: jobs: build: if: github.repository == 'rocm/xformers' - runs-on: self-hosted + runs-on: self-hosted-rocm-ci container: image: 'rocm/pytorch-nightly:latest' - options: ' --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G ' + options: ' --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G --memory 32G ' steps: - uses: actions/checkout@v4 with: @@ -57,7 +57,7 @@ jobs: export PATH=/opt/conda/envs/xformers/bin:$PATH python -VV - python -m pip install -U torch --index-url=https://download.pytorch.org/whl/nightly/rocm6.1 + python -m pip install -U torch --index-url=https://download.pytorch.org/whl/rocm6.2 python -c "import torch; print(f'PyTorch version {torch.__version__}')" python -m pip install ninja scipy pytest pytest-html @@ -71,7 +71,7 @@ jobs: - name: Build xformers run: | export PATH=/opt/conda/envs/xformers/bin:$PATH - export MAX_JOBS=144 + export MAX_JOBS=20 python -m pip install -e ./_xformers --verbose python -m xformers.info @@ -97,7 +97,7 @@ jobs: cd .. clean: - runs-on: self-hosted + runs-on: self-hosted-rocm-ci if: ${{ needs.build.result != 'skipped' }} needs: [build] steps: diff --git a/.github/workflows/rocm_docker.yml b/.github/workflows/rocm_docker.yml new file mode 100644 index 0000000000..31fc242a71 --- /dev/null +++ b/.github/workflows/rocm_docker.yml @@ -0,0 +1,27 @@ +name: Build and Publish ROCm Docker Image + +on: + push: + branches: + - develop + +jobs: + build-and-push: + runs-on: rocm + if: github.repository == 'rocm/xformers' + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ vars.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push + uses: docker/build-push-action@v6 + with: + push: true + tags: rocm/xformers:latest + file: Dockerfile.rocm diff --git a/.gitignore b/.gitignore index b37d0b1b53..978b6be3e0 100644 --- a/.gitignore +++ b/.gitignore @@ -67,7 +67,9 @@ xformers/csrc/attention/hip_fmha/*.hip xformers/csrc/attention/hip_fmha/*_hip.h xformers/csrc/attention/hip_fmha/instances/*.cu xformers/csrc/attention/hip_fmha/instances/*.hip -xformers/csrc/attention/hip_fmha/instances/*.cu -xformers/csrc/attention/hip_fmha/instances/*.hip xformers/csrc/attention/hip_fmha/instances/*_hip.h +xformers/csrc/attention/hip_decoder/*.cu +xformers/csrc/attention/hip_decoder/*.hip +xformers/csrc/attention/hip_decoder/*_hip.h + diff --git a/.gitmodules b/.gitmodules index b642ad5b97..176104791f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "third_party/composable_kernel_tiled"] path = third_party/composable_kernel_tiled url = https://github.com/ROCm/composable_kernel.git - branch = develop + branch = develop diff --git a/setup.py b/setup.py index 1afc501fbc..0a88185867 100644 --- a/setup.py +++ b/setup.py @@ -381,11 +381,12 @@ def get_extensions(): ] source_hip = glob.glob( - os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cpp"), + os.path.join(extensions_dir, "attention", "hip_*", "**", "*.cpp"), recursive=True, ) + source_hip_generated = glob.glob( - os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cu"), + os.path.join(extensions_dir, "attention", "hip_*", "**", "*.cu"), recursive=True, ) # avoid the temporary .cu files generated under xformers/csrc/attention/hip_fmha @@ -539,7 +540,8 @@ def get_extensions(): extension = CUDAExtension sources += source_hip_cu include_dirs += [ - Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha" + Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha", + Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_decoder", ] include_dirs += [ @@ -557,12 +559,17 @@ def get_extensions(): arch_list = os.getenv("HIP_ARCHITECTURES", "native").split() + offload_compress_flag = [] + if hip_version >= "6.2.": + offload_compress_flag = ["--offload-compress"] + extra_compile_args = { "cxx": ["-O3", "-std=c++17"] + generator_flag, "nvcc": [ "-O3", "-std=c++17", *[f"--offload-arch={arch}" for arch in arch_list], + *offload_compress_flag, "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 6c59f4e4ac..aebf81e15f 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -674,6 +674,8 @@ def test_backward( if op_bw == fmha.ck.BwOp: op_fw = fmha.ck.FwOp if dtype == torch.bfloat16: + # bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when + # building xformers and get accurate results pytest.skip( "CK Fmha backward for bfloat16 currently is not very accurate for some cases!" ) @@ -1937,7 +1939,7 @@ def test_forward_gqa(opFW_biasT, Mq: int): "opBW", [ fmha.flash.BwOp, - fmha.cutlass.BwOp, + fmha.ck.BwOp if torch.version.hip else fmha.cutlass.BwOp, ], ) def test_backward_gqa(opBW): @@ -1949,7 +1951,7 @@ def test_backward_gqa(opBW): attn_bias_requires_grad=False, fmt="BMHK", ) - op = (fmha.cutlass.FwOp, opBW) + op = (fmha.ck.FwOp if torch.version.hip else fmha.cutlass.FwOp, opBW) key = key[:, :, :1].expand(-1, -1, H, -1) value = value[:, :, :1].expand(-1, -1, H, -1) key.requires_grad_(True) @@ -2278,6 +2280,19 @@ def test_paged_attention( ) +@cuda_only +@pytest.mark.parametrize("B", [1, 5, 128]) +@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192]) +@pytest.mark.parametrize("page_size", [128, 256]) +@pytest.mark.parametrize("gappy", [False, True], ids=lambda x: "gappy" if x else "") +def test_paged_attention_ck(B, MAX_T: int, page_size: int, gappy: bool): + op = fmha.ck.FwOp + num_quant_groups = 0 + paged_attention_run_inner( + B, MAX_T, num_quant_groups, page_size, op, bench=False, gappy=gappy + ) + + @sm80_or_better_only @disable_on_rocm @pytest.mark.parametrize("B", [1, 5, 128]) diff --git a/third_party/composable_kernel_tiled b/third_party/composable_kernel_tiled index 73b67f290f..4e076909b6 160000 --- a/third_party/composable_kernel_tiled +++ b/third_party/composable_kernel_tiled @@ -1 +1 @@ -Subproject commit 73b67f290f6602fe0461d48a2c103de460f14084 +Subproject commit 4e076909b6c1e1404d9ff5dc0e71e3be1c06569e diff --git a/xformers/csrc/attention/attention.cpp b/xformers/csrc/attention/attention.cpp index f4ea8696b1..bdc77889b9 100644 --- a/xformers/csrc/attention/attention.cpp +++ b/xformers/csrc/attention/attention.cpp @@ -26,7 +26,7 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { "xformers::efficient_attention_forward_ck(Tensor query, " "Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, " "Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, " - "bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size) -> (Tensor, Tensor, int, int)")); + "bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size, Tensor? block_tables, int? page_size) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_forward_decoder_ck(Tensor query, " "Tensor key, Tensor value, Tensor? seq_positions, float scale) -> Tensor")); diff --git a/xformers/csrc/attention/hip_fmha/CMakeLists.txt b/xformers/csrc/attention/hip_decoder/CMakeLists.txt similarity index 100% rename from xformers/csrc/attention/hip_fmha/CMakeLists.txt rename to xformers/csrc/attention/hip_decoder/CMakeLists.txt diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp b/xformers/csrc/attention/hip_decoder/attention_forward_decoder.cpp similarity index 99% rename from xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp rename to xformers/csrc/attention/hip_decoder/attention_forward_decoder.cpp index 7f126dd335..dbdb944b95 100644 --- a/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp +++ b/xformers/csrc/attention/hip_decoder/attention_forward_decoder.cpp @@ -96,7 +96,7 @@ at::Tensor& efficient_attention_forward_decoder_ck_out_impl( int32_t smem_output = K_MAX * sizeof(float) * threads.y; // 4 * threadsPerBlock * sizeof(float) == sizeof(O[b][0][h][:]) const size_t lds_bytes = max(smem_softmax, smem_output); - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream().stream(); AT_DISPATCH_SWITCH_3( at::ScalarType::Half, diff --git a/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp b/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp new file mode 100644 index 0000000000..647e540d37 --- /dev/null +++ b/xformers/csrc/attention/hip_decoder/attention_forward_splitk.cpp @@ -0,0 +1,320 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ck_attention_forward_decoder_splitk.h" + +namespace { +constexpr int32_t kThreadsPerWavefront = 64; +constexpr int32_t kWavefrontsPerBlock = 4; +constexpr int32_t kMaxHeadDimension = 4 * kThreadsPerWavefront; +constexpr int32_t kMaxKVSequenceLength = 4096; +constexpr int32_t kLoopUnroll = 16; +constexpr int32_t kLoopUnrollTail = 2; +using compute_t = float; +} // namespace + +namespace { + +template +struct c10_to_data_t; +template <> +struct c10_to_data_t { + using type = float; +}; + +template <> +struct c10_to_data_t { + using type = ck::half_t; +}; + +template <> +struct c10_to_data_t { + using type = ck::bhalf_t; +}; +} // namespace + +#define AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_SWITCH_3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +namespace { + +template +void instantiate_and_launch_kernels( + typename ck_tile::ForwardDecoderSplitKArgument arg, + dim3 attn_grid_size, + dim3 attn_block_size, + int32_t lds_bytes, + dim3 reduce_grid_size, + dim3 reduce_block_size, + hipStream_t stream) { + auto attn_kernel_impl = ck_tile::ForwardDecoderSplitKAttnKernelImpl< + ck_data_t, + vec_size, + kLoopUnroll, + kLoopUnrollTail, + kMaxKVSequenceLength, + compute_t>{}; + auto reduce_kernel_impl = ck_tile:: + ForwardDecoderSplitKReduceKernelImpl{}; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, /* benchmark */ false}, + ck_tile::make_kernel( + attn_kernel_impl, attn_grid_size, attn_block_size, lds_bytes, arg)); + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, /* benchmark */ false}, + ck_tile::make_kernel( + reduce_kernel_impl, + reduce_grid_size, + reduce_block_size, + 0 /* lds_bytes */, + arg)); +} + +template < + int32_t ThreadsPerWavefront, + int32_t WavefrontsPerBlock> +at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] + const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] + at::optional seq_kv_lens, // [B] + float qk_scale, + int32_t split_k, + at::Tensor& split_max, + at::Tensor& split_sumexp, + at::Tensor& split_O, + at::Tensor& O) { + static_assert(4 * ThreadsPerWavefront == kMaxHeadDimension, ""); + static_assert(WavefrontsPerBlock <= ThreadsPerWavefront, ""); + + at::OptionalDeviceGuard guard(XQ.device()); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + + TORCH_CHECK(!seq_kv_lens || seq_kv_lens->is_cuda()); + + TORCH_CHECK(cache_K.size(1) / split_k <= kMaxKVSequenceLength); + TORCH_CHECK(cache_K.size(4) <= kMaxHeadDimension); + + constexpr auto rank = 5; + + auto B = XQ.size(0); + auto M = XQ.size(1); + auto G = XQ.size(2); + auto H = XQ.size(3); + auto HDim = XQ.size(4); + + TORCH_CHECK(B <= 1024); + TORCH_CHECK(M <= 1024); + TORCH_CHECK(H <= 1024); + + const dim3 attn_grid_size(B * H * M * G, split_k); + const dim3 attn_block_size(ThreadsPerWavefront, WavefrontsPerBlock); + + const dim3 reduce_grid_size = {attn_grid_size.x}; + const dim3 reduce_block_size = {attn_block_size.x}; + + int32_t smem_softmax = kMaxKVSequenceLength * sizeof(compute_t) + + WavefrontsPerBlock * sizeof(compute_t); + int32_t smem_output = kMaxHeadDimension * sizeof(compute_t) * + WavefrontsPerBlock; // 4 * threadsPerBlock * sizeof(float) == + // sizeof(O[b][0][h][:]) + const size_t attn_lds_bytes = max(smem_softmax, smem_output); + auto stream = at::hip::getCurrentHIPStream().stream(); + + AT_DISPATCH_SWITCH_3( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float, + XQ.scalar_type(), + "efficient_attention_forward_decoder_splitk_ck", + [&] { + using ck_data_t = c10_to_data_t::type; + + auto XQ_acc = + XQ.packed_accessor32(); + auto K_acc = + cache_K.packed_accessor64(); + auto V_acc = + cache_V.packed_accessor64(); + auto split_O_acc = + split_O + .packed_accessor32(); + auto O_acc = + O.packed_accessor32(); + auto seq_acc_ptr = seq_kv_lens + ? seq_kv_lens + ->packed_accessor32() + .data() + : nullptr; + auto split_max_acc = + split_max.packed_accessor32(); + auto split_sumexp_acc = + split_sumexp + .packed_accessor32(); + auto arg = ck_tile::ForwardDecoderSplitKArgument{ + reinterpret_cast(XQ_acc.data()), + reinterpret_cast(K_acc.data()), + reinterpret_cast(V_acc.data()), + reinterpret_cast(O_acc.data()), + reinterpret_cast(split_O_acc.data()), + split_max_acc.data(), + split_sumexp_acc.data(), + seq_acc_ptr, + XQ_acc.stride(0), + XQ_acc.stride(1), + XQ_acc.stride(2), + XQ_acc.stride(3), + K_acc.stride(0), + K_acc.stride(1), + K_acc.stride(2), + K_acc.stride(3), + split_O_acc.stride(0), + static_cast(XQ_acc.size(1)), + static_cast(XQ_acc.size(2)), + static_cast(XQ_acc.size(3)), + static_cast(XQ_acc.size(4)), + static_cast(K_acc.size(1)), + K_acc.size(3) == 1, + qk_scale, + split_k}; + + auto required_vec_size = 0; + + for (auto vec_size : {4, 2, 1}) { + if (arg.Q_size_k <= vec_size * ThreadsPerWavefront) { + required_vec_size = vec_size; + } + } + + TORCH_CHECK(required_vec_size > 0); + + switch (required_vec_size) { + case 4: + instantiate_and_launch_kernels( + arg, + attn_grid_size, + attn_block_size, + attn_lds_bytes, + reduce_grid_size, + reduce_block_size, + stream); + break; + case 2: + instantiate_and_launch_kernels( + arg, + attn_grid_size, + attn_block_size, + attn_lds_bytes, + reduce_grid_size, + reduce_block_size, + stream); + break; + case 1: + instantiate_and_launch_kernels( + arg, + attn_grid_size, + attn_block_size, + attn_lds_bytes, + reduce_grid_size, + reduce_block_size, + stream); + break; + default: + break; + } + }); + + return O; +} + +template +at::Tensor efficient_attention_forward_decoder_splitk_ck_impl( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] + const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, H or 1, D] + at::optional seq_kv_lens, // [B] + float qk_scale, + int32_t split_k) { + auto O = at::empty_like(XQ); + constexpr auto rank = 5; + + TORCH_CHECK(XQ.dim() == rank); + TORCH_CHECK(cache_K.dim() == rank); + TORCH_CHECK(cache_V.dim() == rank); + + auto B = XQ.size(0); + auto M = XQ.size(1); + auto G = XQ.size(2); + auto H = XQ.size(3); + auto K = XQ.size(4); + + auto O_splits = at::empty({split_k, B, M, G, H, K}, XQ.options()); + auto split_max = + at::empty({B, M, G, H, split_k}, XQ.options().dtype(at::kFloat)); + auto split_sumexp = at::empty_like(split_max); + + efficient_attention_forward_decoder_splitk_ck_out_impl< + ThreadsPerWavefront, + WavefrontsPerBlock>( + XQ, + cache_K, + cache_V, + seq_kv_lens, + qk_scale, + split_k, + split_max, + split_sumexp, + O_splits, + O); + + return O; +} + +at::Tensor efficient_attention_forward_decoder_splitk_ck( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] + const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale, + int64_t split_k) { + return efficient_attention_forward_decoder_splitk_ck_impl< + kThreadsPerWavefront, + kWavefrontsPerBlock>( + XQ, + cache_K, + cache_V, + seq_kv_lens, + static_cast(qk_scale), + static_cast(split_k)); +} +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME( + "xformers::efficient_attention_forward_decoder_splitk_ck"), + TORCH_FN(efficient_attention_forward_decoder_splitk_ck)); +} + +#undef AT_DISPATCH_CASE_3 +#undef AT_DISPATCH_SWITCH_3 diff --git a/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder.h b/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder.h similarity index 100% rename from xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder.h rename to xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder.h diff --git a/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h b/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h new file mode 100644 index 0000000000..5389affacc --- /dev/null +++ b/xformers/csrc/attention/hip_decoder/ck_attention_forward_decoder_splitk.h @@ -0,0 +1,458 @@ +#pragma once + +#include +#include + +#include "ck_attention_inner_product.h" +#include "ck_attention_math_ext.h" + +namespace { + +template +__device__ typename ck::vector_type::type scalar_scale_acc( + typename ck::vector_type::type acc, + typename ck::vector_type::type a, + float b) { + union { + decltype(acc) vec; + float arr[vec_size]; + } acc_u{acc}; + union { + decltype(a) vec; + data_t arr[vec_size]; + } a_u{a}; + +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + acc_u.arr[i] += ck::type_convert(a_u.arr[i]) * b; + } + + return acc_u.vec; +} + +template +float __device__ __forceinline__ wavefrontReduce(float val, F f) { +#pragma unroll + for (int32_t mask = n_threads_per_wavefront >> 1; mask > 0; mask >>= 1) { + val = f(__shfl_xor(val, mask, n_threads_per_wavefront), val); + } + return val; +} + +template +__forceinline__ __device__ void load_v( + const TData* __restrict__ data_ptr, + int32_t vector_offset, + TDataVec* __restrict__ load_to) { + *load_to = *(reinterpret_cast(data_ptr) + vector_offset); +} + +template +__forceinline__ __device__ void store_v( + TData* __restrict__ data_ptr, + int32_t vector_offset, + TDataVec value) { + *(reinterpret_cast(data_ptr) + vector_offset) = value; +} + +} // namespace + +namespace ck_tile { +template +struct ForwardDecoderSplitKArgument { + const scalar_t* __restrict__ XQ; + const scalar_t* __restrict__ cache_K; + const scalar_t* __restrict__ cache_V; + scalar_t* __restrict__ O; + scalar_t* __restrict__ split_O; + compute_t* __restrict__ split_max; + compute_t* __restrict__ split_sumexp; + const int32_t* __restrict__ seq_kv_lens; + const ptrdiff_t XQ_stride_b; + const ptrdiff_t XQ_stride_m; + const ptrdiff_t XQ_stride_g; + const ptrdiff_t XQ_stride_h; + const ptrdiff_t K_stride_b; + const ptrdiff_t K_stride_m; + const ptrdiff_t K_stride_g; + const ptrdiff_t K_stride_h; + const ptrdiff_t O_stride_split; + const int32_t Q_size_m; + const int32_t Q_size_g; + const int32_t Q_size_h; + const int32_t Q_size_k; + const int32_t K_size_m; + const bool multiquery; + const float qk_scale; + const int32_t split_k; +}; + +template +struct ForwardDecoderSplitKReduceKernelImpl { + CK_TILE_DEVICE void operator()( + ForwardDecoderSplitKArgument arg) { + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (arg.Q_size_m * arg.Q_size_g * arg.Q_size_h); + const int32_t m = + (blockIdx.x / (arg.Q_size_g * arg.Q_size_h)) % arg.Q_size_m; + const int32_t g = (blockIdx.x / arg.Q_size_h) % arg.Q_size_g; + const int32_t h = blockIdx.x % arg.Q_size_h; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_vec_t = typename ck::vector_type::type; + + union { + data_vec_t vec; + data_t arr[vec_size]; + } O_split_data; + union { + compute_vec_t vec; + compute_t arr[vec_size]; + } O_split_compute; + union { + data_vec_t vec; + data_t arr[vec_size]; + } global_O_data; + union { + compute_vec_t vec; + compute_t arr[vec_size]; + } global_O_compute; + + global_O_compute.vec = 0; + + const int32_t lane_idx = threadIdx.x; + const bool lane_active_for_io = lane_idx * vec_size < arg.Q_size_k; + + if (!lane_active_for_io) { + return; + } + + compute_t global_sumexp = 0; + compute_t global_max = ck::NumericLimits::Lowest(); + + for (int32_t split_idx = 0; split_idx < arg.split_k; ++split_idx) { + load_v( + arg.split_O + b * arg.XQ_stride_b + m * arg.XQ_stride_m + + g * arg.XQ_stride_g + h * arg.XQ_stride_h + + split_idx * arg.O_stride_split, + lane_idx, + &O_split_data.vec); +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + O_split_compute.arr[i] = + ck::type_convert(O_split_data.arr[i]); + } + compute_t local_max = + *(arg.split_max + blockIdx.x * arg.split_k + split_idx); + compute_t local_sumexp = + *(arg.split_sumexp + blockIdx.x * arg.split_k + split_idx); + + compute_t log_alpha = -std::abs(local_max - global_max); + compute_t alpha = + isnan(log_alpha) ? compute_t{1.} : ck::math::exp(log_alpha); + + bool pick_new = local_max < global_max; + compute_t pick_current_coef = pick_new ? 1. : alpha; + compute_t pick_new_coef = pick_new ? alpha : 1.; + + global_sumexp = + pick_current_coef * global_sumexp + pick_new_coef * local_sumexp; + global_O_compute.vec = pick_current_coef * global_O_compute.vec + + pick_new_coef * O_split_compute.vec; + global_max = ck::math::max(local_max, global_max); + } + global_O_compute.vec /= global_sumexp; +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + global_O_data.arr[i] = ck::type_convert(global_O_compute.arr[i]); + } + store_v( + arg.O + b * arg.XQ_stride_b + m * arg.XQ_stride_m + + g * arg.XQ_stride_g + h * arg.XQ_stride_h, + lane_idx, + global_O_data.vec); + } +}; + +template < + typename scalar_t, + int32_t vec_size, + int32_t n_loop_unroll, + int32_t n_loop_unroll_tail, + int32_t KV_M_MAX, + typename compute_t> +struct ForwardDecoderSplitKAttnKernelImpl { + CK_TILE_DEVICE void operator()( + ForwardDecoderSplitKArgument arg) { + static_assert( + n_loop_unroll_tail < n_loop_unroll || n_loop_unroll_tail == 1, + "tail unroll must be smaller than main loop untoll; pragma unroll 0 is illegal " + "(and tail is no-op)"); + + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (arg.Q_size_m * arg.Q_size_g * arg.Q_size_h); + const int32_t m = + (blockIdx.x / (arg.Q_size_g * arg.Q_size_h)) % arg.Q_size_m; + const int32_t g = (blockIdx.x / arg.Q_size_h) % arg.Q_size_g; + const int32_t h = blockIdx.x % arg.Q_size_h; + const int32_t split_idx = blockIdx.y; + + // Note: this is decoding case where we attend to current and all previous + // tokens. + const int32_t t_max = arg.seq_kv_lens ? arg.seq_kv_lens[b] : arg.K_size_m; + + const int32_t lane_idx = threadIdx.x; + const int32_t wavefront_idx = threadIdx.y; + // TODO: `threads_per_wavefront` and `wavefronts_per_block` may be compile + // time constants; investigate when optimizing + const int32_t threads_per_wavefront = blockDim.x; + const int32_t wavefronts_per_block = blockDim.y; + const int32_t threads_per_block = + threads_per_wavefront * wavefronts_per_block; + const int32_t thread_linear_idx = + lane_idx + wavefront_idx * threads_per_wavefront; + // const auto* q_ = &(XQ_acc[b][m][g][h][0]); + const auto XQO_base_offset = b * arg.XQ_stride_b + m * arg.XQ_stride_m + + g * arg.XQ_stride_g + h * arg.XQ_stride_h; + const auto* __restrict__ q_ = arg.XQ + XQO_base_offset; + + const auto cache_KV_base_offset = b * arg.K_stride_b + 0 * arg.K_stride_m + + g * arg.K_stride_g + (arg.multiquery ? 0 : h * arg.K_stride_h); + const auto* __restrict__ cache_K_base = arg.cache_K + cache_KV_base_offset; + const auto* __restrict__ cache_V_base = arg.cache_V + cache_KV_base_offset; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_vec_t = typename ck::vector_type::type; + + const bool lane_active_for_io = lane_idx * vec_size < arg.Q_size_k; + + extern __shared__ __align__(16) compute_t smem[]; + + data_vec_t q_thread = 0; + // Load Q into registers in all wavefronts. + // Each thread handles `vec_size` D dimensions + if (lane_active_for_io) { + load_v(q_, lane_idx, &q_thread); + } + + compute_t max_qk_acc = ck::NumericLimits::Lowest(); + + // Compute S[0:t_max] = + // ``` + // for t in range(t_max): + // S[t] = dot(Q, K[t]) + // ``` + // Split the 0:t_max range across wavefronts in a block, + // unroll loads to expose more parallelism. + // Reduce the dot product with cross-lane operation; + // Q and K[t] are in the registers of threads in a single wavefront. + + data_vec_t k_loads[n_loop_unroll] = {}; + + const auto dtt = wavefronts_per_block * n_loop_unroll; + // only last split gets the tail. + // the first (split_k - 1) splits have a number of iterations divisible by + // `dtt` + const auto n_unrolled_loops = t_max / dtt / arg.split_k; // +1? + const int32_t tt_low = + wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * split_idx; + const int32_t tt_high = wavefront_idx * n_loop_unroll + + n_unrolled_loops * dtt * (split_idx + 1); + const int32_t dtt_tail = wavefronts_per_block * n_loop_unroll_tail; + const int32_t tt_tail_low = wavefront_idx * n_loop_unroll_tail + + n_unrolled_loops * dtt * (split_idx + 1); + const int32_t tt_tail_high = + (split_idx == arg.split_k - 1) ? t_max : tt_tail_low; + + for (auto tt = tt_low; tt < tt_high; tt += dtt) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); + } + } +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + compute_t qk_acc = 0; + ck::inner_product( + q_thread, k_loads[ttt], qk_acc); + qk_acc *= arg.qk_scale; + + qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_acc, max_qk_acc); + if (lane_idx == 0) { + smem[tt + ttt - n_unrolled_loops * dtt * split_idx] = qk_acc; + } + } + } + + for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); + } + } + } +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + compute_t qk_acc = 0; + const int32_t t = tt + ttt; + if (t < t_max) { + ck::inner_product( + q_thread, k_loads[ttt], qk_acc); + qk_acc *= arg.qk_scale; + + qk_acc = + wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_acc, max_qk_acc); + + // write accumulated sums to smem. + if (lane_idx == 0) { + smem[t - n_unrolled_loops * dtt * split_idx] = qk_acc; + } + } + } + } + + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = max_qk_acc; + } + __syncthreads(); + if (lane_idx < wavefronts_per_block) { + max_qk_acc = ck::math::max(max_qk_acc, smem[KV_M_MAX + lane_idx]); + } + // shared across all threads in block + max_qk_acc = wavefrontReduce( + max_qk_acc, [](auto a, auto b) { return a > b ? a : b; }); + + if (wavefront_idx == 0 && lane_idx == 0) { + arg.split_max[blockIdx.x * arg.split_k + split_idx] = max_qk_acc; + } + + // each wavefront computes partial sum of exp. + { // softmax reduce begin + compute_t softmax_denominator = 0.0f; + const int32_t t_low = n_unrolled_loops * dtt * split_idx; + const int32_t t_high = (split_idx + 1 < arg.split_k) + ? n_unrolled_loops * dtt * (split_idx + 1) + : t_max; + for (int32_t t = t_low + thread_linear_idx; t < t_high; + t += threads_per_block) { + const auto s = ck::math::exp(smem[t - t_low] - max_qk_acc); + softmax_denominator += s; + smem[t - t_low] = s; + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); + + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = softmax_denominator; + } + __syncthreads(); + + // now, compute sum of exp(x - max(x)) over all intermediate results. + softmax_denominator = 0.0; + if (lane_idx < wavefronts_per_block) { + softmax_denominator = smem[KV_M_MAX + lane_idx]; + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); + + if (wavefront_idx == 0 && lane_idx == 0) { + arg.split_sumexp[blockIdx.x * arg.split_k + split_idx] = + softmax_denominator; + } + } // softmax reduce end + + // Split T across wavefronts in a block + // each wavefront compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + compute_t ps[n_loop_unroll] = {}; + compute_vec_t o_acc = 0; + if (lane_active_for_io) { + for (auto tt = tt_low; tt < tt_high; tt += dtt) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; + // load the V[b][t][g][h|0][:] row into registers, reusing K register + // storage + load_v( + cache_V_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + } + +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + o_acc = + scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + + for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the V[b][t][g][h|0][:] row into registers, reusing K + // register storage + load_v( + cache_V_base + t * arg.K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + o_acc = scalar_scale_acc( + o_acc, k_loads[ttt], ps[ttt]); + } + } + } + } + __syncthreads(); + + // NB: needs sizeof(smem) >= `vec_size` * (sizeof(float)==4) * + // threadsPerBlock + if (lane_active_for_io) { + store_v(&smem[0], thread_linear_idx, o_acc); + } + + __syncthreads(); + // sum up partial D rows from other wavefronts + if (wavefront_idx == 0 && lane_active_for_io) { + union { + compute_vec_t vec = 0; + compute_t arr[vec_size]; + } r; + for (int32_t w = 0; w < wavefronts_per_block; ++w) { + compute_vec_t partial_r; + load_v( + smem, w * threads_per_wavefront + lane_idx, &partial_r); + r.vec += partial_r; + } + // elementwise convert from compute_t result to data_t out to be written + union { + data_vec_t vec; + data_t arr[vec_size]; + } bf_r; +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + bf_r.arr[i] = ck::type_convert(r.arr[i]); + } + // write output row O[b][m][g][h][:] + data_t* __restrict__ o_ = + arg.split_O + XQO_base_offset + split_idx * arg.O_stride_split; + store_v(o_, lane_idx, bf_r.vec); + } + } +}; + +} // namespace ck_tile diff --git a/xformers/csrc/attention/hip_fmha/ck_attention_inner_product.h b/xformers/csrc/attention/hip_decoder/ck_attention_inner_product.h similarity index 100% rename from xformers/csrc/attention/hip_fmha/ck_attention_inner_product.h rename to xformers/csrc/attention/hip_decoder/ck_attention_inner_product.h diff --git a/xformers/csrc/attention/hip_fmha/ck_attention_math_ext.h b/xformers/csrc/attention/hip_decoder/ck_attention_math_ext.h similarity index 100% rename from xformers/csrc/attention/hip_fmha/ck_attention_math_ext.h rename to xformers/csrc/attention/hip_decoder/ck_attention_math_ext.h diff --git a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp index 2bc96fa7ee..ffe12981bb 100644 --- a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp @@ -111,8 +111,7 @@ efficient_attention_backward_ck( TORCH_CHECK(max_seqlen_k_.has_value()); } - // at::cuda::CUDAGuard device_guard(query.device()); - hipStream_t stream = at::cuda::getCurrentHIPStream().stream(); + hipStream_t stream = at::hip::getCurrentHIPStream().stream(); int64_t B = query.size(0); int64_t M = query.size(1); @@ -154,8 +153,8 @@ efficient_attention_backward_ck( grad_q = at::empty_strided(query.sizes(), query.strides(), query.options()); } else { grad_q = at::empty_strided(query.sizes(), query.strides(), query.options()); - grad_k = at::empty_strided(key.sizes(), key.strides(), key.options()); - grad_v = at::empty_strided(value.sizes(), value.strides(), value.options()); + grad_k = at::empty(key.sizes(), key.options()); + grad_v = at::empty(value.sizes(), value.options()); } at::Tensor grad_q_f32; @@ -174,9 +173,7 @@ efficient_attention_backward_ck( TORCH_CHECK(query.sizes() == grad_q.sizes()); TORCH_CHECK(query.strides() == grad_q.strides()); TORCH_CHECK(key.sizes() == grad_k.sizes()); - TORCH_CHECK(key.strides() == grad_k.strides()); TORCH_CHECK(value.sizes() == grad_v.sizes()); - TORCH_CHECK(value.strides() == grad_v.strides()); const bool bias_requires_grad = bias.has_value() && bias->requires_grad(); diff --git a/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp b/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp index 94a7250a6d..cbcc3a1fc1 100644 --- a/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp @@ -33,8 +33,7 @@ at::Tensor rand_uniform_int( int M = out_pattern.size(2); int N = out_pattern.size(3); - // at::cuda::CUDAGuard device_guard(out_pattern.device()); - hipStream_t stream = at::cuda::getCurrentHIPStream().stream(); + hipStream_t stream = at::hip::getCurrentHIPStream().stream(); at::CUDAGeneratorImpl* gen = at::get_generator_or_default( @@ -59,8 +58,7 @@ at::Tensor rand_uniform_int( { // only work for batched mode - using FmhaRandUniformKernel_ = - FmhaRandUniformKernel<128, 64, 32, uint8_t, false>; + using FmhaRandUniformKernel_ = FmhaRandUniformKernel; const auto kargs = FmhaRandUniformKernel_::MakeKargs( randvals.data_ptr(), diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp index b17c036aee..fbc43d21dd 100644 --- a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp @@ -6,6 +6,7 @@ */ #include #include +#include #include #include @@ -19,6 +20,7 @@ #include #include "ck_fmha_util.h" +#include "ck_tiled_fmha_fwd_splitkv_selector.h" #include "ck_tiled_fmha_params.h" extern void batched_forward_fp16( @@ -65,7 +67,9 @@ efficient_attention_forward_ck( int64_t custom_mask_type, c10::optional scale, const c10::optional& seqlen_k, - const c10::optional window_size) { + const c10::optional window_size, + const c10::optional& block_tables, + const c10::optional page_size) { TORCH_CHECK(query.dim() == 4); TORCH_CHECK(key.dim() == 4); TORCH_CHECK(value.dim() == 4); @@ -92,20 +96,27 @@ efficient_attention_forward_ck( TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int); TORCH_CHECK(seqstart_k->scalar_type() == at::ScalarType::Int); TORCH_CHECK(seqstart_q->dim() == 1 && seqstart_k->dim() == 1); - TORCH_CHECK(seqstart_q->size(0) == seqstart_k->size(0)); + TORCH_CHECK( + seqstart_q->size(0) == seqstart_k->size(0) || + seqstart_q->size(0) == seqstart_k->size(0) + 1); TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); TORCH_CHECK(max_seqlen_q_.has_value()); CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_q)); CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_k)); }; + TORCH_CHECK(block_tables.has_value() == page_size.has_value()); + TORCH_CHECK(!block_tables.has_value() || block_tables->dim() == 2); + + // Currently xformers only use Paged-KVcache in grouped mode + TORCH_CHECK(seqstart_q.has_value() || !block_tables.has_value()); + // last dim is contiguous, device is kCUDA CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); - // at::cuda::CUDAGuard device_guard(query.device()); - hipStream_t stream = at::cuda::getCurrentHIPStream().stream(); + hipStream_t stream = at::hip::getCurrentHIPStream().stream(); int64_t B = query.size(0); int64_t M = query.size(1); @@ -121,6 +132,9 @@ efficient_attention_forward_ck( at::Tensor out = at::empty({B, M, Hq, Kv}, opts); + at::Tensor logsumexp_acc; + at::Tensor out_acc; + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; int64_t philox_seed; int64_t philox_offset; @@ -225,6 +239,38 @@ efficient_attention_forward_ck( p.logsumexp_ptr = nullptr; p.lse_strides = {0, 0, 0}; } + + bool use_split_kv; + int num_kv_splits; + + std::tie(use_split_kv, num_kv_splits) = + get_num_kv_splits_heuristic(p.B, p.Hq, p.M, std::max(p.K, p.Kv), 8); + + // 1) fmha fwd split-kv kernel does not support dropout + p.use_split_kv = (!use_dropout && use_split_kv) ? true : false; + + p.num_kv_splits = num_kv_splits; + + if (p.use_split_kv && p.num_kv_splits > 1) { + out_acc = + at::empty({p.num_kv_splits, B, M, Hq, Kv}, opts.dtype(at::kFloat)); + p.out_acc_ptr = out_acc.data_ptr(); + p.out_acc_strides = { + static_cast(out_acc.stride(0)), + static_cast(out_acc.stride(1)), + static_cast(out_acc.stride(2)), + static_cast(out_acc.stride(3)), + static_cast(out_acc.stride(4))}; + + logsumexp_acc = + at::empty({p.num_kv_splits, B, Hq, M}, opts.dtype(at::kFloat)); + p.logsumexp_acc_ptr = logsumexp_acc.data_ptr(); + p.lse_acc_strides = { + static_cast(logsumexp_acc.stride(0)), + static_cast(logsumexp_acc.stride(1)), + static_cast(logsumexp_acc.stride(2)), + static_cast(logsumexp_acc.stride(3))}; + } }; auto set_grouped_forward_params = [&](GroupedForwardParams& p) { @@ -305,6 +351,22 @@ efficient_attention_forward_ck( } else p.seqlen_k_dev_ptr = nullptr; + p.is_gappy = false; + if (block_tables.has_value()) { + p.block_table_ptr = block_tables->data_ptr(); + p.page_block_size = *page_size; + p.batch_stride_block_table = block_tables->stride(0); + p.use_paged_kvcache = true; + + TORCH_CHECK(seqlen_k.has_value()); + + // PageBlockDiagonalGappyKeysMask has special way to use seqstart_k, + // somehow ck_tile kernel need know this + if (seqstart_k->size(0) == seqlen_k->size(0)) + p.is_gappy = true; + } else + p.use_paged_kvcache = false; + p.philox_seed = philox_seed; p.philox_offset = philox_offset; p.compute_logsumexp = compute_logsumexp; @@ -325,6 +387,38 @@ efficient_attention_forward_ck( p.logsumexp_ptr = nullptr; p.lse_strides = {0, 0}; } + + bool use_split_kv; + int num_kv_splits; + + // added for support split_kv + std::tie(use_split_kv, num_kv_splits) = get_num_kv_splits_heuristic( + p.num_batches, p.Hq, p.max_seqlen_q, std::max(p.K, p.Kv), 8); + + // 1) fmha fwd split-kv kernel does not support dropout + // 2) Paged-KVcache is only available from the split-kv kernel at present + p.use_split_kv = + (p.use_paged_kvcache || (!use_dropout && use_split_kv)) ? true : false; + + p.num_kv_splits = num_kv_splits; + + if (p.use_split_kv && p.num_kv_splits > 1) { + out_acc = at::empty({p.num_kv_splits, M, Hq, Kv}, opts.dtype(at::kFloat)); + p.out_acc_ptr = out_acc.data_ptr(); + p.out_acc_strides = { + static_cast(out_acc.stride(0)), + static_cast(out_acc.stride(1)), + static_cast(out_acc.stride(2)), + static_cast(out_acc.stride(3))}; + + logsumexp_acc = + at::empty({p.num_kv_splits, 1, Hq, M}, opts.dtype(at::kFloat)); + p.logsumexp_acc_ptr = logsumexp_acc.data_ptr(); + p.lse_acc_strides = { + static_cast(logsumexp_acc.stride(0)), + static_cast(logsumexp_acc.stride(2)), + static_cast(logsumexp_acc.stride(3))}; + } }; auto inDataType = query.scalar_type(); @@ -398,7 +492,9 @@ efficient_attention_forward_ck_meta( int64_t custom_mask_type, c10::optional scale, const c10::optional& seqlen_k, - const c10::optional window_size) { + const c10::optional window_size, + const c10::optional& block_tables, + const c10::optional page_size) { int64_t B = query.size(0); int64_t M = query.size(1); int64_t N = key.size(1); diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp deleted file mode 100644 index fd70436a36..0000000000 --- a/xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp +++ /dev/null @@ -1,1184 +0,0 @@ -#include -#include -#include -#include -#include - -#include "ck_attention_forward_decoder_splitk.h" - -namespace { -constexpr int32_t kThreadsPerWavefront = 64; -constexpr int32_t kWavefrontsPerBlock = 4; -constexpr int32_t kMaxHeadDimension = 4 * kThreadsPerWavefront; -constexpr int32_t kMaxKVSequenceLength = 4096; -constexpr int32_t kLoopUnroll = 16; -constexpr int32_t kLoopUnrollTail = 2; -using compute_t = float; -} // namespace - -namespace { - -template -struct c10_to_data_t; -template <> -struct c10_to_data_t { - using type = float; -}; - -template <> -struct c10_to_data_t { - using type = ck::half_t; -}; - -template <> -struct c10_to_data_t { - using type = ck::bhalf_t; -}; -} // namespace - -#define AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ - AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) - -#define AT_DISPATCH_SWITCH_3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) - -namespace { - -template < - int32_t ThreadsPerWavefront, - int32_t WavefrontsPerBlock> -at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( - const at::Tensor& XQ, // [B, 1, G, H, D] - const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] - const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] - at::optional seq_kv_lens, // [B] - double qk_scale, - int64_t split_k, - at::Tensor& split_max, - at::Tensor& split_sumexp, - at::Tensor& split_O, - at::Tensor& O) { - static_assert(4 * ThreadsPerWavefront == kMaxHeadDimension, ""); - static_assert(WavefrontsPerBlock <= ThreadsPerWavefront, ""); - - at::OptionalDeviceGuard guard(XQ.device()); - TORCH_CHECK(XQ.is_cuda()); - TORCH_CHECK(cache_K.is_cuda()); - TORCH_CHECK(cache_V.is_cuda()); - - TORCH_CHECK(!seq_kv_lens || seq_kv_lens->is_cuda()); - - TORCH_CHECK(cache_K.size(1) / split_k <= kMaxKVSequenceLength); - TORCH_CHECK(cache_K.size(4) <= kMaxHeadDimension); - - constexpr auto rank = 5; - - auto B = XQ.size(0); - auto M = XQ.size(1); - auto G = XQ.size(2); - auto H = XQ.size(3); - - TORCH_CHECK(B <= 1024); - TORCH_CHECK(M <= 1024); - TORCH_CHECK(H <= 1024); - - dim3 blocks(B * H * M * G, split_k); - dim3 threads(ThreadsPerWavefront, WavefrontsPerBlock); - - int32_t smem_softmax = kMaxKVSequenceLength * sizeof(compute_t) + - WavefrontsPerBlock * sizeof(compute_t); - int32_t smem_output = kMaxHeadDimension * sizeof(compute_t) * - threads.y; // 4 * threadsPerBlock * sizeof(float) == sizeof(O[b][0][h][:]) - const size_t lds_bytes = max(smem_softmax, smem_output); - auto stream = at::cuda::getCurrentHIPStream().stream(); - - AT_DISPATCH_SWITCH_3( - at::ScalarType::Half, - at::ScalarType::BFloat16, - at::ScalarType::Float, - XQ.scalar_type(), - "efficient_attention_forward_decoder_splitk_ck", - [&] { - using ck_data_t = c10_to_data_t::type; - using device_op_t = - ck::tensor_operation::device::FMHADecoderSplitKDeviceOp< - ck_data_t, - kMaxKVSequenceLength, - kLoopUnroll, - kLoopUnrollTail, - compute_t>; - auto op = device_op_t{}; - - auto XQ_acc = - XQ.packed_accessor32(); - auto K_acc = - cache_K.packed_accessor64(); - auto V_acc = - cache_V.packed_accessor64(); - auto split_O_acc = - split_O - .packed_accessor32(); - auto O_acc = - O.packed_accessor32(); - auto seq_acc_ptr = seq_kv_lens - ? seq_kv_lens - ->packed_accessor32() - .data() - : nullptr; - auto split_max_acc = - split_max.packed_accessor32(); - auto split_sumexp_acc = - split_sumexp - .packed_accessor32(); - auto arg = device_op_t::Argument( - reinterpret_cast(XQ_acc.data()), - reinterpret_cast(K_acc.data()), - reinterpret_cast(V_acc.data()), - reinterpret_cast(O_acc.data()), - reinterpret_cast(split_O_acc.data()), - split_max_acc.data(), - split_sumexp_acc.data(), - seq_acc_ptr, - XQ_acc.stride(0), - XQ_acc.stride(1), - XQ_acc.stride(2), - XQ_acc.stride(3), - K_acc.stride(0), - K_acc.stride(1), - K_acc.stride(2), - K_acc.stride(3), - split_O_acc.stride(0), - XQ_acc.size(1), - XQ_acc.size(2), - XQ_acc.size(3), - XQ_acc.size(4), - K_acc.size(1), - K_acc.size(3) == 1, - qk_scale, - split_k, - blocks, - threads, - lds_bytes); - - auto invoker = device_op_t::Invoker{}; - (void)invoker.Run(&arg, {stream}); - }); - - return O; -} - -template -at::Tensor efficient_attention_forward_decoder_splitk_ck_impl( - const at::Tensor& XQ, // [B, 1, G, H, D] - const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] - const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, H or 1, D] - at::optional seq_kv_lens, // [B] - double qk_scale, - int64_t split_k) { - auto O = at::empty_like(XQ); - constexpr auto rank = 5; - - TORCH_CHECK(XQ.dim() == rank); - TORCH_CHECK(cache_K.dim() == rank); - TORCH_CHECK(cache_V.dim() == rank); - - auto B = XQ.size(0); - auto M = XQ.size(1); - auto G = XQ.size(2); - auto H = XQ.size(3); - auto K = XQ.size(4); - - auto O_splits = at::empty({split_k, B, M, G, H, K}, XQ.options()); - auto split_max = - at::empty({B, M, G, H, split_k}, XQ.options().dtype(at::kFloat)); - auto split_sumexp = at::empty_like(split_max); - - efficient_attention_forward_decoder_splitk_ck_out_impl< - ThreadsPerWavefront, - WavefrontsPerBlock>( - XQ, - cache_K, - cache_V, - seq_kv_lens, - qk_scale, - split_k, - split_max, - split_sumexp, - O_splits, - O); - - return O; -} - -at::Tensor efficient_attention_forward_decoder_splitk_ck( - const at::Tensor& XQ, // [B, 1, G, H, D] - const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] - const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] - at::optional seq_kv_lens, // [B] - double qk_scale, - int64_t split_k) { - return efficient_attention_forward_decoder_splitk_ck_impl< - kThreadsPerWavefront, - kWavefrontsPerBlock>( - XQ, cache_K, cache_V, seq_kv_lens, qk_scale, split_k); -} -} // namespace - -TORCH_LIBRARY_IMPL(xformers, CUDA, m) { - m.impl( - TORCH_SELECTIVE_NAME( - "xformers::efficient_attention_forward_decoder_splitk_ck"), - TORCH_FN(efficient_attention_forward_decoder_splitk_ck)); -} - -#ifdef ATTN_FWD_SPLITK_DECODER_MAIN - -#include - -// clang-format off - -/* - -(1) hipify - > pip install -e /xformers - - For obtaining the executed build commands, add `--verbose`. - For efficient utilization of CPU cores for compilation use MAX_JOBS env variable. - -(2) compile - > mkdir build - > cd build - > cmake /xformers/xformers/csrc/attention/hip_fmha/ \ - -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_BUILD_TYPE=Debug \ - -D GPU_TARGETS="native" - > make - -(3a) run correctness check - > ./attention_forward_splitk_decoder_main - -(3b) run specific input shape - > ./attention_forward_splitk_decoder_main n_keys padding batch_size n_heads is_multiquery dtype n_wavefronts_per_block -*/ - -// clang-format on - -static std::tuple split_attention_torch( - const at::Tensor& Q, - const at::Tensor& K, - const at::Tensor& V, - const at::Tensor& k_seqlens, - const int32_t split_k, - const int32_t block_size) { - auto Q_scaled = at::div(Q, sqrt(Q.size(-1))); - - std::vector O_splits; - std::vector m_splits; - std::vector l_splits; - - for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { - std::vector O_batch; - std::vector m_batch; - std::vector l_batch; - - for (size_t b = 0; b < k_seqlens.numel(); ++b) { - auto seqlen = k_seqlens[b].item(); - const int64_t t_low = - split_idx * (seqlen / split_k / block_size) * block_size; - const int64_t t_high = (split_idx + 1 < split_k) - ? (1 + split_idx) * (seqlen / split_k / block_size) * block_size - : seqlen; - - const bool empty = t_low == t_high; - - auto S = at::einsum( - "mghk, nghk -> mghn", - {Q_scaled[b], - at::slice(K[b], /*dim*/ 0, /*start*/ t_low, /*end*/ t_high)}, - /* einsum eval path */ at::nullopt); - auto m = empty - ? at::empty_like(S) - : std::get<0>(at::max(S, /* dim */ -1, /* keepdim */ true)); - auto s = at::exp(at::sub(S, m)); - auto l = at::sum(s, /* dim */ -1, /* keepdim */ true); - auto O = at::einsum( - "mghn, nghk -> mghk", - {s, at::slice(V[b], /*dim*/ 0, /*start*/ t_low, /*end*/ t_high)}, - /* einsum eval path */ at::nullopt); - if (empty) { - m = at::empty_like(at::slice(O, -1, 0, 1)); - l = at::zeros_like(m); - m.fill_(ck::NumericLimits::Lowest()); - } - O_batch.push_back(O); - m_batch.push_back(m); - l_batch.push_back(l); - } - - auto O_cat = at::stack(O_batch); - auto m_cat = at::stack(m_batch); - auto l_cat = at::stack(l_batch); - - O_splits.push_back(O_cat); - m_splits.push_back(m_cat); - l_splits.push_back(l_cat); - } - - auto O_cat = at::stack(O_splits); - auto m_cat = at::transpose(at::stack(m_splits), 0, -1); - auto l_cat = at::transpose(at::stack(l_splits), 0, -1); - - return std::make_tuple(O_cat, m_cat, l_cat); -} - -static at::Tensor split_reduce_torch( - const at::Tensor& O_splits, - const at::Tensor& m_splits, - const at::Tensor& l_splits, - int32_t split_k) { - auto O = at::zeros_like(at::slice(O_splits, 0, 0, 1)); - auto global_max = - at::empty_like(at::slice(m_splits, -1, 0, 1)).fill_(-65535.); - auto global_sumexp = at::zeros_like(global_max); - - for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { - auto local_O = at::slice(O_splits, 0, split_idx, split_idx + 1); - auto local_max = at::slice(m_splits, -1, split_idx, split_idx + 1); - auto local_sumexp = at::slice(l_splits, -1, split_idx, split_idx + 1); - - auto log_alpha = at::neg(at::abs(at::sub(local_max, global_max))); - auto alpha = at::exp(log_alpha); - alpha.nan_to_num_(1.); - - auto pick_new = at::less(local_max, global_max); - auto pick_current_coef = at::where(pick_new, 1., alpha); - auto pick_new_coef = at::where(pick_new, alpha, 1.); - - O = at::add(at::mul(pick_current_coef, O), at::mul(pick_new_coef, local_O)); - global_sumexp = at::add( - at::mul(pick_current_coef, global_sumexp), - at::mul(pick_new_coef, local_sumexp)); - global_max = at::max(local_max, global_max); - } - - return at::div(O, global_sumexp); -} - -static at::Tensor efficient_attention_forward_decoder_splitk_torch( - const at::Tensor& XQ, // [B, 1, G, H, D] - const at::Tensor& cache_K, // [B, kMaxKVSequenceLength, G, H or 1, D] - const at::Tensor& cache_V, // [B, kMaxKVSequenceLength, G, H or 1, D] - at::optional seq_kv_lens, // [B] - double qk_scale, - int32_t split_k, - int32_t block_size) { - auto [O_split, m, l] = split_attention_torch( - XQ, cache_K, cache_V, *seq_kv_lens, split_k, block_size); - auto O = split_reduce_torch(O_split, m, l, split_k); - return O.reshape_as(XQ); -} - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct FMHADecoderSplitAttentionDeviceOp : public BaseOperator { - using DeviceOp = FMHADecoderSplitAttentionDeviceOp; - struct Argument : public BaseArgument { - const scalar_t* __restrict__ XQ; - const scalar_t* __restrict__ cache_K; - const scalar_t* __restrict__ cache_V; - scalar_t* __restrict__ O; - scalar_t* __restrict__ split_O; - compute_t* __restrict__ split_max; - compute_t* __restrict__ split_sumexp; - const int32_t* __restrict__ seq_kv_lens; - const ptrdiff_t XQ_stride_b; - const ptrdiff_t XQ_stride_m; - const ptrdiff_t XQ_stride_g; - const ptrdiff_t XQ_stride_h; - const ptrdiff_t K_stride_b; - const ptrdiff_t K_stride_m; - const ptrdiff_t K_stride_g; - const ptrdiff_t K_stride_h; - const ptrdiff_t O_stride_split; - const int32_t Q_size_m; - const int32_t Q_size_g; - const int32_t Q_size_h; - const int32_t Q_size_k; - const int32_t K_size_m; - const bool multiquery; - const float qk_scale; - const int32_t split_k; - - const dim3 grid_dim; - const dim3 block_dim; - const size_t lds_bytes; - - Argument( - const scalar_t* __restrict__ XQ, - const scalar_t* __restrict__ cache_K, - const scalar_t* __restrict__ cache_V, - scalar_t* __restrict__ O, - scalar_t* __restrict__ split_O, - compute_t* __restrict__ split_max, - compute_t* __restrict__ split_sumexp, - const int32_t* __restrict__ seq_kv_lens, - const ptrdiff_t XQ_stride_b, - const ptrdiff_t XQ_stride_m, - const ptrdiff_t XQ_stride_g, - const ptrdiff_t XQ_stride_h, - const ptrdiff_t K_stride_b, - const ptrdiff_t K_stride_m, - const ptrdiff_t K_stride_g, - const ptrdiff_t K_stride_h, - const ptrdiff_t O_stride_split, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const int32_t K_size_m, - const bool multiquery, - const float qk_scale, - const int32_t split_k, - // launch params - const dim3 grid_dim, - const dim3 block_dim, - const size_t lds_bytes) - : XQ(XQ), - cache_K(cache_K), - cache_V(cache_V), - O(O), - split_O(split_O), - split_max(split_max), - split_sumexp(split_sumexp), - seq_kv_lens(seq_kv_lens), - XQ_stride_b(XQ_stride_b), - XQ_stride_m(XQ_stride_m), - XQ_stride_g(XQ_stride_g), - XQ_stride_h(XQ_stride_h), - K_stride_b(K_stride_b), - K_stride_m(K_stride_m), - K_stride_g(K_stride_g), - K_stride_h(K_stride_h), - O_stride_split(O_stride_split), - Q_size_m(Q_size_m), - Q_size_g(Q_size_g), - Q_size_h(Q_size_h), - Q_size_k(Q_size_k), - K_size_m(K_size_m), - multiquery(multiquery), - qk_scale(qk_scale), - split_k(split_k), - // launch params - grid_dim(grid_dim), - block_dim(block_dim), - lds_bytes(lds_bytes) {} - - std::string str() const { - std::ostringstream oss; - oss << "Argument { " << std::endl - << " XQ: " << XQ << std::endl - << " cache_K: " << cache_K << std::endl - << " cache_V: " << cache_V << std::endl - << " O: " << O << std::endl - << " split_O: " << split_O << std::endl - << " split_max: " << split_max << std::endl - << " split_sumexp: " << split_sumexp << std::endl - << " seq_kv_lens: " << seq_kv_lens << std::endl - << " XQ_stride_b: " << XQ_stride_b << std::endl - << " XQ_stride_m: " << XQ_stride_m << std::endl - << " XQ_stride_g: " << XQ_stride_g << std::endl - << " XQ_stride_h: " << XQ_stride_h << std::endl - << " K_stride_b: " << K_stride_b << std::endl - << " K_stride_m: " << K_stride_m << std::endl - << " K_stride_g: " << K_stride_g << std::endl - << " K_stride_h: " << K_stride_h << std::endl - << " O_stride_split: " << O_stride_split << std::endl - << " Q_size_m: " << Q_size_m << std::endl - << " Q_size_g: " << Q_size_g << std::endl - << " Q_size_h: " << Q_size_h << std::endl - << " Q_size_k: " << Q_size_k << std::endl - << " K_size_m: " << K_size_m << std::endl - << " multiquery: " << multiquery << std::endl - << " qk_scale: " << qk_scale << std::endl - << " split_k: " << split_k << std::endl - << std::endl - << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." - << grid_dim.z << std::endl - << " block_dim: " << block_dim.x << "." << block_dim.y << "." - << block_dim.z << std::endl - << " lds_bytes: " << lds_bytes << std::endl - << "}"; - return oss.str(); - } - }; - - struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run( - const Argument& arg, - const StreamConfig& stream_config = StreamConfig{}) { - auto threads_per_wavefront = arg.block_dim.x; - auto Q_size_k_alignment_necessary = 0; - - for (auto vec_size : {4, 2, 1}) { - if (arg.Q_size_k <= vec_size * threads_per_wavefront) { - Q_size_k_alignment_necessary = vec_size; - } - } - - if (!Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported Q_size_k"); - } - - if (arg.Q_size_k % Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported alignment for Q_size_k"); - } - - float split_attention_result = launch_and_time_kernel( - stream_config, - Q_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - 4, - kLoopUnroll, - kLoopUnrollTail, - kMaxKVSequenceLength, - compute_t> - : Q_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - 2, - kLoopUnroll, - kLoopUnrollTail, - kMaxKVSequenceLength, - compute_t> - : Q_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - 1, - kLoopUnroll, - kLoopUnrollTail, - kMaxKVSequenceLength, - compute_t> - : nullptr, - arg.grid_dim, - arg.block_dim, - arg.lds_bytes, - arg.XQ, - arg.cache_K, - arg.cache_V, - arg.split_O, - arg.split_max, - arg.split_sumexp, - arg.seq_kv_lens, - arg.XQ_stride_b, - arg.XQ_stride_m, - arg.XQ_stride_g, - arg.XQ_stride_h, - arg.K_stride_b, - arg.K_stride_m, - arg.K_stride_g, - arg.K_stride_h, - arg.O_stride_split, - arg.Q_size_m, - arg.Q_size_g, - arg.Q_size_h, - arg.Q_size_k, - arg.K_size_m, - arg.multiquery, - arg.qk_scale, - arg.split_k); - - return split_attention_result; - } - }; -}; - -template -struct FMHADecoderSplitReduceDeviceOp : public BaseOperator { - using DeviceOp = FMHADecoderSplitReduceDeviceOp; - struct Argument : public BaseArgument { - const scalar_t* __restrict__ split_O; - const compute_t* __restrict__ split_max; - const compute_t* __restrict__ split_sumexp; - scalar_t* __restrict__ O; - - const int32_t O_size_m; - const int32_t O_size_g; - const int32_t O_size_h; - const int32_t O_size_k; - - const ptrdiff_t O_stride_split; - const ptrdiff_t O_stride_b; - const ptrdiff_t O_stride_m; - const ptrdiff_t O_stride_g; - const ptrdiff_t O_stride_h; - - const int32_t split_k; - - const dim3 grid_dim; - const dim3 block_dim; - const size_t lds_bytes; - - Argument( - const scalar_t* __restrict__ split_O, - const compute_t* __restrict__ split_max, - const compute_t* __restrict__ split_sumexp, - scalar_t* __restrict__ O, - const int32_t O_size_m, - const int32_t O_size_g, - const int32_t O_size_h, - const int32_t O_size_k, - const ptrdiff_t O_stride_split, - const ptrdiff_t O_stride_b, - const ptrdiff_t O_stride_m, - const ptrdiff_t O_stride_g, - const ptrdiff_t O_stride_h, - const int32_t split_k, - // launch params - const dim3 grid_dim, - const dim3 block_dim, - const size_t lds_bytes) - : split_O(split_O), - split_max(split_max), - split_sumexp(split_sumexp), - O(O), - O_size_m(O_size_m), - O_size_g(O_size_g), - O_size_h(O_size_h), - O_size_k(O_size_k), - O_stride_split(O_stride_split), - O_stride_b(O_stride_b), - O_stride_m(O_stride_m), - O_stride_g(O_stride_g), - O_stride_h(O_stride_h), - split_k(split_k), - // launch params - grid_dim(grid_dim), - block_dim(block_dim), - lds_bytes(lds_bytes) {} - - std::string str() const { - std::ostringstream oss; - oss << "Argument { " << std::endl - << " O: " << O << std::endl - << " split_O: " << split_O << std::endl - << " split_max: " << split_max << std::endl - << " split_sumexp: " << split_sumexp << std::endl - << " O_stride_b: " << O_stride_b << std::endl - << " O_stride_m: " << O_stride_m << std::endl - << " O_stride_g: " << O_stride_g << std::endl - << " O_stride_h: " << O_stride_h << std::endl - << " O_stride_split: " << O_stride_split << std::endl - << " O_size_m: " << O_size_m << std::endl - << " O_size_g: " << O_size_g << std::endl - << " O_size_h: " << O_size_h << std::endl - << " O_size_k: " << O_size_k << std::endl - << " split_k: " << split_k << std::endl - << std::endl - << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." - << grid_dim.z << std::endl - << " block_dim: " << block_dim.x << "." << block_dim.y << "." - << block_dim.z << std::endl - << " lds_bytes: " << lds_bytes << std::endl - << "}"; - return oss.str(); - } - }; - - struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run( - const Argument& arg, - const StreamConfig& stream_config = StreamConfig{}) { - auto threads_per_wavefront = arg.block_dim.x; - auto O_size_k_alignment_necessary = 0; - - for (auto vec_size : {4, 2, 1}) { - if (arg.O_size_k <= vec_size * threads_per_wavefront) { - O_size_k_alignment_necessary = vec_size; - } - } - - if (!O_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported O_size_k"); - } - - if (arg.O_size_k % O_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported alignment for O_size_k"); - } - - const dim3 reduce_gridsize = {arg.grid_dim.x}; - const dim3 reduce_blocksize = {arg.block_dim.x}; - constexpr int32_t reduce_lds_bytes = 0; - float reduce_result = launch_and_time_kernel( - stream_config, - O_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 4> - : O_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 2> - : O_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 1> - : nullptr, - reduce_gridsize, - reduce_blocksize, - reduce_lds_bytes, - arg.split_O, - arg.split_max, - arg.split_sumexp, - arg.O, - arg.O_size_m, - arg.O_size_g, - arg.O_size_h, - arg.O_size_k, - arg.O_stride_split, - arg.O_stride_b, - arg.O_stride_m, - arg.O_stride_g, - arg.O_stride_h, - arg.split_k); - return reduce_result; - } - }; -}; -} // namespace device -} // namespace tensor_operation -} // namespace ck - -static std::tuple split_attention_hip( - const at::Tensor& XQ, - const at::Tensor& K, - const at::Tensor& V, - const at::Tensor& seqlen, - const int32_t split_k, - const int32_t wavefronts_per_block) { - at::OptionalDeviceGuard guard(XQ.device()); - - auto B = XQ.size(0); - auto M = XQ.size(1); - auto G = XQ.size(2); - auto H = XQ.size(3); - auto D = XQ.size(4); - - double qk_scale = 1. / sqrt(D); - - auto O = at::empty_like(XQ); - constexpr auto rank = 5; - auto split_O = at::zeros({split_k, B, M, G, H, D}, XQ.options()); - auto split_max = - at::empty({B, M, G, H, split_k}, XQ.options().dtype(at::kFloat)) - .fill_(ck::NumericLimits::Lowest()); - auto split_sumexp = at::zeros_like(split_max); - - dim3 blocks(B * H * M * G, split_k); - dim3 threads(kThreadsPerWavefront, wavefronts_per_block); - - int32_t smem_softmax = - kMaxKVSequenceLength * sizeof(float) + threads.y * sizeof(float); - int32_t smem_output = kMaxHeadDimension * sizeof(float) * - wavefronts_per_block; // 4 * threadsPerBlock * sizeof(float) == - // sizeof(O[b][0][h][:]) - const size_t lds_bytes = max(smem_softmax, smem_output); - auto stream = at::cuda::getCurrentHIPStream().stream(); - - AT_DISPATCH_SWITCH_3( - at::ScalarType::Half, - at::ScalarType::BFloat16, - at::ScalarType::Float, - XQ.scalar_type(), - "efficient_attention_forward_decoder_split_attention_ck_test", - [&] { - using ck_data_t = c10_to_data_t::type; - using device_op_t = - ck::tensor_operation::device::FMHADecoderSplitAttentionDeviceOp< - ck_data_t>; - auto op = device_op_t{}; - - auto XQ_acc = - XQ.packed_accessor32(); - auto K_acc = - K.packed_accessor64(); - auto V_acc = - V.packed_accessor64(); - auto split_O_acc = - split_O - .packed_accessor32(); - auto O_acc = - O.packed_accessor32(); - auto seq_acc = - seqlen.packed_accessor32(); - auto split_max_acc = - split_max.packed_accessor32(); - auto split_sumexp_acc = - split_sumexp - .packed_accessor32(); - auto arg = device_op_t::Argument( - reinterpret_cast(XQ_acc.data()), - reinterpret_cast(K_acc.data()), - reinterpret_cast(V_acc.data()), - reinterpret_cast(O_acc.data()), - reinterpret_cast(split_O_acc.data()), - split_max_acc.data(), - split_sumexp_acc.data(), - seq_acc.data(), - XQ_acc.stride(0), - XQ_acc.stride(1), - XQ_acc.stride(2), - XQ_acc.stride(3), - K_acc.stride(0), - K_acc.stride(1), - K_acc.stride(2), - K_acc.stride(3), - split_O_acc.stride(0), - XQ_acc.size(1), - XQ_acc.size(2), - XQ_acc.size(3), - XQ_acc.size(4), - K_acc.size(1), - K_acc.size(3) == 1, - qk_scale, - split_k, - blocks, - threads, - lds_bytes); - - auto invoker = device_op_t::Invoker{}; - (void)invoker.Run(arg, {stream}); - }); - return std::make_tuple(split_O, split_max, split_sumexp); -} - -static at::Tensor split_reduce_hip( - const at::Tensor& split_O, - const at::Tensor& split_max, - const at::Tensor& split_sumexp, - const int32_t split_k) { - at::OptionalDeviceGuard guard(split_O.device()); - - auto B = split_O.size(1); - auto M = split_O.size(2); - auto G = split_O.size(3); - auto H = split_O.size(4); - auto D = split_O.size(5); - - TORCH_CHECK_EQ(split_k, split_O.size(0)); - TORCH_CHECK_EQ(split_k, split_max.size(-1)); - TORCH_CHECK_EQ(split_k, split_sumexp.size(-1)); - - constexpr auto rank = 5; - - TORCH_CHECK_EQ(split_O.dim(), 1 + rank); - TORCH_CHECK_EQ(split_max.dim(), rank); - TORCH_CHECK_EQ(split_sumexp.dim(), rank); - - auto O = at::zeros({B, M, G, H, D}, split_O.options()); - - auto stream = at::cuda::getCurrentHIPStream().stream(); - auto lds_bytes = 0; - - dim3 blocks(B * H * M * G); - dim3 threads(kThreadsPerWavefront); - - AT_DISPATCH_SWITCH_3( - at::ScalarType::Half, - at::ScalarType::BFloat16, - at::ScalarType::Float, - O.scalar_type(), - "efficient_attention_forward_decoder_split_reduce_ck_test", - [&] { - using ck_data_t = c10_to_data_t::type; - using device_op_t = - ck::tensor_operation::device::FMHADecoderSplitReduceDeviceOp< - ck_data_t>; - auto op = device_op_t{}; - - auto split_O_acc = - split_O - .packed_accessor32(); - auto O_acc = - O.packed_accessor32(); - auto split_max_acc = - split_max.packed_accessor32(); - auto split_sumexp_acc = - split_sumexp - .packed_accessor32(); - auto arg = device_op_t::Argument( - reinterpret_cast(split_O_acc.data()), - split_max_acc.data(), - split_sumexp_acc.data(), - reinterpret_cast(O_acc.data()), - O_acc.size(1), - O_acc.size(2), - O_acc.size(3), - O_acc.size(4), - split_O_acc.stride(0), - O_acc.stride(0), - O_acc.stride(1), - O_acc.stride(2), - O_acc.stride(3), - split_k, - blocks, - threads, - lds_bytes); - - auto invoker = device_op_t::Invoker{}; - (void)invoker.Run(arg, {stream}); - }); - return O; -} - -std::tuple generate_inputs( - const int32_t padding, - const int32_t B, - const int32_t Hq, - const int32_t Hkv, - const decltype(torch::kFloat32) dtype = torch::kFloat32) { - const int32_t D = 4 * kThreadsPerWavefront; - const int32_t G = Hq / Hkv; - const int32_t num_queries = 1; - - at::manual_seed(1); - - auto options = torch::TensorOptions() - .dtype(dtype) - .layout(torch::kStrided) - .device(torch::kCUDA, 1) - .requires_grad(false); - auto int_options = options.dtype(torch::kInt); - auto XQ = at::randn({B, num_queries, G, Hq, D}, options); - auto K = (G == 1) ? at::randn({B, padding, G, Hkv, D}, options) - : at::randn({B, padding, G, 1, D}, options) - .expand({B, padding, G, Hq, D}); - auto V = at::randn_like(K); - auto seqlen = at::randint(num_queries, padding + 1, {B}, int_options); - - return std::make_tuple(XQ, K, V, seqlen); -} - -static float percent_mismatch(const at::Tensor& a, const at::Tensor& b) { - auto mask = - at::isclose(a, b, /*atol*/ 1e-3, /*rtol*/ 1e-5, /*equal_nan*/ false); - auto percent_match = at::sum(mask.to(torch::kFloat32)) / mask.numel(); - return 1. - percent_match.item(); -} - -static void test_split_attention( - int32_t padding, - int32_t batch_size, - int32_t Hq, - int32_t Hkv, - int32_t split_k) { - auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); - - auto [O_ref, m_ref, l_ref] = split_attention_torch( - XQ, K, V, seqlen, split_k, /* block_size */ kWavefrontsPerBlock * 16); - - auto [O_hip, m_hip, l_hip] = - split_attention_hip(XQ, K, V, seqlen, split_k, kWavefrontsPerBlock); - - auto O_percent_mismatch = percent_mismatch(O_ref, O_hip); - auto m_percent_mismatch = percent_mismatch(m_ref, m_hip); - auto l_percent_mismatch = percent_mismatch(l_ref, l_hip); - - printf( - "[Test split attention] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched split_O " - "elements percentage: %.2f Mismatched split_max elements percentage: %.2f Mismatched " - "split_sumexp elements percentage: %.2f\n", - padding, - batch_size, - Hq, - Hkv, - split_k, - O_percent_mismatch, - m_percent_mismatch, - l_percent_mismatch); -} - -static void test_split_reduce( - int32_t padding, - int32_t batch_size, - int32_t Hq, - int32_t Hkv, - int32_t split_k) { - auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); - - auto [O_ref, m_ref, l_ref] = - split_attention_hip(XQ, K, V, seqlen, split_k, kWavefrontsPerBlock); - - auto O_torch = split_reduce_torch( - O_ref, m_ref.unsqueeze(0), l_ref.unsqueeze(0), split_k); - auto O_hip = split_reduce_hip(O_ref, m_ref, l_ref, split_k); - - auto hip_torch_mismatch = percent_mismatch(O_hip, O_torch); - printf( - "[Test split reduce] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched elements " - "percentage: %.2f \n", - padding, - batch_size, - Hq, - Hkv, - split_k, - hip_torch_mismatch); -} - -static void test_splitk_decoder_e2e_correctness( - int32_t padding, - int32_t batch_size, - int32_t Hq, - int32_t Hkv, - int32_t split_k) { - auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); - - double qk_scale = 1. / sqrt(XQ.size(-1)); - - auto result = efficient_attention_forward_decoder_splitk_ck_impl< - kThreadsPerWavefront, - kWavefrontsPerBlock>(XQ, K, V, seqlen, qk_scale, split_k); - auto gold_result = efficient_attention_forward_decoder_splitk_torch( - XQ, K, V, seqlen, qk_scale, /* split_k */ 1, /* block_size */ 1); - auto e2e_mismatch = percent_mismatch(result, gold_result); - printf( - "[Test e2e split-k decoder] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched " - "elements percentage: %.2f\n", - padding, - batch_size, - Hq, - Hkv, - split_k, - e2e_mismatch); -} - -int main(int argc, char** argv) { - if (argc == 1) { - for (auto padding : {32, 4096}) { - for (auto batch_size : {1, 8}) { - for (auto Hq : {16}) { - for (auto Hkv : {16}) { - for (auto split_k : {1, 2, 4, 8, 16}) { - test_splitk_decoder_e2e_correctness( - padding, batch_size, Hq, Hkv, split_k); - } - } - } - } - } - - for (auto padding : {32, 4096}) { - for (auto batch_size : {1, 8}) { - for (auto Hq : {16}) { - for (auto Hkv : {16}) { - for (auto split_k : {1, 2, 4, 8, 16}) { - test_split_attention(padding, batch_size, Hq, Hkv, split_k); - } - } - } - } - } - - for (auto padding : {32, 4096}) { - for (auto batch_size : {1, 8}) { - for (auto Hq : {16}) { - for (auto Hkv : {16}) { - for (auto split_k : {1, 2}) { - test_split_reduce(padding, batch_size, Hq, Hkv, split_k); - } - } - } - } - } - } else { - const auto args = std::vector(argv + 1, argv + argc); - if (args.size() != 6) { - std::cout << "Usage: ./a.out padding batch_size nq_heads nkv_heads dtype " - "n_wavefronts_per_block" - << std::endl; - return 0; - } - const int32_t padding = std::stoi(args[0]); - const int32_t batch_size = std::stoi(args[1]); - const int32_t nq_heads = std::stoi(args[2]); - const int32_t nkv_heads = std::stoi(args[3]); - const auto dtype = (args[4] == "f32") ? torch::kFloat32 - : (args[4] == "f16") ? torch::kFloat16 - : torch::kBFloat16; - const int32_t n_wavefronts_per_block = std::stoi(args[5]); - - auto [Q, K, V, seq] = - generate_inputs(padding, batch_size, nq_heads, nkv_heads, dtype); - auto O = at::empty_like(Q); - - constexpr auto splitk_dim = 0; - constexpr auto split_k = 1; - auto O_splits = at::stack(O, splitk_dim); - - auto split_max = at::empty( - {batch_size, padding, Q.size(2), Q.size(3), split_k}, - Q.options().dtype(at::kFloat)); - auto split_sumexp = at::empty_like(split_max); - - const double qk_scale = 1. / sqrt(Q.size(-1)); - auto call_ptr = - decltype(&efficient_attention_forward_decoder_splitk_ck_out_impl< - kThreadsPerWavefront, - kWavefrontsPerBlock>){}; - -#define SWITCH_CASE_SET_CALLPTR(n) \ - case (n): \ - call_ptr = &efficient_attention_forward_decoder_splitk_ck_out_impl< \ - kThreadsPerWavefront, \ - (n)>; \ - break; - - switch (n_wavefronts_per_block) { - SWITCH_CASE_SET_CALLPTR(1); - SWITCH_CASE_SET_CALLPTR(2); - SWITCH_CASE_SET_CALLPTR(4); - SWITCH_CASE_SET_CALLPTR(8); - SWITCH_CASE_SET_CALLPTR(16); - - default: - call_ptr = nullptr; - break; - } -#undef SWITCH_CASE_SET_CALLPTR - - if (call_ptr) { - call_ptr( - Q, - K, - V, - seq, - qk_scale, - split_k, - split_max, - split_sumexp, - O_splits, - O); - } else { - std::cout << "Warning: no kernel was found for wavefronts_per_block=" - << n_wavefronts_per_block << std::endl; - } - } - return 0; -} - -#endif // MAIN - -#undef AT_DISPATCH_CASE_3 -#undef AT_DISPATCH_SWITCH_3 diff --git a/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h b/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h deleted file mode 100644 index e4d575a588..0000000000 --- a/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h +++ /dev/null @@ -1,713 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include "ck_attention_inner_product.h" -#include "ck_attention_math_ext.h" - -namespace { - -template -__device__ typename ck::vector_type::type scalar_scale_acc( - typename ck::vector_type::type acc, - typename ck::vector_type::type a, - float b) { - union { - decltype(acc) vec; - float arr[vec_size]; - } acc_u{acc}; - union { - decltype(a) vec; - data_t arr[vec_size]; - } a_u{a}; - -#pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - acc_u.arr[i] += ck::type_convert(a_u.arr[i]) * b; - } - - return acc_u.vec; -} - -template -float __device__ __forceinline__ wavefrontReduce(float val, F f) { -#pragma unroll - for (int32_t mask = n_threads_per_wavefront >> 1; mask > 0; mask >>= 1) { - val = f(__shfl_xor(val, mask, n_threads_per_wavefront), val); - } - return val; -} - -template -__forceinline__ __device__ void load_v( - const TData* __restrict__ data_ptr, - int32_t vector_offset, - TDataVec* __restrict__ load_to) { - *load_to = *(reinterpret_cast(data_ptr) + vector_offset); -} - -template -__forceinline__ __device__ void store_v( - TData* __restrict__ data_ptr, - int32_t vector_offset, - TDataVec value) { - *(reinterpret_cast(data_ptr) + vector_offset) = value; -} - -template -__global__ void efficient_attention_forward_decoder_splitk_reduce_ck_kernel( - const scalar_t* __restrict__ O_splits, - const compute_t* __restrict__ split_max, - const compute_t* __restrict__ split_sumexp, - scalar_t* __restrict__ O, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const ptrdiff_t O_stride_split, - const ptrdiff_t O_stride_b, - const ptrdiff_t O_stride_m, - const ptrdiff_t O_stride_g, - const ptrdiff_t O_stride_h, - const int32_t split_k) { - // Each block handles a single batch and head and query and group - const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); - const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; - const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; - const int32_t h = blockIdx.x % Q_size_h; - - using data_t = scalar_t; - using data_vec_t = typename ck::vector_type::type; - using compute_vec_t = typename ck::vector_type::type; - - union { - data_vec_t vec; - data_t arr[vec_size]; - } O_split_data; - union { - compute_vec_t vec; - compute_t arr[vec_size]; - } O_split_compute; - union { - data_vec_t vec; - data_t arr[vec_size]; - } global_O_data; - union { - compute_vec_t vec; - compute_t arr[vec_size]; - } global_O_compute; - - global_O_compute.vec = 0; - - const int32_t lane_idx = threadIdx.x; - const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; - - if (!lane_active_for_io) { - return; - } - - compute_t global_sumexp = 0; - compute_t global_max = ck::NumericLimits::Lowest(); - - for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { - load_v( - O_splits + b * O_stride_b + m * O_stride_m + g * O_stride_g + - h * O_stride_h + split_idx * O_stride_split, - lane_idx, - &O_split_data.vec); -#pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - O_split_compute.arr[i] = ck::type_convert(O_split_data.arr[i]); - } - compute_t local_max = *(split_max + blockIdx.x * split_k + split_idx); - compute_t local_sumexp = *(split_sumexp + blockIdx.x * split_k + split_idx); - - compute_t log_alpha = -std::abs(local_max - global_max); - compute_t alpha = - isnan(log_alpha) ? compute_t{1.} : ck::math::exp(log_alpha); - - bool pick_new = local_max < global_max; - compute_t pick_current_coef = pick_new ? 1. : alpha; - compute_t pick_new_coef = pick_new ? alpha : 1.; - - global_sumexp = - pick_current_coef * global_sumexp + pick_new_coef * local_sumexp; - global_O_compute.vec = pick_current_coef * global_O_compute.vec + - pick_new_coef * O_split_compute.vec; - global_max = ck::math::max(local_max, global_max); - } - global_O_compute.vec /= global_sumexp; -#pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - global_O_data.arr[i] = ck::type_convert(global_O_compute.arr[i]); - } - store_v( - O + b * O_stride_b + m * O_stride_m + g * O_stride_g + h * O_stride_h, - lane_idx, - global_O_data.vec); -} - -template < - typename scalar_t, - int32_t vec_size, - int32_t n_loop_unroll, - int32_t n_loop_unroll_tail, - int32_t KV_M_MAX, - typename compute_t> -__global__ void efficient_attention_forward_decoder_splitk_ck_kernel( - const scalar_t* __restrict__ XQ, - const scalar_t* __restrict__ cache_K, - const scalar_t* __restrict__ cache_V, - scalar_t* __restrict__ O_splits, - compute_t* __restrict__ split_max, - compute_t* __restrict__ split_sumexp, - const int32_t* __restrict__ seq_kv_lens, - const ptrdiff_t XQ_stride_b, - const ptrdiff_t XQ_stride_m, - const ptrdiff_t XQ_stride_g, - const ptrdiff_t XQ_stride_h, - const ptrdiff_t K_stride_b, - const ptrdiff_t K_stride_m, - const ptrdiff_t K_stride_g, - const ptrdiff_t K_stride_h, - const ptrdiff_t O_stride_split, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const int32_t K_size_m, - const bool multiquery, - const float qk_scale, - const int32_t split_k) { - static_assert( - n_loop_unroll_tail < n_loop_unroll || n_loop_unroll_tail == 1, - "tail unroll must be smaller than main loop untoll; pragma unroll 0 is illegal " - "(and tail is no-op)"); - - // Each block handles a single batch and head and query and group - const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); - const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; - const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; - const int32_t h = blockIdx.x % Q_size_h; - const int32_t split_idx = blockIdx.y; - - // Note: this is decoding case where we attend to current and all previous - // tokens. - const int32_t t_max = seq_kv_lens ? seq_kv_lens[b] : K_size_m; - - const int32_t lane_idx = threadIdx.x; - const int32_t wavefront_idx = threadIdx.y; - // TODO: `threads_per_wavefront` and `wavefronts_per_block` may be compile - // time constants; investigate when optimizing - const int32_t threads_per_wavefront = blockDim.x; - const int32_t wavefronts_per_block = blockDim.y; - const int32_t threads_per_block = - threads_per_wavefront * wavefronts_per_block; - const int32_t thread_linear_idx = - lane_idx + wavefront_idx * threads_per_wavefront; - // const auto* q_ = &(XQ_acc[b][m][g][h][0]); - const auto XQO_base_offset = - b * XQ_stride_b + m * XQ_stride_m + g * XQ_stride_g + h * XQ_stride_h; - const auto* __restrict__ q_ = XQ + XQO_base_offset; - - const auto cache_KV_base_offset = b * K_stride_b + 0 * K_stride_m + - g * K_stride_g + (multiquery ? 0 : h * K_stride_h); - const auto* __restrict__ cache_K_base = cache_K + cache_KV_base_offset; - const auto* __restrict__ cache_V_base = cache_V + cache_KV_base_offset; - - using data_t = scalar_t; - using data_vec_t = typename ck::vector_type::type; - using compute_vec_t = typename ck::vector_type::type; - - const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; - - extern __shared__ __align__(16) compute_t smem[]; - - data_vec_t q_thread = 0; - // Load Q into registers in all wavefronts. - // Each thread handles `vec_size` D dimensions - if (lane_active_for_io) { - load_v(q_, lane_idx, &q_thread); - } - - compute_t max_qk_acc = ck::NumericLimits::Lowest(); - - // Compute S[0:t_max] = - // ``` - // for t in range(t_max): - // S[t] = dot(Q, K[t]) - // ``` - // Split the 0:t_max range across wavefronts in a block, - // unroll loads to expose more parallelism. - // Reduce the dot product with cross-lane operation; - // Q and K[t] are in the registers of threads in a single wavefront. - - data_vec_t k_loads[n_loop_unroll] = {}; - - const auto dtt = wavefronts_per_block * n_loop_unroll; - // only last split gets the tail. - // the first (split_k - 1) splits have a number of iterations divisible by - // `dtt` - const auto n_unrolled_loops = t_max / dtt / split_k; // +1? - const int32_t tt_low = - wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * split_idx; - const int32_t tt_high = - wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * (split_idx + 1); - const int32_t dtt_tail = wavefronts_per_block * n_loop_unroll_tail; - const int32_t tt_tail_low = wavefront_idx * n_loop_unroll_tail + - n_unrolled_loops * dtt * (split_idx + 1); - const int32_t tt_tail_high = (split_idx == split_k - 1) ? t_max : tt_tail_low; - - for (auto tt = tt_low; tt < tt_high; tt += dtt) { - if (lane_active_for_io) { -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - const int32_t t = tt + ttt; - // load the K[b][t][g][h|0][:] row into registers - load_v( - cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); - } - } -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - compute_t qk_acc = 0; - ck::inner_product( - q_thread, k_loads[ttt], qk_acc); - qk_acc *= qk_scale; - - qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); - max_qk_acc = ck::math::max(qk_acc, max_qk_acc); - if (lane_idx == 0) { - smem[tt + ttt - n_unrolled_loops * dtt * split_idx] = qk_acc; - } - } - } - - for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { - if (lane_active_for_io) { -#pragma unroll n_loop_unroll_tail - for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { - const int32_t t = tt + ttt; - if (t < t_max) { - // load the K[b][t][g][h|0][:] row into registers - load_v( - cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); - } - } - } -#pragma unroll n_loop_unroll_tail - for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { - compute_t qk_acc = 0; - const int32_t t = tt + ttt; - if (t < t_max) { - ck::inner_product( - q_thread, k_loads[ttt], qk_acc); - qk_acc *= qk_scale; - - qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); - max_qk_acc = ck::math::max(qk_acc, max_qk_acc); - - // write accumulated sums to smem. - if (lane_idx == 0) { - smem[t - n_unrolled_loops * dtt * split_idx] = qk_acc; - } - } - } - } - - // Use shared reduction to compute max and compute softmax on shared memory. - // write max acc - if (lane_idx == 0) { - smem[KV_M_MAX + wavefront_idx] = max_qk_acc; - } - __syncthreads(); - if (lane_idx < wavefronts_per_block) { - max_qk_acc = ck::math::max(max_qk_acc, smem[KV_M_MAX + lane_idx]); - } - // shared across all threads in block - max_qk_acc = - wavefrontReduce(max_qk_acc, [](auto a, auto b) { return a > b ? a : b; }); - - if (wavefront_idx == 0 && lane_idx == 0) { - split_max[blockIdx.x * split_k + split_idx] = max_qk_acc; - } - - // each wavefront computes partial sum of exp. - { // softmax reduce begin - compute_t softmax_denominator = 0.0f; - const int32_t t_low = n_unrolled_loops * dtt * split_idx; - const int32_t t_high = (split_idx + 1 < split_k) - ? n_unrolled_loops * dtt * (split_idx + 1) - : t_max; - for (int32_t t = t_low + thread_linear_idx; t < t_high; - t += threads_per_block) { - const auto s = ck::math::exp(smem[t - t_low] - max_qk_acc); - softmax_denominator += s; - smem[t - t_low] = s; - } - softmax_denominator = wavefrontReduce( - softmax_denominator, [](auto a, auto b) { return a + b; }); - - if (lane_idx == 0) { - smem[KV_M_MAX + wavefront_idx] = softmax_denominator; - } - __syncthreads(); - - // now, compute sum of exp(x - max(x)) over all intermediate results. - softmax_denominator = 0.0; - if (lane_idx < wavefronts_per_block) { - softmax_denominator = smem[KV_M_MAX + lane_idx]; - } - softmax_denominator = wavefrontReduce( - softmax_denominator, [](auto a, auto b) { return a + b; }); - - if (wavefront_idx == 0 && lane_idx == 0) { - split_sumexp[blockIdx.x * split_k + split_idx] = softmax_denominator; - } - } // softmax reduce end - - // Split T across wavefronts in a block - // each wavefront compute sum(t_subset) P[t] * V[t_subset, d] - // outputs are of size float[D] - - compute_t ps[n_loop_unroll] = {}; - compute_vec_t o_acc = 0; - if (lane_active_for_io) { - for (auto tt = tt_low; tt < tt_high; tt += dtt) { -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - const int32_t t = tt + ttt; - // load the V[b][t][g][h|0][:] row into registers, reusing K register - // storage - load_v( - cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); - ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; - } - -#pragma unroll n_loop_unroll - for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { - o_acc = - scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); - } - } - - for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { -#pragma unroll n_loop_unroll_tail - for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { - const int32_t t = tt + ttt; - if (t < t_max) { - // load the V[b][t][g][h|0][:] row into registers, reusing K register - // storage - load_v( - cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); - ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; - o_acc = - scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); - } - } - } - } - __syncthreads(); - - // NB: needs sizeof(smem) >= `vec_size` * (sizeof(float)==4) * threadsPerBlock - if (lane_active_for_io) { - store_v(&smem[0], thread_linear_idx, o_acc); - } - - __syncthreads(); - // sum up partial D rows from other wavefronts - if (wavefront_idx == 0 && lane_active_for_io) { - union { - compute_vec_t vec = 0; - compute_t arr[vec_size]; - } r; - for (int32_t w = 0; w < wavefronts_per_block; ++w) { - compute_vec_t partial_r; - load_v( - smem, w * threads_per_wavefront + lane_idx, &partial_r); - r.vec += partial_r; - } - // elementwise convert from compute_t result to data_t out to be written - union { - data_vec_t vec; - data_t arr[vec_size]; - } bf_r; -#pragma unroll - for (int32_t i = 0; i < vec_size; ++i) { - bf_r.arr[i] = ck::type_convert(r.arr[i]); - } - // write output row O[b][m][g][h][:] - data_t* __restrict__ o_ = - O_splits + XQO_base_offset + split_idx * O_stride_split; - store_v(o_, lane_idx, bf_r.vec); - } -} - -} // namespace - -namespace ck { -namespace tensor_operation { -namespace device { -template < - typename scalar_t, - int32_t KV_M_MAX, - int32_t n_loop_unroll, - int32_t n_loop_unroll_tail, - typename compute_t> -struct FMHADecoderSplitKDeviceOp : public BaseOperator { - using DeviceOp = FMHADecoderSplitKDeviceOp; - struct Argument : public BaseArgument { - const scalar_t* __restrict__ XQ; - const scalar_t* __restrict__ cache_K; - const scalar_t* __restrict__ cache_V; - scalar_t* __restrict__ O; - scalar_t* __restrict__ split_O; - compute_t* __restrict__ split_max; - compute_t* __restrict__ split_sumexp; - const int32_t* __restrict__ seq_kv_lens; - const ptrdiff_t XQ_stride_b; - const ptrdiff_t XQ_stride_m; - const ptrdiff_t XQ_stride_g; - const ptrdiff_t XQ_stride_h; - const ptrdiff_t K_stride_b; - const ptrdiff_t K_stride_m; - const ptrdiff_t K_stride_g; - const ptrdiff_t K_stride_h; - const ptrdiff_t O_stride_split; - const int32_t Q_size_m; - const int32_t Q_size_g; - const int32_t Q_size_h; - const int32_t Q_size_k; - const int32_t K_size_m; - const bool multiquery; - const float qk_scale; - const int32_t split_k; - - const dim3 grid_dim; - const dim3 block_dim; - const size_t lds_bytes; - - Argument( - const scalar_t* __restrict__ XQ, - const scalar_t* __restrict__ cache_K, - const scalar_t* __restrict__ cache_V, - scalar_t* __restrict__ O, - scalar_t* __restrict__ split_O, - compute_t* __restrict__ split_max, - compute_t* __restrict__ split_sumexp, - const int32_t* __restrict__ seq_kv_lens, - const ptrdiff_t XQ_stride_b, - const ptrdiff_t XQ_stride_m, - const ptrdiff_t XQ_stride_g, - const ptrdiff_t XQ_stride_h, - const ptrdiff_t K_stride_b, - const ptrdiff_t K_stride_m, - const ptrdiff_t K_stride_g, - const ptrdiff_t K_stride_h, - const ptrdiff_t O_stride_split, - const int32_t Q_size_m, - const int32_t Q_size_g, - const int32_t Q_size_h, - const int32_t Q_size_k, - const int32_t K_size_m, - const bool multiquery, - const float qk_scale, - const int32_t split_k, - // launch params - const dim3 grid_dim, - const dim3 block_dim, - const size_t lds_bytes) - : XQ(XQ), - cache_K(cache_K), - cache_V(cache_V), - O(O), - split_O(split_O), - split_max(split_max), - split_sumexp(split_sumexp), - seq_kv_lens(seq_kv_lens), - XQ_stride_b(XQ_stride_b), - XQ_stride_m(XQ_stride_m), - XQ_stride_g(XQ_stride_g), - XQ_stride_h(XQ_stride_h), - K_stride_b(K_stride_b), - K_stride_m(K_stride_m), - K_stride_g(K_stride_g), - K_stride_h(K_stride_h), - O_stride_split(O_stride_split), - Q_size_m(Q_size_m), - Q_size_g(Q_size_g), - Q_size_h(Q_size_h), - Q_size_k(Q_size_k), - K_size_m(K_size_m), - multiquery(multiquery), - qk_scale(qk_scale), - split_k(split_k), - // launch params - grid_dim(grid_dim), - block_dim(block_dim), - lds_bytes(lds_bytes) {} - - std::string str() const { - std::ostringstream oss; - oss << "Argument { " << std::endl - << " XQ: " << XQ << std::endl - << " cache_K: " << cache_K << std::endl - << " cache_V: " << cache_V << std::endl - << " O: " << O << std::endl - << " split_O: " << split_O << std::endl - << " split_max: " << split_max << std::endl - << " split_sumexp: " << split_sumexp << std::endl - << " seq_kv_lens: " << seq_kv_lens << std::endl - << " XQ_stride_b: " << XQ_stride_b << std::endl - << " XQ_stride_m: " << XQ_stride_m << std::endl - << " XQ_stride_g: " << XQ_stride_g << std::endl - << " XQ_stride_h: " << XQ_stride_h << std::endl - << " K_stride_b: " << K_stride_b << std::endl - << " K_stride_m: " << K_stride_m << std::endl - << " K_stride_g: " << K_stride_g << std::endl - << " K_stride_h: " << K_stride_h << std::endl - << " O_stride_split: " << O_stride_split << std::endl - << " Q_size_m: " << Q_size_m << std::endl - << " Q_size_g: " << Q_size_g << std::endl - << " Q_size_h: " << Q_size_h << std::endl - << " Q_size_k: " << Q_size_k << std::endl - << " K_size_m: " << K_size_m << std::endl - << " multiquery: " << multiquery << std::endl - << " qk_scale: " << qk_scale << std::endl - << " split_k: " << split_k << std::endl - << std::endl - << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." - << grid_dim.z << std::endl - << " block_dim: " << block_dim.x << "." << block_dim.y << "." - << block_dim.z << std::endl - << " lds_bytes: " << lds_bytes << std::endl - << "}"; - return oss.str(); - } - }; - - struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run( - const BaseArgument* argp_, - const StreamConfig& stream_config = StreamConfig{}) { - const Argument* argp = dynamic_cast(argp_); - - auto threads_per_wavefront = argp->block_dim.x; - auto Q_size_k_alignment_necessary = 0; - - for (auto vec_size : {4, 2, 1}) { - if (argp->Q_size_k <= vec_size * threads_per_wavefront) { - Q_size_k_alignment_necessary = vec_size; - } - } - - if (!Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported Q_size_k"); - } - - if (argp->Q_size_k % Q_size_k_alignment_necessary) { - throw std::runtime_error("Unsupported alignment for Q_size_k"); - } - - float split_attention_result = launch_and_time_kernel( - stream_config, - Q_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - /* vec_size */ 4, - n_loop_unroll, - n_loop_unroll_tail, - KV_M_MAX, - compute_t> - : Q_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - /* vec_size */ 2, - n_loop_unroll, - n_loop_unroll_tail, - KV_M_MAX, - compute_t> - : Q_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_ck_kernel< - scalar_t, - /* vec_size */ 1, - n_loop_unroll, - n_loop_unroll_tail, - KV_M_MAX, - compute_t> - : nullptr, - argp->grid_dim, - argp->block_dim, - argp->lds_bytes, - argp->XQ, - argp->cache_K, - argp->cache_V, - argp->split_O, - argp->split_max, - argp->split_sumexp, - argp->seq_kv_lens, - argp->XQ_stride_b, - argp->XQ_stride_m, - argp->XQ_stride_g, - argp->XQ_stride_h, - argp->K_stride_b, - argp->K_stride_m, - argp->K_stride_g, - argp->K_stride_h, - argp->O_stride_split, - argp->Q_size_m, - argp->Q_size_g, - argp->Q_size_h, - argp->Q_size_k, - argp->K_size_m, - argp->multiquery, - argp->qk_scale, - argp->split_k); - - const dim3 reduce_gridsize = {argp->grid_dim.x}; - const dim3 reduce_blocksize = {argp->block_dim.x}; - constexpr int32_t reduce_lds_bytes = 0; - float reduce_result = launch_and_time_kernel( - stream_config, - Q_size_k_alignment_necessary == 4 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 4> - : Q_size_k_alignment_necessary == 2 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 2> - : Q_size_k_alignment_necessary == 1 - ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< - scalar_t, - 1> - : nullptr, - reduce_gridsize, - reduce_blocksize, - reduce_lds_bytes, - argp->split_O, - argp->split_max, - argp->split_sumexp, - argp->O, - argp->Q_size_m, - argp->Q_size_g, - argp->Q_size_h, - argp->Q_size_k, - argp->O_stride_split, - argp->XQ_stride_b, - argp->XQ_stride_m, - argp->XQ_stride_g, - argp->XQ_stride_h, - argp->split_k); - return split_attention_result + reduce_result; - } - }; -}; -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/xformers/csrc/attention/hip_fmha/ck_fmha_util.h b/xformers/csrc/attention/hip_fmha/ck_fmha_util.h index b782f96ee0..7ce9f03c4b 100644 --- a/xformers/csrc/attention/hip_fmha/ck_fmha_util.h +++ b/xformers/csrc/attention/hip_fmha/ck_fmha_util.h @@ -75,7 +75,7 @@ static inline size_t get_size_in_bytes(size_t n, at::ScalarType dtype) { * expand the bias as needed - be careful to only create a view with different * shape/strides, no copies allowed. */ -inline at::Tensor get_bias_4d_view( +static inline at::Tensor get_bias_4d_view( const at::Tensor& bias, int batch_sz, int n_heads, @@ -108,3 +108,15 @@ inline at::Tensor get_bias_4d_view( TORCH_CHECK(false, "bias can only have ndims in {2, 3, 4}"); } } + +static inline int get_number_of_cu() { + int device; + + HIP_CALL_CHECK(hipGetDevice(&device)); + + hipDeviceProp_t props; + + HIP_CALL_CHECK(hipGetDeviceProperties(&props, device)); + + return props.multiProcessorCount; +} diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h index 8bcb29bee8..dbb9f451b0 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h @@ -7,7 +7,8 @@ #pragma once #include -#include +#include +#include #include #include @@ -17,12 +18,12 @@ template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasBiasGrad, bool kHasDropout, ck_tile::index_t MaxK> -struct batched_backward_causalmask_bias_dropout_dispatch { +struct batched_backward_mask_bias_dropout_dispatch { using FmhaBlockDropout = typename FmhaBwdBlockDropoutMaker::dropout; @@ -59,8 +60,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch { constexpr ck_tile::index_t kBlockSize = 64; const bool pad_seqlen_q = !(param.M % kBlockSize == 0); - const bool pad_headdim_v = - !(param.Kv % FmhaBwdShape::kVHeaddim == 0); + const bool pad_headdim_v = !(param.Kv % MaxK == 0); BOOL_SWITCH_2( pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] { @@ -77,7 +77,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch { typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, kBlockSize, - FmhaBwdShape::kVHeaddim, + MaxK, // kVHeaddim false, // kIsGroupMode FmhaOGradDotOTraits_>; @@ -93,83 +93,73 @@ struct batched_backward_causalmask_bias_dropout_dispatch { } { - const bool has_local_attention = (param.window_size > 0) ? true : false; - - BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { - constexpr ck_tile::index_t occupancy = 1; - constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION; - - using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - - constexpr auto kBiasEnum = kHasBias - ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS - : ck_tile::BlockAttentionBiasEnum::NO_BIAS; - - constexpr bool kPadSeqLenQ = true; - constexpr bool kPadSeqLenK = true; - - const bool pad_headdim_q = - !(param.K % FmhaBwdShape::kQKHeaddim == 0); - const bool pad_headdim_v = - !(param.Kv % FmhaBwdShape::kVHeaddim == 0); - - // usually headdim_q and headdim_v are same, consider them together - // to determine whether to do padding saving some compiling time - const bool pad_headdim = (pad_headdim_q || pad_headdim_v); - - BOOL_SWITCH(pad_headdim, kPadHeadDim, [&] { - using FmhaBwdTraits_ = ck_tile::TileFmhaTraits< - kPadSeqLenQ, - kPadSeqLenK, - kPadHeadDim, // kPadHeadDimQ, - kPadHeadDim, // kPadHeadDimV, - kBiasEnum, - kHasBiasGrad, - false, // kStoreLSE - false, // place-holder for kHasDropout, not used actually - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaBwdPipelineProblem = - FmhaBwdPipelineProblemTemp; - - constexpr auto FmhaBwdPipelineEnum_ = - FmhaBwdPipelineEnumSelector:: - value; - - using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker< - FmhaBwdPipelineEnum_, - FmhaBwdPipelineProblem>::pipeline; - - using FmhaBwdKGradEpilogue_ = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig::KGradDataType, - kPadSeqLenK, - kPadHeadDim>>; - - using FmhaBwdVGradEpilogue_ = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig::VGradDataType, - kPadSeqLenK, - kPadHeadDim>>; - - using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel< - FmhaBwdPipeline_, - FmhaBwdKGradEpilogue_, - FmhaBwdVGradEpilogue_>; - - RunWithBwdDQDKDVKernel(param, stream); - }); - }); + constexpr ck_tile::index_t occupancy = 1; + + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + const bool pad_headdim_q = + !(param.K % FmhaBwdShape::kQKHeaddim == 0); + const bool pad_headdim_v = + !(param.Kv % FmhaBwdShape::kVHeaddim == 0); + + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaBwdTraits_ = ck_tile::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + kHasBiasGrad, + false, // kStoreLSE + false, // place-holder for kHasDropout, not used actually + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaBwdPipelineProblem = + FmhaBwdPipelineProblemTemp; + + constexpr auto FmhaBwdPipelineEnum_ = + FmhaBwdPipelineEnumSelector::value; + + using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker< + FmhaBwdPipelineEnum_, + FmhaBwdPipelineProblem>::pipeline; + + using FmhaBwdKGradEpilogue_ = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + kPadSeqLenK, + kPadHeadDimQ>>; + + using FmhaBwdVGradEpilogue_ = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + kPadSeqLenK, + kPadHeadDimV>>; + + using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel< + FmhaBwdPipeline_, + FmhaBwdKGradEpilogue_, + FmhaBwdVGradEpilogue_>; + + RunWithBwdDQDKDVKernel(param, stream); + }); }; if constexpr (NeedConvertGradQ) { constexpr ck_tile::index_t kBlockSize = 256; const bool pad_seqlen_q = !(param.M % kBlockSize == 0); - const bool pad_headdim_q = - !(param.K % FmhaBwdShape::kQKHeaddim == 0); + const bool pad_headdim_q = !(param.K % MaxK == 0); BOOL_SWITCH_2( pad_seqlen_q, kPadSeqLenQ, pad_headdim_q, kPadHeadDimQ, [&] { @@ -188,7 +178,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch { kBlockSize, FmhaBwdShape::kM0, FmhaBwdShape::kN0, - FmhaBwdShape::kQKHeaddim, + MaxK, // kQKHeaddim false, // kIsGroupMode false, // kIsDeterministic FmhaBwdConvertQGradTraits_>; @@ -309,7 +299,7 @@ struct batched_backward_causalmask_bias_dropout_dispatch { (param.custom_mask_type == 0) ? -1 : 0, // window_right_size param.custom_mask_type, param.dropout_prob, // dropout ratio - {param.philox_seed, param.philox_offset}); + std::make_pair(param.philox_seed, param.philox_offset)); }(); dim3 kGridSize = FmhaBwdDQDKDVKernel::GridSize(param.B, param.Hq, param.N); @@ -357,17 +347,17 @@ struct batched_backward_causalmask_bias_dropout_dispatch { template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasBiasGrad, bool kHasDropout, ck_tile::index_t MaxK> -void run_batched_backward_causalmask_bias_dropout_dispatch( +void run_batched_backward_mask_bias_dropout_dispatch( BatchedBackwardParams& param, hipStream_t stream) { - batched_backward_causalmask_bias_dropout_dispatch< + batched_backward_mask_bias_dropout_dispatch< ScalarType, - kHasCausalMask, + kHasMask, kHasBias, kHasBiasGrad, kHasDropout, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp index 3cf339b834..f6d6fb4eb6 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_bf16.cpp @@ -25,16 +25,18 @@ void batched_backward_bf16(BatchedBackwardParams& param, hipStream_t stream) { [&] { if constexpr (kHasBias || !kHasBiasGrad) { FMHA_BWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_batched_backward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, kHasBias, kHasBiasGrad, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_batched_backward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp index 807169ccd0..342677ae88 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward_fp16.cpp @@ -25,16 +25,18 @@ void batched_backward_fp16(BatchedBackwardParams& param, hipStream_t stream) { [&] { if constexpr (kHasBias || !kHasBiasGrad) { FMHA_BWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_batched_backward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, kHasBias, kHasBiasGrad, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_batched_backward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h index 20c1b2c3ef..a79887c55b 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h @@ -6,192 +6,68 @@ */ #pragma once -#include -#include -#include -#include - -#include "ck_tiled_bool_switch.h" +#include +#include "ck_tiled_fmha_batched_forward_dispatch.h" +#include "ck_tiled_fmha_batched_forward_splitkv_dispatch.h" +#include "ck_tiled_fmha_batched_forward_splitkv_smallq_dispatch.h" #include "ck_tiled_fmha_fwd_setting.h" -#include "ck_tiled_fmha_params.h" - -template < - typename ScalarType, - bool kHasCausalMask, - bool kHasBias, - bool kHasDropout, - ck_tile::index_t MaxK> -struct batched_forward_causalmask_bias_dropout_dispatch { - template - using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - FmhaFwdShape, - false, // kIsGroupMode - FmhaMask, - FmhaTraits>; - - static void Run(BatchedForwardParams& param, hipStream_t stream) { - const bool has_local_attention = (param.window_size > 0) ? true : false; - - BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { - constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION; - - using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - - using FmhaFwdShape_ = FmhaFwdShape; - using FmhaFwdTilePartitioner_ = - ck_tile::FmhaFwdTilePartitioner; - constexpr ck_tile::index_t occupancy = - (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); - - constexpr auto kBiasEnum = kHasBias - ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS - : ck_tile::BlockAttentionBiasEnum::NO_BIAS; - - const bool pad_seqlen_q = !(param.M % FmhaFwdShape_::kM0 == 0); - const bool pad_seqlen_k = - (param.N == 0) || !(param.N % FmhaFwdShape_::kN0 == 0); - const bool pad_headdim_q = - !(param.K % FmhaFwdShape_::kK0BlockLength == 0); - const bool pad_headdim_v = !(param.Kv % FmhaFwdShape_::kN1 == 0); - - // usually headdim_q and headdim_v are same, consider them together to - // determine whether to do padding saving some compiling time - const bool pad_headdim = (pad_headdim_q || pad_headdim_v); - - const bool use_async_pipeline = - ((param.K % 8 == 0) && (param.Kv % 8 == 0) && (MaxK <= 128)); - - BOOL_SWITCH_3( - pad_seqlen_q, - kPadSeqLenQ, - pad_seqlen_k, - kPadSeqLenK, - pad_headdim, - kPadHeadDim, - [&] { - using FmhaFwdTraits_ = ck_tile::TileFmhaTraits< - kPadSeqLenQ, - kPadSeqLenK, - kPadHeadDim, // kPadHeadDimQ - kPadHeadDim, // kPadHeadDimV - kBiasEnum, - false, // kHasBiasGrad place-holder - true, // kStoreLSE - kHasDropout, - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaPipelineProblem = - FmhaPipelineProblemTemp; - - using FmhaFwdPipeline_ = - ck_tile::BlockFmhaPipelineQRKSVS; - - using FmhaFwdEpilogue_ = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - kPadSeqLenQ, - kPadHeadDim>>; - - using FmhaFwdKernel_ = ck_tile::FmhaFwdKernel< - FmhaFwdTilePartitioner_, - FmhaFwdPipeline_, - FmhaFwdEpilogue_>; - - RunWithKernel(param, stream); - }); - }); - }; - - template - static void RunWithKernel(BatchedForwardParams& param, hipStream_t stream) { - const auto kargs = [&] { - return FmhaFwdKernel::MakeKargs( - param.q_ptr, - param.k_ptr, - param.v_ptr, - param.attn_bias_ptr, - nullptr, // rand_val_ptr - param.logsumexp_ptr, - param.out_ptr, - param.M, // seqlen_q - param.N, // seqlen_k - param.K, // hdim_q - param.Kv, // hdim_v - param.Hq, // nhead_q - param.Hq / param.Hkv, // nhead_ratio_qk - param.scale, - 1.0f, // scale_p - 1.0f, // scale_o - param.q_strides[1], // q, k, v, bias, randval, out tensor seq-dim - // stride - param.k_strides[1], - param.v_strides[1], - param.attn_bias_strides[2], - 0, // stride_randval - param.out_strides[1], - param.q_strides[2], // q, k, v, bias, randval, lse, out tensor - // head-dim stride - param.k_strides[2], - param.v_strides[2], - param.attn_bias_strides[1], - 0, // nhead_randva - param.lse_strides[1], // nhead_stride_lse - param.out_strides[2], - param.q_strides[0], // q, k, v, bias, randval, lse, out tensor - // batch-dim stride - param.k_strides[0], - param.v_strides[0], - param.attn_bias_strides[0], - 0, // batch_stride_randval - param.lse_strides[0], // batch_stride_lse - param.out_strides[0], - (param.window_size > 0) ? param.window_size - 1 - : -1, // window_left_size - (param.custom_mask_type == 0) ? -1 : 0, // window_right_size - param.custom_mask_type, - param.dropout_prob, // dropout ratio - false, // is_store_randval - {param.philox_seed, param.philox_offset}); - }(); - - dim3 kGridSize = - FmhaFwdKernel::GridSize(param.B, param.Hq, param.M, param.Kv); - constexpr dim3 kBlockSize = FmhaFwdKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = FmhaFwdKernel::kBlockPerCu; - - (void)ck_tile::launch_kernel( - ck_tile::stream_config{stream, false}, - ck_tile::make_kernel( - FmhaFwdKernel{}, kGridSize, kBlockSize, 0, kargs)); - }; -}; +#include "ck_tiled_fmha_fwd_splitkv_smallq_selector.h" +#include "ck_tiled_fmha_seqlen_q_switch.h" template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasDropout, ck_tile::index_t MaxK> -void run_batched_forward_causalmask_bias_dropout_dispatch( +void run_batched_forward_mask_bias_dropout_dispatch( BatchedForwardParams& param, hipStream_t stream) { - batched_forward_causalmask_bias_dropout_dispatch< - ScalarType, - kHasCausalMask, - kHasBias, - kHasDropout, - MaxK>::Run(param, stream); + // currently split-kv implementation does not support dropout + if constexpr (!kHasDropout) { + if (param.use_split_kv) { + if (use_splitkv_smallq(param.M, std::max(param.K, param.Kv))) { + batched_forward_splitkv_smallq_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK>::Run(param, stream); + } else { + FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] { + batched_forward_splitkv_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK, + MaxSeqlenQ>::Run(param, stream); + }); + } + } else { + if (get_fmha_fwd_mtile(param.B, param.Hq, param.M) == 128) + batched_forward_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + else + batched_forward_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 64>::Run(param, stream); + } + } else { + // at present, dropout of fwd kernel requires 32x32 WarpTile + batched_forward_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + } }; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp index bd2e076e0c..216dab5347 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bf16.cpp @@ -17,15 +17,17 @@ void batched_forward_bf16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_batched_forward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_batched_forward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_dispatch.h new file mode 100644 index 0000000000..6fdd1c6bb5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_dispatch.h @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_setting.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + bool kHasDropout, + ck_tile::index_t MaxK, + ck_tile::index_t MTile> +struct batched_forward_mask_bias_dropout_dispatch { + template + using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + typename FmhaFwdShape::Type, + false, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaFwdShape_ = typename FmhaFwdShape::Type; + constexpr ck_tile::index_t occupancy = + (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + const bool pad_seqlen_q = !(param.M % FmhaFwdShape_::kM0 == 0); + const bool pad_seqlen_k = + (param.N == 0) || !(param.N % FmhaFwdShape_::kN0 == 0); + const bool pad_headdim_q = !(param.K % FmhaFwdShape_::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.Kv % FmhaFwdShape_::kN1 == 0); + + // usually headdim_q and headdim_v are same, consider them together to + // determine whether to do padding saving some compiling time + const bool pad_headdim = (pad_headdim_q || pad_headdim_v); + + const bool use_async_pipeline = + ((param.K % 8 == 0) && (param.Kv % 8 == 0) && (MaxK <= 128)); + + BOOL_SWITCH_3( + pad_seqlen_q, + kPadSeqLenQ, + pad_seqlen_k, + kPadSeqLenK, + pad_headdim, + kPadHeadDim, + [&] { + using FmhaFwdTraits_ = ck_tile::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ + kPadHeadDim, // kPadHeadDimV + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + kHasDropout, + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaFwdPipeline_ = + ck_tile::BlockFmhaPipelineQRKSVS; + + using FmhaFwdEpilogue_ = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDim>>; + + using FmhaFwdKernel_ = + ck_tile::FmhaFwdKernel; + + RunWithKernel(param, stream); + }); + }; + + template + static void RunWithKernel(BatchedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaFwdKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // rand_val_ptr + param.logsumexp_ptr, + param.out_ptr, + param.M, // seqlen_q + param.N, // seqlen_k + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + 1.0f, // scale_p + 1.0f, // scale_o + param.q_strides[1], // q, k, v, bias, randval, out tensor seq-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + 0, // stride_randval + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, randval, lse, out tensor + // head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + 0, // nhead_randva + param.lse_strides[1], // nhead_stride_lse + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, randval, lse, out tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + 0, // batch_stride_randval + param.lse_strides[0], // batch_stride_lse + param.out_strides[0], + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type, + param.dropout_prob, // dropout ratio + false, // is_store_randval + std::make_pair(param.philox_seed, param.philox_offset)); + }(); + + dim3 kGridSize = + FmhaFwdKernel::GridSize(param.B, param.Hq, param.M, param.Kv, false); + constexpr dim3 kBlockSize = FmhaFwdKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp index 3c3791bdfb..e1d2e95557 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp @@ -17,15 +17,17 @@ void batched_forward_fp16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_batched_forward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_batched_forward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_dispatch.h new file mode 100644 index 0000000000..df1ece8930 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_dispatch.h @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK, + ck_tile::index_t MaxSeqlenQ> +struct batched_forward_splitkv_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVShape::Type, + false, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + false, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + const bool pad_seqlen_q = !(param.M % FmhaTileShape::kM0 == 0); + const bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + const bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + + // usually headdim_q and headdim_v are same, consider them together to + // determine whether to do padding saving some compiling time + const bool pad_headdim = (pad_headdim_q || pad_headdim_v); + + const bool has_uneven_splits = + !(param.N % (param.num_kv_splits * FmhaTileShape::kN0) == 0); + + BOOL_SWITCH_3( + pad_seqlen_q, + kPadSeqLenQ, + pad_headdim, + kPadHeadDim, + has_uneven_splits, + kHasUnevenSplits, + [&] { + constexpr bool kPadSeqLenK = kHasUnevenSplits ? true : false; + + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + kHasUnevenSplits, + occupancy>; + + if (param.num_kv_splits > 1) { + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + } + + if (param.num_kv_splits > 1) { + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + const bool pad_seqlen_q = !(param.M % kM0 == 0); + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH_2( + pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH( + param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + FmhaPipelineProblem>; + + using FmhaEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem< + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = ck_tile:: + FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_acc_strides[2], + param.q_strides[2], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.lse_acc_strides[2], + param.out_acc_strides[3], + param.q_strides[0], // q, k, v, bias, lse_acc, out_acc tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.lse_acc_strides[1], + param.out_acc_strides[1], + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, lse, out tensor head-dim + // stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.lse_strides[1], + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, lse, out tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.lse_strides[0], + param.out_strides[0], + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.B, param.Hq, param.M, param.Kv, param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.B, // batches + param.M, // seqlen_q + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[2], // row_stride_o_acc + param.out_strides[1], // row_stride_o + param.lse_acc_strides[2], // head_stride_lse_acc + param.out_acc_strides[3], // head_stride_o_acc + param.lse_strides[1], // head_stride_lse + param.out_strides[2], // head_stride_o + param.lse_acc_strides[1], // batch_stride_lse_acc + param.out_acc_strides[1], // batch_stride_o_acc + param.lse_strides[0], // batch_stride_lse + param.out_strides[0], // batch_stride_o + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0]); // split_stride_out_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.B, param.Hq, param.M, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_smallq_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_smallq_dispatch.h new file mode 100644 index 0000000000..806a507fd2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_splitkv_smallq_dispatch.h @@ -0,0 +1,357 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_smallq_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK> +struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVSmallQShape::Type, + false, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + false, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + const bool pad_seqlen_q = !(param.M % FmhaTileShape::kM0 == 0); + const bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + const bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + + // usually headdim_q and headdim_v are same, consider them together to + // determine whether to do padding saving some compiling time + const bool pad_headdim = (pad_headdim_q || pad_headdim_v); + + const bool has_uneven_splits = + !(param.N % (param.num_kv_splits * FmhaTileShape::kN0) == 0); + + BOOL_SWITCH_3( + pad_seqlen_q, + kPadSeqLenQ, + pad_headdim, + kPadHeadDim, + has_uneven_splits, + kHasUnevenSplits, + [&] { + constexpr bool kPadSeqLenK = kHasUnevenSplits ? true : false; + + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + kHasUnevenSplits, + occupancy>; + + if (param.num_kv_splits > 1) { + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + } + + if (param.num_kv_splits > 1) { + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + const bool pad_seqlen_q = !(param.M % kM0 == 0); + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH_2( + pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH( + param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + FmhaPipelineProblem>; + + using FmhaEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem< + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = ck_tile:: + FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_acc_strides[2], + param.q_strides[2], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.lse_acc_strides[2], + param.out_acc_strides[3], + param.q_strides[0], // q, k, v, bias, lse_acc, out_acc tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.lse_acc_strides[1], + param.out_acc_strides[1], + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, lse, out tensor head-dim + // stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.lse_strides[1], + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, lse, out tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.lse_strides[0], + param.out_strides[0], + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.B, param.Hq, param.M, param.Kv, param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.B, // batches + param.M, // seqlen_q + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[2], // row_stride_o_acc + param.out_strides[1], // row_stride_o + param.lse_acc_strides[2], // head_stride_lse_acc + param.out_acc_strides[3], // head_stride_o_acc + param.lse_strides[1], // head_stride_lse + param.out_strides[2], // head_stride_o + param.lse_acc_strides[1], // batch_stride_lse_acc + param.out_acc_strides[1], // batch_stride_o_acc + param.lse_strides[0], // batch_stride_lse + param.out_strides[0], // batch_stride_o + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0]); // split_stride_out_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.B, param.Hq, param.M, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h index 36cf1b56e7..06b3b66232 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h @@ -6,225 +6,68 @@ */ #pragma once -#include -#include -#include -#include - -#include "ck_tiled_bool_switch.h" +#include +#include "ck_tiled_fmha_batched_infer_dispatch.h" +#include "ck_tiled_fmha_batched_infer_splitkv_dispatch.h" +#include "ck_tiled_fmha_batched_infer_splitkv_smallq_dispatch.h" #include "ck_tiled_fmha_fwd_setting.h" -#include "ck_tiled_fmha_params.h" -#include "ck_tiled_headdim_switch.h" - -template < - typename ScalarType, - bool kHasCausalMask, - bool kHasBias, - bool kHasDropout, - ck_tile::index_t MaxK> -struct batched_infer_causalmask_bias_dropout_dispatch { - template - using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - FmhaFwdShape, - false, // kIsGroupMode - FmhaMask, - FmhaTraits>; - - static void Run(BatchedForwardParams& param, hipStream_t stream) { - const bool has_local_attention = (param.window_size > 0) ? true : false; - - BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { - constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION; - - using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - - using FmhaShape = FmhaFwdShape; - using FmhaTilePartitioner = ck_tile::FmhaFwdTilePartitioner; - constexpr ck_tile::index_t occupancy = - (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); - - constexpr auto kBiasEnum = kHasBias - ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS - : ck_tile::BlockAttentionBiasEnum::NO_BIAS; - - const bool pad_seqlen_q = !(param.M % FmhaShape::kM0 == 0); - const bool pad_seqlen_k = - (param.N == 0) || !(param.N % FmhaShape::kN0 == 0); - const bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); - const bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); - - // usually headdim_q and headdim_v are same, consider them together to - // determine whether to do padding saving some compiling time - const bool pad_headdim = (pad_headdim_q || pad_headdim_v); - - const bool use_async_pipeline = - (!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) && - (MaxK <= 128)); - - if (!use_async_pipeline) { - BOOL_SWITCH_3( - pad_seqlen_q, - kPadSeqLenQ, - pad_seqlen_k, - kPadSeqLenK, - pad_headdim, - kPadHeadDim, - [&] { - using FmhaTraits = ck_tile::TileFmhaTraits< - kPadSeqLenQ, - kPadSeqLenK, - kPadHeadDim, // kPadHeadDimQ, - kPadHeadDim, // kPadHeadDimV, - kBiasEnum, - false, // kHasBiasGrad place-holder - false, // kStoreLSE - kHasDropout, - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaPipelineProblem = - FmhaPipelineProblemTemp; - - using FmhaPipeline = - ck_tile::BlockFmhaPipelineQRKSVS; - - using FmhaEpilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - kPadSeqLenQ, - kPadHeadDim>>; - - using FmhaKernel = ck_tile::FmhaFwdKernel< - FmhaTilePartitioner, - FmhaPipeline, - FmhaEpilogue>; - - RunWithKernel(param, stream); - }); - } else { - BOOL_SWITCH(pad_seqlen_k, kPadSeqLenK, [&] { - using FmhaTraits = ck_tile::TileFmhaTraits< - true, // kPadSeqLenQ, - kPadSeqLenK, - true, // kPadHeadDimQ, - true, // kPadHeadDimV, - kBiasEnum, - false, // kHasBiasGrad place-holder - false, // kStoreLSE - kHasDropout, - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaPipelineProblem = - FmhaPipelineProblemTemp; - - using FmhaPipeline = - ck_tile::BlockFmhaPipelineQRKSVSAsync; - - using FmhaEpilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - true, - true>>; - - using FmhaKernel = ck_tile:: - FmhaFwdKernel; - - RunWithKernel(param, stream); - }); - }; - }); - }; - - template - static void RunWithKernel(BatchedForwardParams& param, hipStream_t stream) { - const auto kargs = [&] { - return FmhaKernel::MakeKargs( - param.q_ptr, - param.k_ptr, - param.v_ptr, - param.attn_bias_ptr, - nullptr, // rand_val_ptr - nullptr, // lse_ptr - param.out_ptr, - param.M, // seqlen_q - param.N, // seqlen_k - param.K, // hdim_q - param.Kv, // hdim_v - param.Hq, // nhead_q - param.Hq / param.Hkv, // nhead_ratio_qk - param.scale, - 1.0f, // scale_p - 1.0f, // scale_o - param.q_strides[1], // q, k, v, bias, randval, out tensor seq-dim - // stride - param.k_strides[1], - param.v_strides[1], - param.attn_bias_strides[2], - 0, // stride_randval - param.out_strides[1], - param.q_strides[2], // q, k, v, bias, randval, lse, out tensor - // head-dim stride - param.k_strides[2], - param.v_strides[2], - param.attn_bias_strides[1], - 0, // nhead_stride_randval - 0, // nhead_stride_lse - param.out_strides[2], - param.q_strides[0], // q, k, v, bias, randval, lse, out tensor - // batch-dim stride - param.k_strides[0], - param.v_strides[0], - param.attn_bias_strides[0], - 0, // batch_stride_randval - 0, // batch_stride_lse - param.out_strides[0], - (param.window_size > 0) ? param.window_size - 1 - : -1, // window_left_size - (param.custom_mask_type == 0) ? -1 : 0, // window_right_size - param.custom_mask_type, - param.dropout_prob, // dropout ratio - false, // is_store_randval - {param.philox_seed, param.philox_offset}); - }(); - - dim3 kGridSize = FmhaKernel::GridSize(param.B, param.Hq, param.M, param.Kv); - constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; - - (void)ck_tile::launch_kernel( - ck_tile::stream_config{stream, false}, - ck_tile::make_kernel( - FmhaKernel{}, kGridSize, kBlockSize, 0, kargs)); - }; -}; +#include "ck_tiled_fmha_fwd_splitkv_smallq_selector.h" +#include "ck_tiled_fmha_seqlen_q_switch.h" template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasDropout, ck_tile::index_t MaxK> -void run_batched_infer_causalmask_bias_dropout_dispatch( +void run_batched_infer_mask_bias_dropout_dispatch( BatchedForwardParams& param, hipStream_t stream) { - batched_infer_causalmask_bias_dropout_dispatch< - ScalarType, - kHasCausalMask, - kHasBias, - kHasDropout, - MaxK>::Run(param, stream); + // currently split-kv implementation does not support dropout + if constexpr (!kHasDropout) { + if (param.use_split_kv) { + if (use_splitkv_smallq(param.M, std::max(param.K, param.Kv))) { + batched_infer_splitkv_smallq_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK>::Run(param, stream); + } else { + FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] { + batched_infer_splitkv_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK, + MaxSeqlenQ>::Run(param, stream); + }); + } + } else { + if (get_fmha_fwd_mtile(param.B, param.Hq, param.M) == 128) + batched_infer_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + else + batched_infer_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 64>::Run(param, stream); + } + } else { + // at present, dropout of fwd kernel requires 32x32 WarpTile + batched_infer_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + } }; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp index 23b04d935f..dca87ca6c7 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bf16.cpp @@ -16,15 +16,17 @@ void batched_infer_bf16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_batched_infer_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_batched_infer_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h new file mode 100644 index 0000000000..ed49eac35e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_dispatch.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_setting.h" +#include "ck_tiled_fmha_params.h" +#include "ck_tiled_headdim_switch.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + bool kHasDropout, + ck_tile::index_t MaxK, + ck_tile::index_t MTile> +struct batched_infer_mask_bias_dropout_dispatch { + template + using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + typename FmhaFwdShape::Type, + false, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaShape = typename FmhaFwdShape::Type; + constexpr ck_tile::index_t occupancy = + (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + const bool pad_seqlen_q = !(param.M % FmhaShape::kM0 == 0); + const bool pad_seqlen_k = + (param.N == 0) || !(param.N % FmhaShape::kN0 == 0); + const bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); + const bool pad_headdim_q = !(param.K % FmhaShape::kSubQKHeaddim == 0); + + // usually headdim_q and headdim_v are same, consider them together to + // determine whether to do padding saving some compiling time + const bool pad_headdim = (pad_headdim_q || pad_headdim_v); + + const bool use_async_pipeline = + (!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) && + (MaxK <= 128)); + + if (!use_async_pipeline) { + BOOL_SWITCH_3( + pad_seqlen_q, + kPadSeqLenQ, + pad_seqlen_k, + kPadSeqLenK, + pad_headdim, + kPadHeadDim, + [&] { + using FmhaTraits = ck_tile::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + kHasDropout, + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaPipelineQRKSVS; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDim>>; + + using FmhaKernel = + ck_tile::FmhaFwdKernel; + + RunWithKernel(param, stream); + }); + } else { + BOOL_SWITCH(pad_seqlen_k, kPadSeqLenK, [&] { + using FmhaTraits = ck_tile::TileFmhaTraits< + true, // kPadSeqLenQ, + kPadSeqLenK, + true, // kPadHeadDimQ, + true, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + kHasDropout, + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaPipelineQRKSVSAsync; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, + true>>; + + using FmhaKernel = ck_tile::FmhaFwdKernel; + + RunWithKernel(param, stream); + }); + }; + }; + + template + static void RunWithKernel(BatchedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // rand_val_ptr + nullptr, // lse_ptr + param.out_ptr, + param.M, // seqlen_q + param.N, // seqlen_k + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + 1.0f, // scale_p + 1.0f, // scale_o + param.q_strides[1], // q, k, v, bias, randval, out tensor seq-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + 0, // stride_randval + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, randval, lse, out tensor + // head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + 0, // nhead_stride_randval + 0, // nhead_stride_lse + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, randval, lse, out tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + 0, // batch_stride_randval + 0, // batch_stride_lse + param.out_strides[0], + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type, + param.dropout_prob, // dropout ratio + false, // is_store_randval + std::make_pair(param.philox_seed, param.philox_offset)); + }(); + + dim3 kGridSize = + FmhaKernel::GridSize(param.B, param.Hq, param.M, param.Kv, false); + constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp index 4e1d99e8ec..2d899e9378 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp @@ -16,15 +16,17 @@ void batched_infer_fp16(BatchedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_batched_infer_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_batched_infer_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_dispatch.h new file mode 100644 index 0000000000..1e8e70e398 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_dispatch.h @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK, + ck_tile::index_t MaxSeqlenQ> +struct batched_infer_splitkv_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVShape::Type, + false, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + false, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + const bool pad_seqlen_q = !(param.M % FmhaTileShape::kM0 == 0); + const bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + const bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + + // usually headdim_q and headdim_v are same, consider them together to + // determine whether to do padding saving some compiling time + const bool pad_headdim = (pad_headdim_q || pad_headdim_v); + + const bool has_uneven_splits = + !(param.N % (param.num_kv_splits * FmhaTileShape::kN0) == 0); + + BOOL_SWITCH_3( + pad_seqlen_q, + kPadSeqLenQ, + pad_headdim, + kPadHeadDim, + has_uneven_splits, + kHasUnevenSplits, + [&] { + constexpr bool kPadSeqLenK = kHasUnevenSplits ? true : false; + + if (param.num_kv_splits > 1) { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + kHasUnevenSplits, + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + kHasUnevenSplits, + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + }; + + if (param.num_kv_splits > 1) { + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + const bool pad_seqlen_q = !(param.M % kM0 == 0); + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH_2( + pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH( + param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + FmhaPipelineProblem>; + + using FmhaEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem< + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = ck_tile:: + FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_acc_strides[2], + param.q_strides[2], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.lse_acc_strides[2], + param.out_acc_strides[3], + param.q_strides[0], // q, k, v, bias, lse_acc, out_acc tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.lse_acc_strides[1], + param.out_acc_strides[1], + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // lse_ptr + param.out_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, lse, out tensor head-dim + // stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + 0, // nhead_stride_lse + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, lse, out tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + 0, // batch_stride_lse + param.out_strides[0], + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.B, param.Hq, param.M, param.Kv, param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + nullptr, // lse_ptr, not used + param.out_ptr, + param.B, // batches + param.M, // seqlen_q + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[2], // row_stride_o_acc + param.out_strides[1], // row_stride_o + param.lse_acc_strides[2], // head_stride_lse_acc + param.out_acc_strides[3], // head_stride_o_acc + 0, // head_stride_lse, // not used + param.out_strides[2], // head_stride_o + param.lse_acc_strides[1], // batch_stride_lse_acc + param.out_acc_strides[1], // batch_stride_o_acc + 0, // batch_stride_lse, not used + param.out_strides[0], // batch_stride_o + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0]); // split_stride_out_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.B, param.Hq, param.M, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_smallq_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_smallq_dispatch.h new file mode 100644 index 0000000000..9ef7c24424 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_splitkv_smallq_dispatch.h @@ -0,0 +1,370 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_smallq_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK> +struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVSmallQShape::Type, + false, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + false, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + const bool pad_seqlen_q = !(param.M % FmhaTileShape::kM0 == 0); + const bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + const bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + + // usually headdim_q and headdim_v are same, consider them together to + // determine whether to do padding saving some compiling time + const bool pad_headdim = (pad_headdim_q || pad_headdim_v); + + const bool has_uneven_splits = + !(param.N % (param.num_kv_splits * FmhaTileShape::kN0) == 0); + + BOOL_SWITCH_3( + pad_seqlen_q, + kPadSeqLenQ, + pad_headdim, + kPadHeadDim, + has_uneven_splits, + kHasUnevenSplits, + [&] { + constexpr bool kPadSeqLenK = kHasUnevenSplits ? true : false; + + if (param.num_kv_splits > 1) { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + kHasUnevenSplits, + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDim, // kPadHeadDimQ, + kPadHeadDim, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + kHasUnevenSplits, + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + }; + + if (param.num_kv_splits > 1) { + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + const bool pad_seqlen_q = !(param.M % kM0 == 0); + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH_2( + pad_seqlen_q, kPadSeqLenQ, pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH( + param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + FmhaPipelineProblem>; + + using FmhaEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem< + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = ck_tile:: + FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_acc_strides[2], + param.q_strides[2], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.lse_acc_strides[2], + param.out_acc_strides[3], + param.q_strides[0], // q, k, v, bias, lse_acc, out_acc tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.lse_acc_strides[1], + param.out_acc_strides[1], + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // lse_ptr + param.out_ptr, + param.B, // batch + param.M, // seqlen_q + param.N, // seqlen_k + nullptr, // seqlen_k_ptr, not used + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr, not used + 0, // batch_stride_block_table, not used + 0, // page_table_size, not used + nullptr, // cache_batch_idx, not used + param.scale, + 1.0f, // scale_p + param.q_strides[1], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, lse, out tensor head-dim + // stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + 0, // nhead_stride_lse + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, lse, out tensor + // batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + 0, // batch_stride_lse + param.out_strides[0], + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.B, param.Hq, param.M, param.Kv, param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + BatchedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + nullptr, // lse_ptr, not used + param.out_ptr, + param.B, // batches + param.M, // seqlen_q + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[2], // row_stride_o_acc + param.out_strides[1], // row_stride_o + param.lse_acc_strides[2], // head_stride_lse_acc + param.out_acc_strides[3], // head_stride_o_acc + 0, // head_stride_lse, // not used + param.out_strides[2], // head_stride_o + param.lse_acc_strides[1], // batch_stride_lse_acc + param.out_acc_strides[1], // batch_stride_o_acc + 0, // batch_stride_lse, not used + param.out_strides[0], // batch_stride_o + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0]); // split_stride_out_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.B, param.Hq, param.M, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h index 9e2ba48187..ccf6b1bdc2 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_bwd_setting.h @@ -70,6 +70,14 @@ struct FmhaBwdBlockTile<64> { using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4 }; +template <> +struct FmhaBwdBlockTile<96> { + using tile_lengths = ck_tile::sequence<16, 128, 96, 16, 96, 16, 32, 128, 128>; + using gemm02_warps = ck_tile::sequence<1, 4, 1>; // default for gemm0/gemm2 + using gemm13_warps = ck_tile::sequence<4, 1, 1>; // default for gemm1/gemm3 + using gemm4_warps = ck_tile::sequence<1, 4, 1>; // default for gemm4 +}; + template <> struct FmhaBwdBlockTile<128> { using tile_lengths = @@ -123,6 +131,20 @@ struct FmhaBwdShape<64> : ck_tile::TileFmhaBwdShape< typename FmhaBwdBlockTile<64>::gemm4_warps, FmhaBwdWarpTile2> {}; +template <> +struct FmhaBwdShape<96> : ck_tile::TileFmhaBwdShape< + typename FmhaBwdBlockTile<96>::tile_lengths, + typename FmhaBwdBlockTile<96>::gemm02_warps, + FmhaBwdWarpTile2, + typename FmhaBwdBlockTile<96>::gemm13_warps, + FmhaBwdWarpTile3, + typename FmhaBwdBlockTile<96>::gemm02_warps, + FmhaBwdWarpTile2, + typename FmhaBwdBlockTile<96>::gemm13_warps, + FmhaBwdWarpTile3, + typename FmhaBwdBlockTile<96>::gemm4_warps, + FmhaBwdWarpTile2> {}; + template <> struct FmhaBwdShape<128> : ck_tile::TileFmhaBwdShape< typename FmhaBwdBlockTile<128>::tile_lengths, @@ -151,7 +173,7 @@ struct FmhaBwdShape<256> : ck_tile::TileFmhaBwdShape< typename FmhaBwdBlockTile<256>::gemm4_warps, FmhaBwdWarpTile2> {}; -template +template struct FmhaBwdPipelineEnumSelector { static constexpr ck_tile::BlockFmhaBwdPipelineEnum value = ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h index ddd91a6864..922bdd05d6 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h @@ -8,113 +8,164 @@ #include #include -#include +#include "ck_fmha_util.h" +#include "ck_tiled_fmha_fwd_type_config.h" -template -struct FmhaFwdTypeConfig; +template +struct FmhaFwdBlockTile; -template <> -struct FmhaFwdTypeConfig { - using QDataType = ck_tile::fp16_t; - using KDataType = ck_tile::fp16_t; - using VDataType = ck_tile::fp16_t; - using BiasDataType = ck_tile::fp16_t; - using RandValOutputDataType = unsigned short; - using LSEDataType = - float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp16_t; +// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0) +// +template +struct FmhaFwdBlockTile<32, MTile> { + using type = ck_tile::sequence<64, 64, 16, 32, 32, 32>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; -template <> -struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using RandValOutputDataType = unsigned short; - using LSEDataType = - float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; +template struct FmhaFwdBlockTile<32>; + +template +struct FmhaFwdBlockTile<64, MTile> { + using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template -struct FmhaFwdBlockTile; +template struct FmhaFwdBlockTile<64>; -template <> -struct FmhaFwdBlockTile<32> { - using type = ck_tile::sequence<128, 64, 16, 32, 32, 32>; - using gemm0_warps = ck_tile::sequence<2, 1, 1>; - using gemm1_warps = ck_tile::sequence<2, 1, 1>; +template +struct FmhaFwdBlockTile<96, MTile> { + using type = ck_tile::sequence<128, 128, 32, 128, 32, 96>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; +template struct FmhaFwdBlockTile<96>; + template <> -struct FmhaFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +struct FmhaFwdBlockTile<128, 64> { + using type = ck_tile::sequence<64, 128, 32, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; template <> -struct FmhaFwdBlockTile<128> { +struct FmhaFwdBlockTile<128, 128> { using type = ck_tile::sequence<128, 128, 32, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct FmhaFwdBlockTile<256> { +template +struct FmhaFwdBlockTile<256, MTile> { using type = ck_tile::sequence<128, 128, 32, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -using FmhaFwdWarpTile = ck_tile::sequence<32, 32, 16>; +template struct FmhaFwdBlockTile<256>; -static constexpr bool IsVLayoutRowMajor = true; +using FmhaFwdWarpTile1 = ck_tile::sequence<32, 32, 16>; +using FmhaFwdWarpTile2 = ck_tile::sequence<16, 16, 16>; -template +template struct FmhaFwdShape; -template <> -struct FmhaFwdShape<32> : ck_tile::TileFmhaShape< - typename FmhaFwdBlockTile<32>::type, - typename FmhaFwdBlockTile<32>::gemm0_warps, - FmhaFwdWarpTile, - typename FmhaFwdBlockTile<32>::gemm1_warps, - FmhaFwdWarpTile, - IsVLayoutRowMajor> {}; +template +struct FmhaFwdShape<32, MTile> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdBlockTile<32>::type, + typename FmhaFwdBlockTile<32>::gemm0_warps, + FmhaFwdWarpTile1, + typename FmhaFwdBlockTile<32>::gemm1_warps, + FmhaFwdWarpTile1, + IsVLayoutRowMajor>; +}; -template <> -struct FmhaFwdShape<64> : ck_tile::TileFmhaShape< - typename FmhaFwdBlockTile<64>::type, - typename FmhaFwdBlockTile<64>::gemm0_warps, - FmhaFwdWarpTile, - typename FmhaFwdBlockTile<64>::gemm1_warps, - FmhaFwdWarpTile, - IsVLayoutRowMajor> {}; +template struct FmhaFwdShape<32, 64>; +template struct FmhaFwdShape<32, 128>; + +template +struct FmhaFwdShape<64, MTile> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdBlockTile<64>::type, + typename FmhaFwdBlockTile<64>::gemm0_warps, + FmhaFwdWarpTile1, + typename FmhaFwdBlockTile<64>::gemm1_warps, + FmhaFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdShape<64, 64>; +template struct FmhaFwdShape<64, 128>; + +template +struct FmhaFwdShape<96, MTile> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdBlockTile<96>::type, + typename FmhaFwdBlockTile<96>::gemm0_warps, + FmhaFwdWarpTile1, + typename FmhaFwdBlockTile<96>::gemm1_warps, + FmhaFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdShape<96, 64>; +template struct FmhaFwdShape<96, 128>; template <> -struct FmhaFwdShape<128> : ck_tile::TileFmhaShape< - typename FmhaFwdBlockTile<128>::type, - typename FmhaFwdBlockTile<128>::gemm0_warps, - FmhaFwdWarpTile, - typename FmhaFwdBlockTile<128>::gemm1_warps, - FmhaFwdWarpTile, - IsVLayoutRowMajor> {}; +struct FmhaFwdShape<128, 64> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdBlockTile<128, 64>::type, + typename FmhaFwdBlockTile<128, 64>::gemm0_warps, + FmhaFwdWarpTile2, + typename FmhaFwdBlockTile<128, 64>::gemm1_warps, + FmhaFwdWarpTile2, + IsVLayoutRowMajor>; +}; template <> -struct FmhaFwdShape<256> : ck_tile::TileFmhaShape< - typename FmhaFwdBlockTile<256>::type, - typename FmhaFwdBlockTile<256>::gemm0_warps, - FmhaFwdWarpTile, - typename FmhaFwdBlockTile<256>::gemm1_warps, - FmhaFwdWarpTile, - IsVLayoutRowMajor> {}; +struct FmhaFwdShape<128, 128> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdBlockTile<128, 128>::type, + typename FmhaFwdBlockTile<128, 128>::gemm0_warps, + FmhaFwdWarpTile1, + typename FmhaFwdBlockTile<128, 128>::gemm1_warps, + FmhaFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template +struct FmhaFwdShape<256, MTile> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdBlockTile<256>::type, + typename FmhaFwdBlockTile<256>::gemm0_warps, + FmhaFwdWarpTile1, + typename FmhaFwdBlockTile<256>::gemm1_warps, + FmhaFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdShape<256, 64>; +template struct FmhaFwdShape<256, 128>; + +static int get_fmha_fwd_mtile( + int num_batches, + int num_heads, + int max_seqlen_q) { + int num_SMs = get_number_of_cu(); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + int batch_nhead_mblocks = + num_batches * num_heads * ceildiv(max_seqlen_q, 128); + + if (batch_nhead_mblocks >= 0.8 * num_SMs) + return 128; + + return 64; +}; + +static int get_fmha_fwd_least_mtile() { + return 64; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_selector.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_selector.h new file mode 100644 index 0000000000..daa281c28d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_selector.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include "ck_fmha_util.h" +#include "ck_tiled_fmha_fwd_setting.h" +#include "ck_tiled_fmha_fwd_splitkv_setting.h" +#include "ck_tiled_fmha_fwd_splitkv_smallq_setting.h" +#include "ck_tiled_fmha_seqlen_q_switch.h" + +// generate a list of numbers as num_splits to consider, the list of numbers is +// like 1, 2, 4, 8, 16, 32, 64, 96, 128, 160 +static int generate_splits_list(int i) { + if (i <= 0) + return 1; + + if (i <= 5) + return 1 << (i - 1); + else + return (i - 5) * 32; +}; + +static std::pair get_num_kv_splits_heuristic( + int num_batches, + int num_heads, + int max_seqlen_q, + int max_headdim, + int max_splits) { + int num_SMs = get_number_of_cu(); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + int mtile_size_for_pipeline_default = get_fmha_fwd_least_mtile(); + int mtile_size_for_splitkv = 64; + int mtile_size_for_splitkv_smallq = 16; + + // get mtile_size_for_splitkv + mtile_size_for_splitkv = + get_mtile_size_for_splitkv(max_seqlen_q, max_headdim); + + // get mtile_size_for_splitkv_smallq + mtile_size_for_splitkv_smallq = + get_mtile_size_for_splitkv_smallq(max_headdim); + + if (max_seqlen_q >= mtile_size_for_pipeline_default) { + int batch_nhead_mblocks = num_batches * num_heads * + ceildiv(max_seqlen_q, mtile_size_for_pipeline_default); + + if (batch_nhead_mblocks >= 0.8f * num_SMs) + return std::make_pair(false, 1); + } + + bool use_splitkv = true; + + // m_tile size is the size for dividing the seqlen_q + // we first tries to use the normal splitkv kernel + int mtile_size = mtile_size_for_splitkv; + int batch_nhead_mblocks = + num_batches * num_heads * ceildiv(max_seqlen_q, mtile_size); + + // resort to splitkv-smallq kernel for avoiding wasting of computation or for + // better CU occupancy + if (max_seqlen_q <= mtile_size_for_splitkv_smallq) + mtile_size = mtile_size_for_splitkv_smallq; + + batch_nhead_mblocks = + num_batches * num_heads * ceildiv(max_seqlen_q, mtile_size); + + // If we have enough workgroups to fill all the SMs, then just use 1 split + if (batch_nhead_mblocks >= 0.9f * num_SMs) { + return std::make_pair(use_splitkv, 1); + } + + max_splits = std::min({max_splits, num_SMs}); + + int max_check = 1; + + while (generate_splits_list(max_check) <= max_splits) + max_check++; + + int num_splits = 2; + for (int i = 2; i < max_check; i++) { + num_splits = generate_splits_list(i); + + if (batch_nhead_mblocks * num_splits >= num_SMs) + break; + }; + + return std::make_pair(use_splitkv, num_splits); +} diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_setting.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_setting.h new file mode 100644 index 0000000000..82e0c2c403 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_setting.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include "ck_tiled_fmha_fwd_type_config.h" +#include "ck_tiled_fmha_seqlen_q_switch.h" + +template +struct FmhaFwdSplitKVBlockTile; + +// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0) + +template +struct FmhaFwdSplitKVBlockTile<32, MaxSeqLenQ> { + using type = ck_tile::sequence<32, 64, 16, 32, 32, 32>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; +}; + +template struct FmhaFwdSplitKVBlockTile<32>; + +template +struct FmhaFwdSplitKVBlockTile<64, MaxSeqLenQ> { + using type = ck_tile::sequence<32, 64, 32, 64, 32, 64>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; +}; + +template struct FmhaFwdSplitKVBlockTile<64>; + +template +struct FmhaFwdSplitKVBlockTile<96, MaxSeqLenQ> { + using type = ck_tile::sequence<64, 128, 32, 128, 32, 96>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template struct FmhaFwdSplitKVBlockTile<96>; + +template <> +struct FmhaFwdSplitKVBlockTile<128, 32> { + using type = ck_tile::sequence<32, 128, 32, 128, 32, 128>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; +}; + +template <> +struct FmhaFwdSplitKVBlockTile<128, 64> { + using type = ck_tile::sequence<64, 128, 32, 128, 32, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template +struct FmhaFwdSplitKVBlockTile<256, MaxSeqLenQ> { + using type = ck_tile::sequence<64, 128, 32, 256, 32, 256>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template struct FmhaFwdSplitKVBlockTile<256>; + +using FmhaFwdSplitKVWarpTile = ck_tile::sequence<16, 16, 16>; + +template +struct FmhaFwdSplitKVShape; + +template +struct FmhaFwdSplitKVShape<32, MaxSeqLenQ> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVBlockTile<32>::type, + typename FmhaFwdSplitKVBlockTile<32>::gemm0_warps, + FmhaFwdSplitKVWarpTile, + typename FmhaFwdSplitKVBlockTile<32>::gemm1_warps, + FmhaFwdSplitKVWarpTile, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdSplitKVShape<32, 32>; +template struct FmhaFwdSplitKVShape<32, 64>; + +template +struct FmhaFwdSplitKVShape<64, MaxSeqLenQ> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVBlockTile<64>::type, + typename FmhaFwdSplitKVBlockTile<64>::gemm0_warps, + FmhaFwdSplitKVWarpTile, + typename FmhaFwdSplitKVBlockTile<64, MaxSeqLenQ>::gemm1_warps, + FmhaFwdSplitKVWarpTile, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdSplitKVShape<64, 32>; +template struct FmhaFwdSplitKVShape<64, 64>; + +template +struct FmhaFwdSplitKVShape<96, MaxSeqLenQ> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVBlockTile<96>::type, + typename FmhaFwdSplitKVBlockTile<96>::gemm0_warps, + FmhaFwdSplitKVWarpTile, + typename FmhaFwdSplitKVBlockTile<96, MaxSeqLenQ>::gemm1_warps, + FmhaFwdSplitKVWarpTile, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdSplitKVShape<96, 32>; +template struct FmhaFwdSplitKVShape<96, 64>; + +template <> +struct FmhaFwdSplitKVShape<128, 32> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVBlockTile<128, 32>::type, + typename FmhaFwdSplitKVBlockTile<128, 32>::gemm0_warps, + FmhaFwdSplitKVWarpTile, + typename FmhaFwdSplitKVBlockTile<128, 32>::gemm1_warps, + FmhaFwdSplitKVWarpTile, + IsVLayoutRowMajor>; +}; + +template <> +struct FmhaFwdSplitKVShape<128, 64> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVBlockTile<128, 64>::type, + typename FmhaFwdSplitKVBlockTile<128, 64>::gemm0_warps, + FmhaFwdSplitKVWarpTile, + typename FmhaFwdSplitKVBlockTile<128, 64>::gemm1_warps, + FmhaFwdSplitKVWarpTile, + IsVLayoutRowMajor>; +}; + +template +struct FmhaFwdSplitKVShape<256, MaxSeqLenQ> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVBlockTile<256>::type, + typename FmhaFwdSplitKVBlockTile<256>::gemm0_warps, + FmhaFwdSplitKVWarpTile, + typename FmhaFwdSplitKVBlockTile<256>::gemm1_warps, + FmhaFwdSplitKVWarpTile, + IsVLayoutRowMajor>; +}; + +template struct FmhaFwdSplitKVShape<256, 32>; +template struct FmhaFwdSplitKVShape<256, 64>; + +template +int fwd_splitkv_get_mtile_size() { + using FmhaTileShape = typename FmhaFwdSplitKVShape::Type; + + return FmhaTileShape::kM0; +}; + +static int get_mtile_size_for_splitkv(int max_seqlen_q, int max_headdim) { + int mtile_size_for_splitkv = 64; + + FMHA_FWD_SEQLEN_Q_SWITCH(max_seqlen_q, MaxSeqLenQ, [&] { + if (max_headdim <= 32) { + mtile_size_for_splitkv = fwd_splitkv_get_mtile_size<32, MaxSeqLenQ>(); + } else if (max_headdim <= 64) { + mtile_size_for_splitkv = fwd_splitkv_get_mtile_size<64, MaxSeqLenQ>(); + } else if (max_headdim <= 96) { + mtile_size_for_splitkv = fwd_splitkv_get_mtile_size<96, MaxSeqLenQ>(); + } else if (max_headdim <= 128) { + mtile_size_for_splitkv = fwd_splitkv_get_mtile_size<128, MaxSeqLenQ>(); + } else { + mtile_size_for_splitkv = fwd_splitkv_get_mtile_size<256, MaxSeqLenQ>(); + }; + }); + + return mtile_size_for_splitkv; +} diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_selector.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_selector.h new file mode 100644 index 0000000000..fec619fed3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_selector.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "ck_tiled_fmha_fwd_splitkv_setting.h" +#include "ck_tiled_fmha_fwd_splitkv_smallq_setting.h" + +/// This method determines whether to use normal or smallq splitkv kernel +static bool use_splitkv_smallq(int max_seqlen_q, int max_headdim) { + int mtile_size_for_splitkv_smallq = + get_mtile_size_for_splitkv_smallq(max_headdim); + + // resort to splitkv-smallq kernel for avoiding wasting of computation + if (max_seqlen_q <= mtile_size_for_splitkv_smallq) + return true; + + return false; +} diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_setting.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_setting.h new file mode 100644 index 0000000000..0688fa0dbb --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_splitkv_smallq_setting.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include "ck_tiled_fmha_fwd_type_config.h" + +template +struct FmhaFwdSplitKVSmallQBlockTile; + +// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0) + +template <> +struct FmhaFwdSplitKVSmallQBlockTile<32> { + using type = ck_tile::sequence<16, 64, 16, 32, 32, 32>; + using gemm0_warps = ck_tile::sequence<1, 2, 1>; + using gemm1_warps = ck_tile::sequence<1, 2, 1>; +}; + +template <> +struct FmhaFwdSplitKVSmallQBlockTile<64> { + using type = ck_tile::sequence<16, 64, 32, 64, 32, 64>; + using gemm0_warps = ck_tile::sequence<1, 4, 1>; + using gemm1_warps = ck_tile::sequence<1, 4, 1>; +}; + +template <> +struct FmhaFwdSplitKVSmallQBlockTile<96> { + using type = ck_tile::sequence<16, 64, 32, 128, 32, 96>; + using gemm0_warps = ck_tile::sequence<1, 4, 1>; + using gemm1_warps = ck_tile::sequence<1, 4, 1>; +}; + +template <> +struct FmhaFwdSplitKVSmallQBlockTile<128> { + using type = ck_tile::sequence<16, 64, 64, 128, 64, 128>; + using gemm0_warps = ck_tile::sequence<1, 4, 1>; + using gemm1_warps = ck_tile::sequence<1, 4, 1>; +}; + +template <> +struct FmhaFwdSplitKVSmallQBlockTile<256> { + using type = ck_tile::sequence<16, 64, 64, 256, 64, 256>; + using gemm0_warps = ck_tile::sequence<1, 4, 1>; + using gemm1_warps = ck_tile::sequence<1, 4, 1>; +}; + +using FmhaFwdSplitKVSmallQWarpTile0 = ck_tile::sequence<16, 16, 16>; +using FmhaFwdSplitKVSmallQWarpTile1 = ck_tile::sequence<16, 16, 16>; + +template +struct FmhaFwdSplitKVSmallQShape; + +template <> +struct FmhaFwdSplitKVSmallQShape<32> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVSmallQBlockTile<32>::type, + typename FmhaFwdSplitKVSmallQBlockTile<32>::gemm0_warps, + FmhaFwdSplitKVSmallQWarpTile0, + typename FmhaFwdSplitKVSmallQBlockTile<32>::gemm1_warps, + FmhaFwdSplitKVSmallQWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct FmhaFwdSplitKVSmallQShape<64> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVSmallQBlockTile<64>::type, + typename FmhaFwdSplitKVSmallQBlockTile<64>::gemm0_warps, + FmhaFwdSplitKVSmallQWarpTile0, + typename FmhaFwdSplitKVSmallQBlockTile<64>::gemm1_warps, + FmhaFwdSplitKVSmallQWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct FmhaFwdSplitKVSmallQShape<96> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVSmallQBlockTile<96>::type, + typename FmhaFwdSplitKVSmallQBlockTile<96>::gemm0_warps, + FmhaFwdSplitKVSmallQWarpTile0, + typename FmhaFwdSplitKVSmallQBlockTile<96>::gemm1_warps, + FmhaFwdSplitKVSmallQWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct FmhaFwdSplitKVSmallQShape<128> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVSmallQBlockTile<128>::type, + typename FmhaFwdSplitKVSmallQBlockTile<128>::gemm0_warps, + FmhaFwdSplitKVSmallQWarpTile0, + typename FmhaFwdSplitKVSmallQBlockTile<128>::gemm1_warps, + FmhaFwdSplitKVSmallQWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct FmhaFwdSplitKVSmallQShape<256> { + using Type = ck_tile::TileFmhaShape< + typename FmhaFwdSplitKVSmallQBlockTile<256>::type, + typename FmhaFwdSplitKVSmallQBlockTile<256>::gemm0_warps, + FmhaFwdSplitKVSmallQWarpTile0, + typename FmhaFwdSplitKVSmallQBlockTile<256>::gemm1_warps, + FmhaFwdSplitKVSmallQWarpTile1, + IsVLayoutRowMajor>; +}; + +template +int fwd_splitkv_smallq_get_mtile_size() { + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + return FmhaTileShape::kM0; +}; + +static int get_mtile_size_for_splitkv_smallq(int max_headdim) { + int mtile_size_for_splitkv_smallq = 16; + + if (max_headdim <= 32) { + mtile_size_for_splitkv_smallq = fwd_splitkv_smallq_get_mtile_size<32>(); + } else if (max_headdim <= 64) { + mtile_size_for_splitkv_smallq = fwd_splitkv_smallq_get_mtile_size<64>(); + } else if (max_headdim <= 96) { + mtile_size_for_splitkv_smallq = fwd_splitkv_smallq_get_mtile_size<96>(); + } else if (max_headdim <= 128) { + mtile_size_for_splitkv_smallq = fwd_splitkv_smallq_get_mtile_size<128>(); + } else { + mtile_size_for_splitkv_smallq = fwd_splitkv_smallq_get_mtile_size<256>(); + }; + + return mtile_size_for_splitkv_smallq; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_type_config.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_type_config.h new file mode 100644 index 0000000000..72e4a5e1e6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_type_config.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig { + using QDataType = ck_tile::fp16_t; + using KDataType = ck_tile::fp16_t; + using VDataType = ck_tile::fp16_t; + using BiasDataType = ck_tile::fp16_t; + using RandValOutputDataType = unsigned short; + using LSEDataType = + float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp16_t; +}; + +template <> +struct FmhaFwdTypeConfig { + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = unsigned short; + using LSEDataType = + float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +static constexpr bool IsVLayoutRowMajor = true; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h index 82d9920f6d..dc7909a576 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward.h @@ -7,7 +7,8 @@ #pragma once #include -#include +#include +#include #include #include @@ -17,12 +18,12 @@ template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasBiasGrad, bool kHasDropout, ck_tile::index_t MaxK> -struct grouped_backward_causalmask_bias_dropout_dispatch { +struct grouped_backward_mask_bias_dropout_dispatch { using FmhaBlockDropout = typename FmhaBwdBlockDropoutMaker::dropout; @@ -57,7 +58,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { static void Run(GroupedBackwardParams& param, hipStream_t stream) { { constexpr ck_tile::index_t kBlockSize = 64; - bool pad_headdim_v = !(param.Kv % FmhaBwdShape::kVHeaddim == 0); + bool pad_headdim_v = !(param.Kv % MaxK == 0); constexpr bool kPadSeqLenQ = true; @@ -73,7 +74,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, kBlockSize, - FmhaBwdShape::kVHeaddim, + MaxK, // kVHeaddim true, // kIsGroupMode FmhaOGradDotOTraits_>; @@ -89,85 +90,75 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { }; { - const bool has_local_attention = (param.window_size > 0) ? true : false; - - BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { - constexpr ck_tile::index_t occupancy = 1; - constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION; - const bool has_dropout = (param.dropout_prob > 0.0f); - - using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - - constexpr auto kBiasEnum = kHasBias - ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS - : ck_tile::BlockAttentionBiasEnum::NO_BIAS; - - constexpr bool kPadSeqLenQ = true; - constexpr bool kPadSeqLenK = true; - - const bool pad_headdim_q = - !(param.K % FmhaBwdShape::kQKHeaddim == 0); - const bool pad_headdim_v = - !(param.Kv % FmhaBwdShape::kVHeaddim == 0); - - // usually headdim_q and headdim_v are same, consider them together - // to determine whether to do padding saving some compiling time - const bool pad_headdim = (pad_headdim_q || pad_headdim_v); - - BOOL_SWITCH(pad_headdim, kPadHeadDim, [&] { - using FmhaBwdTraits_ = ck_tile::TileFmhaTraits< - kPadSeqLenQ, - kPadSeqLenK, - kPadHeadDim, // kPadHeadDimQ, - kPadHeadDim, // kPadHeadDimV, - kBiasEnum, - kHasBiasGrad, - false, // kStoreLSE - false, // place-holder for kHasDropout, not used actually - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaBwdPipelineProblem = - FmhaBwdPipelineProblemTemp; - - constexpr auto FmhaBwdPipelineEnum_ = - FmhaBwdPipelineEnumSelector:: - value; - - using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker< - FmhaBwdPipelineEnum_, - FmhaBwdPipelineProblem>::pipeline; - - using FmhaBwdKGradEpilogue_ = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig::KGradDataType, - kPadSeqLenK, - kPadHeadDim>>; - - using FmhaBwdVGradEpilogue_ = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig::VGradDataType, - kPadSeqLenK, - kPadHeadDim>>; - - using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel< - FmhaBwdPipeline_, - FmhaBwdKGradEpilogue_, - FmhaBwdVGradEpilogue_>; - - RunWithBwdDQDKDVKernel(param, stream); - }); - }); + constexpr ck_tile::index_t occupancy = 1; + const bool has_dropout = (param.dropout_prob > 0.0f); + + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + const bool pad_headdim_q = + !(param.K % FmhaBwdShape::kQKHeaddim == 0); + const bool pad_headdim_v = + !(param.Kv % FmhaBwdShape::kVHeaddim == 0); + + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaBwdTraits_ = ck_tile::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + kHasBiasGrad, + false, // kStoreLSE + false, // place-holder for kHasDropout, not used actually + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaBwdPipelineProblem = + FmhaBwdPipelineProblemTemp; + + constexpr auto FmhaBwdPipelineEnum_ = + FmhaBwdPipelineEnumSelector::value; + + using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker< + FmhaBwdPipelineEnum_, + FmhaBwdPipelineProblem>::pipeline; + + using FmhaBwdKGradEpilogue_ = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + kPadSeqLenK, + kPadHeadDimQ>>; + + using FmhaBwdVGradEpilogue_ = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + kPadSeqLenK, + kPadHeadDimV>>; + + using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel< + FmhaBwdPipeline_, + FmhaBwdKGradEpilogue_, + FmhaBwdVGradEpilogue_>; + + RunWithBwdDQDKDVKernel(param, stream); + }); }; if constexpr (NeedConvertGradQ) { constexpr ck_tile::index_t kBlockSize = 128; const bool pad_seqlen_q = true; - const bool pad_headdim_q = - !(param.K % FmhaBwdShape::kQKHeaddim == 0); + const bool pad_headdim_q = !(param.K % MaxK == 0); BOOL_SWITCH_2( pad_seqlen_q, kPadSeqLenQ, pad_headdim_q, kPadHeadDimQ, [&] { @@ -186,7 +177,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { kBlockSize, 64, // kM0 1, // kN0, no use - FmhaBwdShape::kQKHeaddim, + MaxK, // kQKHeaddim true, // kIsGroupMode false, // kIsDeterministic FmhaBwdConvertQGradTraits_>; @@ -292,7 +283,7 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { (param.custom_mask_type == 0) ? -1 : 0, // window_right_size param.custom_mask_type, param.dropout_prob, // dropout ratio - {param.philox_seed, param.philox_offset}); + std::make_pair(param.philox_seed, param.philox_offset)); }(); dim3 kGridSize = FmhaBwdDQDKDVKernel::GridSize( @@ -339,17 +330,17 @@ struct grouped_backward_causalmask_bias_dropout_dispatch { template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasBiasGrad, bool kHasDropout, ck_tile::index_t MaxK> -void run_grouped_backward_causalmask_bias_dropout_dispatch( +void run_grouped_backward_mask_bias_dropout_dispatch( GroupedBackwardParams& param, hipStream_t stream) { - grouped_backward_causalmask_bias_dropout_dispatch< + grouped_backward_mask_bias_dropout_dispatch< ScalarType, - kHasCausalMask, + kHasMask, kHasBias, kHasBiasGrad, kHasDropout, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp index 7b77442be6..dd18cb4d4b 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_bf16.cpp @@ -25,16 +25,18 @@ void grouped_backward_bf16(GroupedBackwardParams& param, hipStream_t stream) { [&] { if constexpr (kHasBias || !kHasBiasGrad) { FMHA_BWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_grouped_backward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, kHasBias, kHasBiasGrad, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_grouped_backward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp index be47bbdbb1..f5f2a954e8 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_backward_fp16.cpp @@ -25,16 +25,18 @@ void grouped_backward_fp16(GroupedBackwardParams& param, hipStream_t stream) { [&] { if constexpr (kHasBias || !kHasBiasGrad) { FMHA_BWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_grouped_backward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, kHasBias, kHasBiasGrad, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_grouped_backward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h index 519a5ea89e..5d19d6cc0e 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h @@ -6,186 +6,69 @@ */ #pragma once -#include -#include -#include -#include - -#include "ck_tiled_bool_switch.h" +#include #include "ck_tiled_fmha_fwd_setting.h" -#include "ck_tiled_fmha_params.h" - -template < - typename ScalarType, - bool kHasCausalMask, - bool kHasBias, - bool kHasDropout, - ck_tile::index_t MaxK> -struct grouped_forward_causalmask_bias_dropout_dispatch { - template - using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - FmhaFwdShape, - true, // kIsGroupMode - FmhaMask, - FmhaTraits>; - - static void Run(GroupedForwardParams& param, hipStream_t stream) { - const bool has_local_attention = (param.window_size > 0) ? true : false; - - BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { - constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION; - - using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - - using FmhaFwdShape_ = FmhaFwdShape; - - constexpr ck_tile::index_t occupancy = (MaxK == 64) ? 3 - : (MaxK == 256) ? 1 - : 2; - - constexpr auto kBiasEnum = kHasBias - ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS - : ck_tile::BlockAttentionBiasEnum::NO_BIAS; - - constexpr bool kPadSeqLenQ = true; - constexpr bool kPadSeqLenK = true; - - const bool pad_headdim_q = - !(param.K % FmhaFwdShape_::kK0BlockLength == 0); - const bool pad_headdim_v = !(param.Kv % FmhaFwdShape_::kN1 == 0); - - BOOL_SWITCH_2( - pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { - using FmhaFwdTraits_ = ck_tile::TileFmhaTraits< - kPadSeqLenQ, - kPadSeqLenK, - kPadHeadDimQ, - kPadHeadDimV, - kBiasEnum, - false, // kHasBiasGrad place-holder - true, // kStoreLSE - kHasDropout, - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaPipelineProblem = - FmhaPipelineProblemTemp; - - using FmhaFwdPipeline_ = - ck_tile::BlockFmhaPipelineQRKSVS; - - using FmhaFwdEpilogue_ = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - kPadSeqLenQ, - kPadHeadDimV>>; - - if (param.seqlen_k_dev_ptr != - nullptr) { // seqlen_k of batches are padded - using FmhaTilePartitioner = - ck_tile::FmhaFwdTilePartitioner_HBS; - using FmhaFwdKernel_ = ck_tile::FmhaFwdKernel< - FmhaTilePartitioner, - FmhaFwdPipeline_, - FmhaFwdEpilogue_>; - - RunWithKernel(param, stream); - } else { - using FmhaTilePartitioner = - ck_tile::FmhaFwdTilePartitioner_SHB; - using FmhaFwdKernel_ = ck_tile::FmhaFwdKernel< - FmhaTilePartitioner, - FmhaFwdPipeline_, - FmhaFwdEpilogue_>; - - RunWithKernel(param, stream); - } - }); - }); - }; - - template - static void RunWithKernel(GroupedForwardParams& param, hipStream_t stream) { - const auto kargs = [&] { - return FmhaFwdKernel::MakeKargs( - param.q_ptr, - param.k_ptr, - param.v_ptr, - param.attn_bias_ptr, - nullptr, // rand_val_ptr - param.logsumexp_ptr, - param.out_ptr, - param.seqstart_q_dev_ptr, - param.seqstart_k_dev_ptr, - param.seqlen_k_dev_ptr, - param.K, // hdim_q - param.Kv, // hdim_v - param.Hq, // nhead_q - param.Hq / param.Hkv, // nhead_ratio_qk - param.scale, - 1.0f, // scale_p - 1.0f, // scale_o - param.q_strides[0], // q, k, v, bias, randval, out tensor seq-dim - // stride - param.k_strides[0], - param.v_strides[0], - param.attn_bias_strides[2], - 0, // stride_randval - param.out_strides[0], - param.q_strides[1], // q, k, v, bias, randval, lse, out tensor - // head-dim stride - param.k_strides[1], - param.v_strides[1], - param.attn_bias_strides[1], - 0, // nhead_stride_randval - param.lse_strides[0], - param.out_strides[1], - (param.window_size > 0) ? param.window_size - 1 - : -1, // window_left_size - (param.custom_mask_type == 0) ? -1 : 0, // window_right_size - param.custom_mask_type, - param.dropout_prob, - false, // is_store_randval - {param.philox_seed, param.philox_offset}); - }(); - - dim3 kGridSize = FmhaFwdKernel::GridSize( - param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); - constexpr dim3 kBlockSize = FmhaFwdKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = FmhaFwdKernel::kBlockPerCu; - - (void)ck_tile::launch_kernel( - ck_tile::stream_config{stream, false}, - ck_tile::make_kernel( - FmhaFwdKernel{}, kGridSize, kBlockSize, 0, kargs)); - }; -}; +#include "ck_tiled_fmha_fwd_splitkv_smallq_selector.h" +#include "ck_tiled_fmha_grouped_forward_dispatch.h" +#include "ck_tiled_fmha_grouped_forward_splitkv_dispatch.h" +#include "ck_tiled_fmha_grouped_forward_splitkv_smallq_dispatch.h" +#include "ck_tiled_fmha_seqlen_q_switch.h" template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasDropout, ck_tile::index_t MaxK> -void run_grouped_forward_causalmask_bias_dropout_dispatch( +void run_grouped_forward_mask_bias_dropout_dispatch( GroupedForwardParams& param, hipStream_t stream) { - grouped_forward_causalmask_bias_dropout_dispatch< - ScalarType, - kHasCausalMask, - kHasBias, - kHasDropout, - MaxK>::Run(param, stream); + // currently split-kv implementation does not support dropout + if constexpr (!kHasDropout) { + if (param.use_split_kv) { + if (use_splitkv_smallq(param.max_seqlen_q, std::max(param.K, param.Kv))) { + grouped_forward_splitkv_smallq_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK>::Run(param, stream); + } else { + FMHA_FWD_SEQLEN_Q_SWITCH(param.max_seqlen_q, MaxSeqlenQ, [&] { + grouped_forward_splitkv_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK, + MaxSeqlenQ>::Run(param, stream); + }); + } + } else { + if (get_fmha_fwd_mtile(param.num_batches, param.Hq, param.max_seqlen_q) == + 128) + grouped_forward_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + else + grouped_forward_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 64>::Run(param, stream); + } + } else { + // at present, dropout of fwd kernel requires 32x32 WarpTile + grouped_forward_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + } }; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp index 28d75ddc56..bc8d28a930 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bf16.cpp @@ -17,15 +17,17 @@ void grouped_forward_bf16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_grouped_forward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_grouped_forward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_dispatch.h new file mode 100644 index 0000000000..920c093e33 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_dispatch.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_setting.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + bool kHasDropout, + ck_tile::index_t MaxK, + ck_tile::index_t MTile> +struct grouped_forward_mask_bias_dropout_dispatch { + template + using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + typename FmhaFwdShape::Type, + true, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaFwdShape_ = typename FmhaFwdShape::Type; + + constexpr ck_tile::index_t occupancy = (MaxK == 64) ? 3 + : (MaxK == 256) ? 1 + : 2; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + const bool pad_headdim_q = !(param.K % FmhaFwdShape_::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.Kv % FmhaFwdShape_::kN1 == 0); + + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaFwdTraits_ = ck_tile::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + kHasDropout, + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaFwdPipeline_ = + ck_tile::BlockFmhaPipelineQRKSVS; + + using FmhaFwdEpilogue_ = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaFwdKernel_ = + ck_tile::FmhaFwdKernel; + + RunWithKernel(param, stream); + }); + }; + + template + static void RunWithKernel(GroupedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaFwdKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // rand_val_ptr + param.logsumexp_ptr, + param.out_ptr, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + 1.0f, // scale_p + 1.0f, // scale_o + param.q_strides[0], // q, k, v, bias, randval, out tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + 0, // stride_randval + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, randval, lse, out tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + 0, // nhead_stride_randval + param.lse_strides[0], + param.out_strides[1], + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type, + param.dropout_prob, + false, // is_store_randval + std::make_pair(param.philox_seed, param.philox_offset)); + }(); + + dim3 kGridSize = FmhaFwdKernel::GridSize( + param.num_batches, + param.Hq, + param.max_seqlen_q, + param.Kv, + param.seqlen_k_dev_ptr != nullptr); + constexpr dim3 kBlockSize = FmhaFwdKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp index 31e28bad6d..ecd80de2bc 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp @@ -17,15 +17,17 @@ void grouped_forward_fp16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_grouped_forward_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_grouped_forward_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_dispatch.h new file mode 100644 index 0000000000..eacfd6bc1a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_dispatch.h @@ -0,0 +1,336 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK, + ck_tile::index_t MaxSeqlenQ> +struct grouped_forward_splitkv_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVShape::Type, + true, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + true, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + const bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + true, // kHasUnevenSplits + occupancy>; + + if (param.num_kv_splits > 1) { + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaFwdPipeline_ = + ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaFwdEpilogue_ = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaFwdKernel_ = ck_tile:: + FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaFwdPipeline_ = + ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaFwdEpilogue_ = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaFwdKernel_ = ck_tile:: + FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + }; + + if (param.num_kv_splits > 1) { + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr bool kPadSeqLenQ = true; + + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH(pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH(param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr + 0, // batch_stride_block_table + 0, // page_block_size + false, // is_gappy + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_acc_strides[1], + param.q_strides[1], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.lse_acc_strides[1], + param.out_acc_strides[2], + 0, // batch_stride_k, not used, only used for paged-kvcache + 0, // batch_stride_v, not used, only used for paged-kvcache + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr + 0, // batch_stride_block_table + 0, // page_block_size + false, // is_gappy + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, lse, out tensor head-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.lse_strides[0], + param.out_strides[1], + 0, // batch_stride_k, not used, only used for paged-kvcache + 0, // batch_stride_v, not used, only used for paged-kvcache + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.num_batches, + param.Hq, + param.max_seqlen_q, + param.Kv, + param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[1], // row_stride_o_acc, + param.out_strides[0], // row_stride_o, + param.lse_acc_strides[1], // nhead_stride_lse_acc + param.out_acc_strides[2], // nhead_stride_o_acc, + param.lse_strides[0], // nhead_stride_lse, + param.out_strides[1], // nhead_stride_o, + param.lse_acc_strides[0], // split_stride_lse_acc, + param.out_acc_strides[0]); // split_stride_o_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_smallq_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_smallq_dispatch.h new file mode 100644 index 0000000000..4f92d2bdf4 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_splitkv_smallq_dispatch.h @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_smallq_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK> +struct grouped_forward_splitkv_smallq_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVSmallQShape::Type, + true, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + true, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + const bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + false, // kIsPagedKV + true, // kHasUnevenSplits + occupancy>; + + if (param.num_kv_splits > 1) { + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaFwdPipeline_ = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaFwdEpilogue_ = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaFwdKernel_ = ck_tile:: + FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaFwdPipeline_ = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaFwdEpilogue_ = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaFwdKernel_ = ck_tile:: + FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + }; + + if (param.num_kv_splits > 1) { + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr bool kPadSeqLenQ = true; + + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH(pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH(param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr + 0, // batch_stride_block_table + 0, // page_block_size + false, // is_gappy + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_acc_strides[1], + param.q_strides[1], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.lse_acc_strides[1], + param.out_acc_strides[2], + 0, // batch_stride_k, not used, only used for paged-kvcache + 0, // batch_stride_v, not used, only used for paged-kvcache + param.lse_acc_strides[0], // split_stride_lse_acc + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + nullptr, // block_table_ptr + 0, // batch_stride_block_table + 0, // page_block_size + false, // is_gappy + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, lse, out tensor head-dim + // stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.lse_strides[0], + param.out_strides[1], + 0, // batch_stride_k, not used, only used for paged-kvcache + 0, // batch_stride_v, not used, only used for paged-kvcache + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.num_batches, + param.Hq, + param.max_seqlen_q, + param.Kv, + param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[1], // row_stride_o_acc, + param.out_strides[0], // row_stride_o, + param.lse_acc_strides[1], // nhead_stride_lse_acc + param.out_acc_strides[2], // nhead_stride_o_acc, + param.lse_strides[0], // nhead_stride_lse, + param.out_strides[1], // nhead_stride_o, + param.lse_acc_strides[0], // split_stride_lse_acc, + param.out_acc_strides[0]); // split_stride_o_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h index 3805108c1e..539e33215e 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h @@ -6,231 +6,69 @@ */ #pragma once -#include -#include -#include -#include - -#include "ck_tiled_bool_switch.h" +#include #include "ck_tiled_fmha_fwd_setting.h" -#include "ck_tiled_fmha_params.h" -#include "ck_tiled_headdim_switch.h" - -template < - typename ScalarType, - bool kHasCausalMask, - bool kHasBias, - bool kHasDropout, - ck_tile::index_t MaxK> -struct grouped_infer_causalmask_bias_dropout_dispatch { - template - using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - FmhaFwdShape, - true, // kIsGroupMode - FmhaMask, - FmhaTraits>; - - static void Run(GroupedForwardParams& param, hipStream_t stream) { - const bool has_local_attention = (param.window_size > 0) ? true : false; - - BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { - constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION; - - using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; - - using FmhaShape = FmhaFwdShape; - constexpr ck_tile::index_t occupancy = - (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); - - constexpr auto kBiasEnum = kHasBias - ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS - : ck_tile::BlockAttentionBiasEnum::NO_BIAS; - - constexpr bool kPadSeqLenQ = true; - constexpr bool kPadSeqLenK = true; - - bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); - bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); - const bool use_async_pipeline = - (!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) && - (MaxK <= 128)); - - if (!use_async_pipeline) { - BOOL_SWITCH_2( - pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { - using FmhaTraits = ck_tile::TileFmhaTraits< - kPadSeqLenQ, - kPadSeqLenK, - kPadHeadDimQ, - kPadHeadDimV, - kBiasEnum, - false, // kHasBiasGrad place-holder - false, // kStoreLSE - kHasDropout, - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaPipelineProblem = - FmhaPipelineProblemTemp; - - using FmhaPipeline = - ck_tile::BlockFmhaPipelineQRKSVS; - - using FmhaEpilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - kPadSeqLenQ, - kPadHeadDimV>>; - - if (param.seqlen_k_dev_ptr != - nullptr) { // seqlen_k of batches are padded - using FmhaTilePartitioner = - ck_tile::FmhaFwdTilePartitioner_HBS; - using FmhaKernel = ck_tile::FmhaFwdKernel< - FmhaTilePartitioner, - FmhaPipeline, - FmhaEpilogue>; - - RunWithKernel(param, stream); - } else { - using FmhaTilePartitioner = - ck_tile::FmhaFwdTilePartitioner_SHB; - using FmhaKernel = ck_tile::FmhaFwdKernel< - FmhaTilePartitioner, - FmhaPipeline, - FmhaEpilogue>; - - RunWithKernel(param, stream); - } - }); - } else { - using FmhaTraits = ck_tile::TileFmhaTraits< - true, // kPadSeqLenQ, - kPadSeqLenK, - true, // kPadHeadDimQ, - true, // kPadHeadDimV, - kBiasEnum, - false, // kHasBiasGrad place-holder - false, // kStoreLSE - kHasDropout, - false, // kDoFp8StaticQuant place-holder - occupancy>; - - using FmhaPipelineProblem = - FmhaPipelineProblemTemp; - - using FmhaPipeline = - ck_tile::BlockFmhaPipelineQRKSVSAsync; - - using FmhaEpilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - true, - true>>; - - if (param.seqlen_k_dev_ptr != - nullptr) { // seqlen_k of batches are padded - using FmhaTilePartitioner = - ck_tile::FmhaFwdTilePartitioner_HBS; - using FmhaKernel = ck_tile:: - FmhaFwdKernel; - - RunWithKernel(param, stream); - } else { - using FmhaTilePartitioner = - ck_tile::FmhaFwdTilePartitioner_SHB; - using FmhaKernel = ck_tile:: - FmhaFwdKernel; - - RunWithKernel(param, stream); - } - } - }); - }; - - template - static void RunWithKernel(GroupedForwardParams& param, hipStream_t stream) { - const auto kargs = [&] { - return FmhaKernel::MakeKargs( - param.q_ptr, - param.k_ptr, - param.v_ptr, - param.attn_bias_ptr, - nullptr, // rand_val_ptr - nullptr, // lse_ptr - param.out_ptr, - param.seqstart_q_dev_ptr, - param.seqstart_k_dev_ptr, - param.seqlen_k_dev_ptr, - param.K, // hdim_q - param.Kv, // hdim_v - param.Hq, // nhead_q - param.Hq / param.Hkv, // nhead_ratio_qk - param.scale, - 1.0f, // scale_p - 1.0f, // scale_o - param.q_strides[0], // q, k, v, bias, randval, out tensor seq-dim - // stride - param.k_strides[0], - param.v_strides[0], - param.attn_bias_strides[2], - 0, // stride_randval - param.out_strides[0], - param.q_strides[1], // q, k, v, bias, randval, lse, out tensor - // head-dim stride - param.k_strides[1], - param.v_strides[1], - param.attn_bias_strides[1], - 0, // nhead_stride_randval - 0, // nhead_stride_lse - param.out_strides[1], - (param.window_size > 0) ? param.window_size - 1 - : -1, // window_left_size - (param.custom_mask_type == 0) ? -1 : 0, // window_right_size - param.custom_mask_type, - param.dropout_prob, - false, // is_store_randval - {param.philox_seed, param.philox_offset}); - }(); - - dim3 kGridSize = FmhaKernel::GridSize( - param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); - constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; - - (void)ck_tile::launch_kernel( - ck_tile::stream_config{stream, false}, - ck_tile::make_kernel( - FmhaKernel{}, kGridSize, kBlockSize, 0, kargs)); - }; -}; +#include "ck_tiled_fmha_fwd_splitkv_smallq_selector.h" +#include "ck_tiled_fmha_grouped_infer_dispatch.h" +#include "ck_tiled_fmha_grouped_infer_splitkv_dispatch.h" +#include "ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h" +#include "ck_tiled_fmha_seqlen_q_switch.h" template < typename ScalarType, - bool kHasCausalMask, + bool kHasMask, bool kHasBias, bool kHasDropout, ck_tile::index_t MaxK> -void run_grouped_infer_causalmask_bias_dropout_dispatch( +void run_grouped_infer_mask_bias_dropout_dispatch( GroupedForwardParams& param, hipStream_t stream) { - grouped_infer_causalmask_bias_dropout_dispatch< - ScalarType, - kHasCausalMask, - kHasBias, - kHasDropout, - MaxK>::Run(param, stream); + // currently split-kv implementation does not support dropout + if constexpr (!kHasDropout) { + if (param.use_split_kv) { + if (use_splitkv_smallq(param.max_seqlen_q, std::max(param.K, param.Kv))) { + grouped_infer_splitkv_smallq_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK>::Run(param, stream); + } else { + FMHA_FWD_SEQLEN_Q_SWITCH(param.max_seqlen_q, MaxSeqlenQ, [&] { + grouped_infer_splitkv_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + MaxK, + MaxSeqlenQ>::Run(param, stream); + }); + } + } else { + if (get_fmha_fwd_mtile(param.num_batches, param.Hq, param.max_seqlen_q) == + 128) + grouped_infer_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + else + grouped_infer_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 64>::Run(param, stream); + } + } else { + // at present, dropout of fwd kernel requires 32x32 WarpTile + grouped_infer_mask_bias_dropout_dispatch< + ScalarType, + kHasMask, + kHasBias, + kHasDropout, + MaxK, + 128>::Run(param, stream); + } }; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp index 090227c1db..e740b7308b 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bf16.cpp @@ -16,15 +16,17 @@ void grouped_infer_bf16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_grouped_infer_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_grouped_infer_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h new file mode 100644 index 0000000000..6cda6e8233 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_dispatch.h @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_setting.h" +#include "ck_tiled_fmha_params.h" +#include "ck_tiled_headdim_switch.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + bool kHasDropout, + ck_tile::index_t MaxK, + ck_tile::index_t MTile> +struct grouped_infer_mask_bias_dropout_dispatch { + template + using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + typename FmhaFwdShape::Type, + true, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaShape = typename FmhaFwdShape::Type; + constexpr ck_tile::index_t occupancy = + (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + bool pad_headdim_q = !(param.K % FmhaShape::kSubQKHeaddim == 0); + bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); + const bool use_async_pipeline = + (!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) && + (MaxK <= 128)); + + if (!use_async_pipeline) { + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck_tile::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + kHasDropout, + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaPipelineQRKSVS; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = + ck_tile::FmhaFwdKernel; + + RunWithKernel(param, stream); + }); + } else { + using FmhaTraits = ck_tile::TileFmhaTraits< + true, // kPadSeqLenQ, + kPadSeqLenK, + true, // kPadHeadDimQ, + true, // kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + kHasDropout, + false, // kDoFp8StaticQuant place-holder + occupancy>; + + using FmhaPipelineProblem = FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaPipelineQRKSVSAsync; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, + true>>; + + using FmhaKernel = ck_tile::FmhaFwdKernel; + + RunWithKernel(param, stream); + } + }; + + template + static void RunWithKernel(GroupedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // rand_val_ptr + nullptr, // lse_ptr + param.out_ptr, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + 1.0f, // scale_p + 1.0f, // scale_o + param.q_strides[0], // q, k, v, bias, randval, out tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + 0, // stride_randval + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, randval, lse, out tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + 0, // nhead_stride_randval + 0, // nhead_stride_lse + param.out_strides[1], + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type, + param.dropout_prob, + false, // is_store_randval + std::make_pair(param.philox_seed, param.philox_offset)); + }(); + + dim3 kGridSize = FmhaKernel::GridSize( + param.num_batches, + param.Hq, + param.max_seqlen_q, + param.Kv, + param.seqlen_k_dev_ptr != nullptr); + constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp index 62c774ff59..fd0110cb96 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp @@ -16,15 +16,17 @@ void grouped_infer_fp16(GroupedForwardParams& param, hipStream_t stream) { const bool has_dropout = (param.dropout_prob > 0.0f); BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] { FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { - if (param.custom_mask_type == 0) - run_grouped_infer_causalmask_bias_dropout_dispatch< + if (param.custom_mask_type == 0 && param.window_size <= 0) + run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, kHasBias, kHasDropout, MaxK>(param, stream); - else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) - run_grouped_infer_causalmask_bias_dropout_dispatch< + else if ( + param.custom_mask_type == 1 || param.custom_mask_type == 2 || + param.window_size > 0) + run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, kHasBias, diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_dispatch.h new file mode 100644 index 0000000000..2c0160f3ae --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_dispatch.h @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK, + ck_tile::index_t MaxSeqlenQ> +struct grouped_infer_splitkv_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVShape::Type, + true, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + true, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + + bool is_paged_kv = param.use_paged_kvcache; + + BOOL_SWITCH_3( + pad_headdim_q, + kPadHeadDimQ, + pad_headdim_v, + kPadHeadDimV, + is_paged_kv, + kIsPagedKV, + [&] { + if (param.num_kv_splits > 1) { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kIsPagedKV, + true, // kHasUnevenSplits + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kIsPagedKV, + true, // kHasUnevenSplits + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + }; + + if (param.num_kv_splits > 1) { + using FmhaTileShape = + typename FmhaFwdSplitKVShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr bool kPadSeqLenQ = true; + + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH(pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH(param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + param.use_paged_kvcache ? param.block_table_ptr : nullptr, + param.use_paged_kvcache ? param.batch_stride_block_table : 0, + param.use_paged_kvcache ? param.page_block_size : 0, + param.use_paged_kvcache ? param.is_gappy : false, + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_acc_strides[1], + param.q_strides[1], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.lse_acc_strides[1], + param.out_acc_strides[2], + param.use_paged_kvcache ? param.k_strides[0] * param.page_block_size + : 0, // batch_stride_k + param.use_paged_kvcache ? param.v_strides[0] * param.page_block_size + : 0, // batch_stride_v + param.lse_acc_strides[0], // split_stride_l + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // lse_ptr, + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + param.use_paged_kvcache ? param.block_table_ptr : nullptr, + param.use_paged_kvcache ? param.batch_stride_block_table : 0, + param.use_paged_kvcache ? param.page_block_size : 0, + param.use_paged_kvcache ? param.is_gappy : false, + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, lse, out tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + 0, // nhead_stride_lse + param.out_strides[1], + param.use_paged_kvcache ? param.k_strides[0] * param.page_block_size + : 0, // batch_stride_k + param.use_paged_kvcache ? param.v_strides[0] * param.page_block_size + : 0, // batch_stride_v + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.num_batches, + param.Hq, + param.max_seqlen_q, + param.Kv, + param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + nullptr, // lse_ptr, not used + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[1], // row_stride_o_acc, + param.out_strides[0], // row_stride_o, + param.lse_acc_strides[1], // nhead_stride_lse_acc + param.out_acc_strides[2], // nhead_stride_o_acc, + 0, // nhead_stride_lse, + param.out_strides[1], // nhead_stride_o, + param.lse_acc_strides[0], // split_stride_lse_acc, + param.out_acc_strides[0]); // split_stride_o_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h new file mode 100644 index 0000000000..916c2ab11e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_fwd_splitkv_smallq_setting.h" +#include "ck_tiled_fmha_num_kv_split_switch.h" +#include "ck_tiled_fmha_params.h" + +template < + typename ScalarType, + bool kHasMask, + bool kHasBias, + ck_tile::index_t MaxK> +struct grouped_infer_splitkv_smallq_mask_bias_dropout_dispatch { + template < + typename FmhaFwdSplitKVTraits, + typename FmhaMask, + typename ODataType> + using FmhaFwdSplitKVPipelineProblemTemp = + ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + ODataType, + typename FmhaFwdSplitKVSmallQShape::Type, + true, // kIsGroupMode + FmhaMask, + FmhaFwdSplitKVTraits>; + + template + using FmhaSplitKVCombinePipelineProblemTemp = + ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + MaxK, // headdim_v + true, // kIsGroupMode + kN1, + FmhaSplitKVCombineTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + { + using FmhaMask = ck_tile::SimplifiedGenericAttentionMask; + + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr auto kBiasEnum = kHasBias + ? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS + : ck_tile::BlockAttentionBiasEnum::NO_BIAS; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0); + bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0); + + bool is_paged_kv = param.use_paged_kvcache; + + BOOL_SWITCH_3( + pad_headdim_q, + kPadHeadDimQ, + pad_headdim_v, + kPadHeadDimV, + is_paged_kv, + kIsPagedKV, + [&] { + if (param.num_kv_splits > 1) { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + true, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kIsPagedKV, + true, // kHasUnevenSplits + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::OaccDataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } else { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + kBiasEnum, + false, // kHasBiasGrad place-holder + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kIsPagedKV, + true, // kHasUnevenSplits + occupancy>; + + using ODataType = + typename FmhaFwdTypeConfig::ODataType; + using FmhaPipelineProblem = FmhaFwdSplitKVPipelineProblemTemp< + FmhaTraits, + FmhaMask, + ODataType>; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS< + FmhaPipelineProblem>; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + ODataType, + false, + false>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + }); + }; + + if (param.num_kv_splits > 1) { + using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape::Type; + + constexpr ck_tile::index_t kN1 = 32; + constexpr ck_tile::index_t kM0 = + ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes< + typename FmhaFwdTypeConfig::OaccDataType, + kN1>::kM0; + + constexpr ck_tile::index_t occupancy = -1; + + constexpr bool kPadSeqLenQ = true; + + const bool pad_headdim_v = !(param.Kv % kN1 == 0); + + BOOL_SWITCH(pad_headdim_v, kPadHeadDimV, [&] { + FMHA_FWD_NUM_KV_SPLITS_SWITCH(param.num_kv_splits, kLogMaxSplits, [&] { + using FmhaTraits = ck_tile::TileFmhaFwdSplitKVCombineTraits< + kPadSeqLenQ, + kPadHeadDimV, + false, // kStoreLSE + false, // kDoFp8StaticQuant place-holder + kLogMaxSplits, + -1>; + + using FmhaPipelineProblem = + FmhaSplitKVCombinePipelineProblemTemp; + + using FmhaPipeline = + ck_tile::BlockFmhaFwdSplitKVCombinePipeline; + + using FmhaEpilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using FmhaKernel = + ck_tile::FmhaFwdSplitKVCombineKernel; + + RunWithSplitKVCombineKernel(param, stream); + }); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + if (param.num_kv_splits > 1) + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_acc_ptr, + param.out_acc_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + param.use_paged_kvcache ? param.block_table_ptr : nullptr, + param.use_paged_kvcache ? param.batch_stride_block_table : 0, + param.use_paged_kvcache ? param.page_block_size : 0, + param.use_paged_kvcache ? param.is_gappy : false, + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out_acc tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_acc_strides[1], + param.q_strides[1], // q, k, v, bias, lse_acc, out_acc tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.lse_acc_strides[1], + param.out_acc_strides[2], + param.use_paged_kvcache ? param.k_strides[0] * param.page_block_size + : 0, // batch_stride_k + param.use_paged_kvcache ? param.v_strides[0] * param.page_block_size + : 0, // batch_stride_v + param.lse_acc_strides[0], // split_stride_l + param.out_acc_strides[0], // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + else + return FmhaFwdSplitKVKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // lse_ptr, + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq, // nhead_q + param.Hq / param.Hkv, // nhead_ratio_qk + param.num_kv_splits, // num_splits + param.use_paged_kvcache ? param.block_table_ptr : nullptr, + param.use_paged_kvcache ? param.batch_stride_block_table : 0, + param.use_paged_kvcache ? param.page_block_size : 0, + param.use_paged_kvcache ? param.is_gappy : false, + param.scale, + 1.0f, // scale_p + param.q_strides[0], // q, k, v, bias, out tensor seq-dim + // stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, lse, out tensor + // head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + 0, // nhead_stride_lse + param.out_strides[1], + param.use_paged_kvcache ? param.k_strides[0] * param.page_block_size + : 0, // batch_stride_k + param.use_paged_kvcache ? param.v_strides[0] * param.page_block_size + : 0, // batch_stride_v + 0, // split_stride_lse_acc + 0, // split_stride_out_acc + (param.window_size > 0) ? param.window_size - 1 + : -1, // window_left_size + (param.custom_mask_type == 0) ? -1 : 0, // window_right_size + param.custom_mask_type); + }(); + + dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize( + param.num_batches, + param.Hq, + param.max_seqlen_q, + param.Kv, + param.num_kv_splits); + constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaFwdSplitKVKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithSplitKVCombineKernel( + GroupedForwardParams& param, + hipStream_t stream) { + const auto kargs = [&] { + return FmhaSplitKVCombineKernel::MakeKargs( + param.logsumexp_acc_ptr, + param.out_acc_ptr, + nullptr, // lse_ptr, not used + param.out_ptr, + param.num_batches, + param.seqstart_q_dev_ptr, + param.Kv, + param.num_kv_splits, + 1.0f, + param.out_acc_strides[1], // row_stride_o_acc, + param.out_strides[0], // row_stride_o, + param.lse_acc_strides[1], // nhead_stride_lse_acc + param.out_acc_strides[2], // nhead_stride_o_acc, + 0, // nhead_stride_lse, + param.out_strides[1], // nhead_stride_o, + param.lse_acc_strides[0], // split_stride_lse_acc, + param.out_acc_strides[0]); // split_stride_o_acc + }(); + + dim3 kGridSize = FmhaSplitKVCombineKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); + constexpr dim3 kBlockSize = FmhaSplitKVCombineKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = + FmhaSplitKVCombineKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + FmhaSplitKVCombineKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_num_kv_split_switch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_num_kv_split_switch.h new file mode 100644 index 0000000000..db9a1afbc4 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_num_kv_split_switch.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#define FMHA_FWD_NUM_KV_SPLITS_SWITCH(NUM_SPLITS, CONST_NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 8) { \ + constexpr ck_tile::index_t CONST_NAME = 3; \ + __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 16) { \ + constexpr ck_tile::index_t CONST_NAME = 4; \ + __VA_ARGS__(); \ + } else { \ + throw std::runtime_error("num-splits not supported!"); \ + } \ + }() diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h index ce86f6df40..67f0afdf19 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h @@ -51,6 +51,20 @@ struct BatchedForwardParams : public BatchedInferParams { // completely contiguous void* logsumexp_ptr; + + // used by the splitkv forward kernel + int num_kv_splits; + + bool use_split_kv; + + // PBHM mode strides, completely contiguous + std::array lse_acc_strides; + + // PBMHK mode strides + std::array out_acc_strides; + + void* logsumexp_acc_ptr; + void* out_acc_ptr; }; struct GroupedInferParams { @@ -89,10 +103,15 @@ struct GroupedInferParams { int window_size; // local-attention void* out_ptr; + + bool use_paged_kvcache; + bool is_gappy; + void* block_table_ptr; + int page_block_size; + int batch_stride_block_table; }; struct GroupedForwardParams : public GroupedInferParams { - bool use_dropout; bool compute_logsumexp; float dropout_prob; @@ -105,6 +124,21 @@ struct GroupedForwardParams : public GroupedInferParams { // completely contiguous void* logsumexp_ptr; + + // used by the splitkv forward kernel + int num_kv_splits; + + bool use_split_kv; + + // PHM mode strides, completely contiguous, unpadded layout where M is + // concatten total seqlen_q for all batches + std::array lse_acc_strides; + + // PMHK mode strides, last-dim contiguous + std::array out_acc_strides; + + void* logsumexp_acc_ptr; + void* out_acc_ptr; }; struct BatchedBackwardParams { diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_seqlen_q_switch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_seqlen_q_switch.h new file mode 100644 index 0000000000..c8356a0a89 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_seqlen_q_switch.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#define FMHA_FWD_SEQLEN_Q_SWITCH(SEQLEN_Q, CONST_NAME, ...) \ + [&] { \ + if (SEQLEN_Q <= 32) { \ + constexpr ck_tile::index_t CONST_NAME = 32; \ + __VA_ARGS__(); \ + } else { \ + constexpr ck_tile::index_t CONST_NAME = 64; \ + __VA_ARGS__(); \ + } \ + }() diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h index ce99023c94..498e17f91d 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h @@ -23,6 +23,9 @@ } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ constexpr ck_tile::index_t CONST_NAME = 64; \ __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 96 && HEAD_DIM2 <= 96) { \ + constexpr ck_tile::index_t CONST_NAME = 96; \ + __VA_ARGS__(); \ } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ constexpr ck_tile::index_t CONST_NAME = 128; \ __VA_ARGS__(); \ @@ -39,6 +42,9 @@ } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ constexpr ck_tile::index_t CONST_NAME = 64; \ __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 96 && HEAD_DIM2 <= 96) { \ + constexpr ck_tile::index_t CONST_NAME = 96; \ + __VA_ARGS__(); \ } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ constexpr ck_tile::index_t CONST_NAME = 128; \ __VA_ARGS__(); \ @@ -57,6 +63,9 @@ } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ constexpr ck_tile::index_t CONST_NAME = 64; \ __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 96 && HEAD_DIM2 <= 96) { \ + constexpr ck_tile::index_t CONST_NAME = 96; \ + __VA_ARGS__(); \ } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ constexpr ck_tile::index_t CONST_NAME = 128; \ __VA_ARGS__(); \ @@ -76,6 +85,9 @@ } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ constexpr ck_tile::index_t CONST_NAME = 64; \ __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 96 && HEAD_DIM2 <= 96) { \ + constexpr ck_tile::index_t CONST_NAME = 96; \ + __VA_ARGS__(); \ } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ constexpr ck_tile::index_t CONST_NAME = 128; \ __VA_ARGS__(); \ diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h index 715d5e4bdf..801960a432 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h @@ -7,27 +7,30 @@ #include #include #include -#include - -template < - ck_tile::index_t MPerBlockTile, - ck_tile::index_t NPerBlockTile, - ck_tile::index_t KPerBlockTile, - typename RandValOutputDataType, - bool kIsGroupMode> +#include + +template struct FmhaRandUniformKernel { - static constexpr ck_tile::index_t kBlockSize = 256; + using BlockTile = ck_tile::sequence<128, 64, 32>; + using WarpTile = ck_tile::sequence<32, 32, 8>; + using BlockWarps = ck_tile::sequence<4, 1, 1>; + + using BlockGemmTileShape = + ck_tile::TileGemmShape; + + static constexpr ck_tile::index_t kBlockSize = + BlockGemmTileShape::NumWarps * ck_tile::get_warp_size(); static constexpr ck_tile::index_t kBlockPerCu = 1; __device__ static constexpr auto GetBlockGemm() { using namespace ck_tile; - using BlockGemmProblem_ = ck_tile::BlockGemmPipelineProblem< + using BlockGemmProblem_ = ck_tile::BlockGemmProblem< ck_tile::fp16_t, ck_tile::fp16_t, float, kBlockSize, - ck_tile::TileGemmShape>; + BlockGemmTileShape>; // using the default policy, which use M32xN32xK8 warp_tile return ck_tile::BlockGemmARegBSmemCRegV2{}; diff --git a/xformers/csrc/attention/hip_fmha/generate_instances.py b/xformers/csrc/attention/hip_fmha/generate_instances.py index 53dd8143c2..d769b8b358 100644 --- a/xformers/csrc/attention/hip_fmha/generate_instances.py +++ b/xformers/csrc/attention/hip_fmha/generate_instances.py @@ -27,16 +27,16 @@ """ FMHA_INFER_INSTANCE_TEMPLATE = """ -{extern}template void run_{mode}_infer_causalmask_bias_dropout_dispatch< +{extern}template void run_{mode}_infer_mask_bias_dropout_dispatch< {dtype}, - {has_causalmask}, + {has_mask}, {has_bias}, {has_dropout}, {max_k}>({cap_mode}ForwardParams& param, hipStream_t stream); """ FMHA_INFER_INSTANCE_FNAME = ( - "fmha_{mode}_infer_{dtype_str}_{has_or_no_causalmask_str}_" + "fmha_{mode}_infer_{dtype_str}_{has_or_no_mask_str}_" "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" ) @@ -46,16 +46,16 @@ """ FMHA_FORWARD_INSTANCE_TEMPLATE = """ -{extern}template void run_{mode}_forward_causalmask_bias_dropout_dispatch< +{extern}template void run_{mode}_forward_mask_bias_dropout_dispatch< {dtype}, - {has_causalmask}, + {has_mask}, {has_bias}, {has_dropout}, {max_k}>({cap_mode}ForwardParams& param, hipStream_t stream); """ FMHA_FORWARD_INSTANCE_FNAME = ( - "fmha_{mode}_forward_{dtype_str}_{has_or_no_causalmask_str}_" + "fmha_{mode}_forward_{dtype_str}_{has_or_no_mask_str}_" "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" ) @@ -65,9 +65,9 @@ """ FMHA_BACKWARD_INSTANCE_TEMPLATE = """ -{extern}template void run_{mode}_backward_causalmask_bias_dropout_dispatch< +{extern}template void run_{mode}_backward_mask_bias_dropout_dispatch< {dtype}, - {has_causalmask}, + {has_mask}, {has_bias}, {has_bias_grad}, {has_dropout}, @@ -75,7 +75,7 @@ """ FMHA_BACKWARD_INSTANCE_FNAME = ( - "fmha_{mode}_backward_{dtype_str}_{has_or_no_causalmask_str}_" + "fmha_{mode}_backward_{dtype_str}_{has_or_no_mask_str}_" "{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" ) @@ -83,9 +83,9 @@ BOOL_MAP = {True: "true", False: "false"} -BOOL_MAP_CAUSALMASK = { - True: "has_causalmask", - False: "no_causalmask", +BOOL_MAP_MASK = { + True: "has_mask", + False: "no_mask", } BOOL_MAP_BIAS = { @@ -106,6 +106,7 @@ INT_MAP_MAX_K = { 32: "maxk_32", 64: "maxk_64", + 96: "maxk_96", 128: "maxk_128", 256: "maxk_256", } @@ -129,16 +130,14 @@ def create_infer_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: - for has_causalmask in [True, False]: + for has_mask in [True, False]: for has_bias in [True, False]: for has_dropout in [True, False]: for max_k in headdims: fname = FMHA_INFER_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, - has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[ - has_causalmask - ], + has_or_no_mask_str=BOOL_MAP_MASK[has_mask], has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], @@ -153,7 +152,7 @@ def create_infer_instances(instance_dir: Path, headdims: List) -> None: extern="", mode=mode, dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], + has_mask=BOOL_MAP[has_mask], has_bias=BOOL_MAP[has_bias], has_dropout=BOOL_MAP[has_dropout], max_k=max_k, @@ -185,12 +184,12 @@ def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None: for max_k in headdims: for has_bias in [True, False]: for has_dropout in [True, False]: - for has_causalmask in [True, False]: + for has_mask in [True, False]: infer_instance = FMHA_INFER_INSTANCE_TEMPLATE.format( extern="extern ", mode=mode, dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], + has_mask=BOOL_MAP[has_mask], has_bias=BOOL_MAP[has_bias], has_dropout=BOOL_MAP[has_dropout], max_k=max_k, @@ -202,16 +201,14 @@ def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None: def create_forward_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: - for has_causalmask in [True, False]: + for has_mask in [True, False]: for has_bias in [True, False]: for has_dropout in [True, False]: for max_k in headdims: fname = FMHA_FORWARD_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, - has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[ - has_causalmask - ], + has_or_no_mask_str=BOOL_MAP_MASK[has_mask], has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], @@ -226,7 +223,7 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None: extern="", mode=mode, dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], + has_mask=BOOL_MAP[has_mask], has_bias=BOOL_MAP[has_bias], has_dropout=BOOL_MAP[has_dropout], max_k=max_k, @@ -258,13 +255,13 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None: for max_k in headdims: for has_bias in [True, False]: for has_dropout in [True, False]: - for has_causalmask in [True, False]: + for has_mask in [True, False]: forward_instance = ( FMHA_FORWARD_INSTANCE_TEMPLATE.format( extern="extern ", mode=mode, dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], + has_mask=BOOL_MAP[has_mask], has_bias=BOOL_MAP[has_bias], has_dropout=BOOL_MAP[has_dropout], max_k=max_k, @@ -277,7 +274,7 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None: def create_backward_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: - for has_causalmask in [True, False]: + for has_mask in [True, False]: for has_bias, has_bias_grad in [ [True, False], [True, True], @@ -288,9 +285,7 @@ def create_backward_instances(instance_dir: Path, headdims: List) -> None: fname = FMHA_BACKWARD_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, - has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[ - has_causalmask - ], + has_or_no_mask_str=BOOL_MAP_MASK[has_mask], has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], has_or_no_biasgrad_str=BOOL_MAP_BIASGRAD[has_bias_grad], has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], @@ -306,7 +301,7 @@ def create_backward_instances(instance_dir: Path, headdims: List) -> None: extern="", mode=mode, dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], + has_mask=BOOL_MAP[has_mask], has_bias=BOOL_MAP[has_bias], has_bias_grad=BOOL_MAP[has_bias_grad], has_dropout=BOOL_MAP[has_dropout], @@ -343,13 +338,13 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: [False, False], ]: for has_dropout in [True, False]: - for has_causalmask in [True, False]: + for has_mask in [True, False]: backward_instance = ( FMHA_BACKWARD_INSTANCE_TEMPLATE.format( extern="extern ", mode=mode, dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], + has_mask=BOOL_MAP[has_mask], has_bias=BOOL_MAP[has_bias], has_bias_grad=BOOL_MAP[has_bias_grad], has_dropout=BOOL_MAP[has_dropout], @@ -368,9 +363,11 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: disable_hd256 = True if disable_hd256: - headdims = [32, 64, 128] + headdims_fwd = [32, 64, 96, 128] + headdims_bwd = [32, 64, 96, 128] else: - headdims = [32, 64, 128, 256] + headdims_fwd = [32, 64, 96, 128, 256] + headdims_bwd = [32, 64, 96, 128, 256] this_dir = os.path.dirname(__file__) output_dir = Path(this_dir) / "instances" @@ -382,9 +379,9 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: file_path = os.path.join(output_dir, ff) os.remove(file_path) - create_infer_instances(output_dir, headdims) - create_infer_instances_ref(output_dir, headdims) - create_forward_instances(output_dir, headdims) - create_forward_instances_ref(output_dir, headdims) - create_backward_instances(output_dir, headdims) - create_backward_instances_ref(output_dir, headdims) + create_infer_instances(output_dir, headdims_fwd) + create_infer_instances_ref(output_dir, headdims_fwd) + create_forward_instances(output_dir, headdims_fwd) + create_forward_instances_ref(output_dir, headdims_fwd) + create_backward_instances(output_dir, headdims_bwd) + create_backward_instances_ref(output_dir, headdims_bwd) diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index b129b07194..d6b447d173 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index 58aaac8016..c319629872 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 73360d7dc6..6161fc4ae4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 7f99b48199..08c3ec38a2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..12c1aa463c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index b5b258196e..8bea77809d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, - false, true, + false, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index 1829f50f2d..5ed35bbef6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 74501e0072..672d36fe11 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 573d9bf4b8..b70134c681 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, - false, true, + false, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..e2301db5ec --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 8689b5389f..c132e77e64 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, + false, true, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index 070e8b2c0b..aac5a1aaf8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 2a5977be38..a4d5950050 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, + false, true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 62a1c9d0b5..aa88585bc2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, - true, false, + true, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..3e99fd87db --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index f4f3ac89c2..8c95d9392c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, false, + false, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index 4067c8e5ac..25e054c6ce 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, false, + false, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index e6b8fd85f2..cec2dec8bf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, + true, false, false, - true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index d8fd52d7aa..fe59c183f4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, false, + false, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..9c1dd943e7 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 67bf8995c8..7603478867 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, - true, false, false, + true, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index 71b1586ac3..a085a7ab08 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, true, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index c3dd3d5fe3..1e0a77cfd4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, - true, true, false, + false, + true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 4c2c0672ea..ec28f459b8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..aefdd4d6af --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 68bac14f28..d580e1549e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index 2a72588f19..6a2ffe01cf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index db6ee679cb..2fbc707a50 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, false, + false, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 2028826784..8a8ac48042 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..ddd9e4ff7e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h index 06f82124ae..607048cbad 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_instances_ref.h @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -19,7 +19,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -27,7 +27,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -35,7 +35,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -43,7 +43,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -51,7 +51,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -59,7 +59,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -67,7 +67,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -75,7 +75,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -83,7 +83,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -91,7 +91,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -99,7 +99,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -107,7 +107,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -115,7 +115,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -123,7 +123,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -131,7 +131,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -139,7 +139,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -147,7 +147,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -155,7 +155,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -163,7 +163,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -171,7 +171,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -179,7 +179,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -187,7 +187,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -195,7 +195,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -203,7 +203,103 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -211,7 +307,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -219,7 +315,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -227,7 +323,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -235,7 +331,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -243,7 +339,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -251,7 +347,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -259,7 +355,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -267,7 +363,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -275,7 +371,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -283,7 +379,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -291,7 +387,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -299,7 +395,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -307,7 +403,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -315,7 +411,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -323,7 +419,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -331,7 +427,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -339,7 +435,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -347,7 +443,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -355,7 +451,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -363,7 +459,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -371,7 +467,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -379,7 +475,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -387,7 +483,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index b831c919df..6901b50c17 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, true, - false, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index fd52bcc4de..efa38d5329 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 504c22609f..0d21552eee 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, - false, true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 490659b74c..8366fe3350 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..f57bb62706 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 0d902e1203..b481351c79 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, - false, false, true, + true, + false, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index 4bc3b5a836..470a8ee444 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, false, - false, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 331b791409..1a58c63720 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, false, - false, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index a820ad76c3..f5c4d3df3b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, - false, true, + false, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..2e8451901e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index f9e140aaef..8d3e5e0ad2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index 13dfd5a096..69492777b2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, + true, false, true, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 5688539e83..b25b805768 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 1c3a956d4c..1f8ac812df 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, - true, false, + true, false, + true, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..247dd491cf --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index fbd6b8b48b..d66ebd7d54 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index b64b16b8da..f71f0a98fd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index ea7baeea2c..3d001ec57c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, + true, false, false, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index e79dd63df7..4ffb7f4193 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cf9da51fd8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 35a9684053..e0e5c1093b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index 14d9356112..cb039bd893 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 783c741b66..e988f88a63 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 7ddd65d116..6d4f8e8832 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..7bc8fbb70e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 69e6983446..b40590e752 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index 5fa39c8804..9e543ce456 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index fed439c709..d4b4d3d25a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 6a955e9821..78d157c8b0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..c26216d39e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index b4df2bf407..80f5cbafaa 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index 545a779553..e09b3ada17 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 1da7bae3a8..c7bb811828 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 4c3cf7ff66..3184149372 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..fe54bed624 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 1cbafbf70d..4285510a6a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index cf89aa7bd8..86410bafac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, - false, true, + false, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 8b498600a2..2c91e6152a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, true, true, + false, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 75fef6ab41..8855ffd887 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cc4e57f2d3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 836e9428ee..2d98de9388 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index f1e9009d1a..89b21aa7c6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, - true, false, + true, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index bbc4eea829..648a99f443 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 2d804bd5df..fc4e72b84a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..6c25ae5b80 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index bdf72b91aa..e77b97fd84 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index 2588185d9e..304bdea6ac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 087b8e1c80..2aaaa250bb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index d01cb1e375..82cf516785 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..744858265b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index d1bdf1fa57..71f2f421e0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index b8c8eb5b31..8b84758423 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 60553e4057..70ceb95945 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index dafd1d5d2b..54a97cc2c1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..0b5415c041 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 99a2823b48..217d876bcc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index acceefffbd..303b93b077 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index ac3a2a5fdb..74d455fff4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 5a281913f3..2783b3be1d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..11f72a7b4c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h index d47f8cc1ec..1655e42ce5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_instances_ref.h @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -19,7 +19,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -27,7 +27,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -35,7 +35,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -43,7 +43,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -51,7 +51,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -59,7 +59,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -67,7 +67,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -75,7 +75,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -83,7 +83,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -91,7 +91,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -99,7 +99,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -107,7 +107,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 32>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -115,7 +115,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -123,7 +123,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -131,7 +131,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -139,7 +139,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -147,7 +147,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -155,7 +155,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -163,7 +163,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -171,7 +171,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -179,7 +179,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -187,7 +187,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -195,7 +195,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -203,7 +203,103 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 64>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); + +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -211,7 +307,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -219,7 +315,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -227,7 +323,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -235,7 +331,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -243,7 +339,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -251,7 +347,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -259,7 +355,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -267,7 +363,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -275,7 +371,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -283,7 +379,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -291,7 +387,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -299,7 +395,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 128>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -307,7 +403,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -315,7 +411,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -323,7 +419,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -331,7 +427,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -339,7 +435,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -347,7 +443,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -355,7 +451,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -363,7 +459,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -371,7 +467,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -379,7 +475,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< true, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -387,7 +483,7 @@ extern template void run_batched_backward_causalmask_bias_dropout_dispatch< false, 256>(BatchedBackwardParams& param, hipStream_t stream); -extern template void run_batched_backward_causalmask_bias_dropout_dispatch< +extern template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 68ffee4bf8..6748c1b011 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index 4d84693d6f..ecc6392b9c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 9511965063..c9280ecea9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, + false, true, true, true, - false, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 7ddd6efd88..4a3fb67186 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..f54fd36354 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index dd6ef7d002..110394c34d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index daee392159..161304b8ed 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index dc19712620..6ec124e26a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index e9c8d75e34..8d8fa202e0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..29c9fb6a4c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 3b85cea79a..671d37710a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 128>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index f261d64baf..6ba00de55c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 256>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 635f9f1a23..367d9f6e26 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 32>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 919a01fb9c..643f6ad5bd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 64>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..4832c97990 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index bc25646dca..3712d8cd6a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index a324ea3d19..ad905cbdf9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 8ffe3a4c36..777bef0160 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 0d3ab043e3..b748de7b95 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..dbb567a280 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 64c0c14fb6..d76eae7cff 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index 2d0e3efaaf..37ded4ac11 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 003201abf5..0cfc315f8b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index a6570b6bfc..2e95e9082f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..f1d3f39d00 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index a23a7087d1..4a65054c8f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index 274405d533..fb57f88653 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 46a8e8a4d4..3cb6b9d3e2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 5bdd29dbdb..53052e40d4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_backward.h" -template void run_batched_backward_causalmask_bias_dropout_dispatch< +template void run_batched_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..494f10a720 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_backward.h" + +template void run_batched_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 96>(BatchedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp index 189677f41b..a60963f802 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp index 39881bd0de..cfe158f63e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp index a24b8868a8..f83330c354 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp index 849a6633b5..d218b55775 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..1ab50df932 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp index c49a96edbb..88664056e3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp index 4e3144c61e..52327df1ce 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, - false, true, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp index b949c55579..e7576d0c4c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp index c485fdfcd0..eeaf62d6fc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..ae7317559d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp index 68345b50d9..a1544c50a6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp index f362ff83b1..565a51e164 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, - true, false, + true, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp index 1654eb5354..5a33c64489 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp index fef0b43b9e..40bfebada2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..96287c4882 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp index eae1bef147..8e071fc747 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp index 3fea67a9df..406c49d6d1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp index d2eeed0208..0bf56df8c0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp index 0b5b5e9acd..83ba77748f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..43a36ce652 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h index 8fab725be7..dd1a636a6b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_batched_forward.h" -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp index b0898e658f..967c68daa5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp index aee8358c14..3bbc694732 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp index 62205efbdc..f4e5f5eb7f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, - false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp index 3e28448d41..71569c47c6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..fa01afbfb2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp index 20e880ae32..0e385e642b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp index 2d9e145b8a..3375f54543 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp index e9e1d8c03d..4cff079b20 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp index 296c93e84d..489bad0fad 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..0b955693c0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp index 87d8256c23..65d7b902a9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp index 521469e26c..972ad19835 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp index 12c05851be..ea7a9926ab 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp index 77e509f0c5..9111ebbbbc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..5038f0028e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp index ffcd7f0d89..55d50683d0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp index a0fbb353fc..be72e76d24 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp index 729e834bf8..96d9f212de 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp index b2ee36ac21..247d27508f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..8fbe1f0ce6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp index e9c50c43e0..8a22e0a124 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp index 98ad34421e..b523959364 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp index df8cb489a0..3f8d2ea4a0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp index 9ff6b63464..c73e76ba54 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cb6f657839 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp index 8e5fc2b224..3721e1206e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp index 0a32ecd5e4..6449266a26 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, - false, true, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp index 8b10f11921..98a23c5da4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, true, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp index 89b57dc002..c12921f2f4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..3b347a64bd --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp index 286ce1f10a..498c653437 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp index 8489a8255f..fd696a20b2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, - true, false, + true, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp index 5caa44509a..2660e9f956 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp index 7b45b7050a..ffd777b0f3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..03e08c45a1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp index 9b5b928f7f..fe81619104 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp index 1b36a0d252..0fc54fd688 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp index 785ecd397e..ca9c1aeb5f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp index 82199beb7a..bf77caa3a8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..2e56a95123 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h index d697669727..f4fb71af63 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_batched_forward.h" -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_forward_causalmask_bias_dropout_dispatch< +extern template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp index 1af052fb63..fa4ca05fd6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp index 5616cdc520..078fc9a96c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp index 0ab15f4316..722424784f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, + false, true, true, - false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp index 988a2fe2bc..c13355df47 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..63141d2382 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp index e18cda6c98..640a324464 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp index ed23610a9a..b1d2f9261b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp index 2e512e089b..6be825ead5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp index cfd204f045..82b2d2a37d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..518d809847 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp index ea683ccd0a..5ceff03a83 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp index c17397faf7..ec115bde5d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp index 6483bd6da2..e237d7a1d9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp index 607227078c..d22f8e5e7c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..ada24fa386 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp index f161893bda..bf94d16cac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp index c37fb70c92..91f8252bc0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp index f05aca856e..2849c4a01e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp index cd0f3d4ffc..bfb2727b55 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_forward.h" -template void run_batched_forward_causalmask_bias_dropout_dispatch< +template void run_batched_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..b2c4b3fc95 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp index ad22843e37..c969aaddd4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp index a457b90f34..4b5c1722f2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp index 51d21df17c..82155df9dc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp index 0c2a21bf6e..0f037342f1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..4199f8dfc1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp index 4e33efc722..4a02de28c9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp index f3eb7b0ec0..33f3521253 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp index d8db2ebe22..251f3435c7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp index 72e7fb412e..db0bcc4905 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..84d693dcd4 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp index 0b4ed8294a..4964bfa57b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp index 2e752c9418..d1afa4f97b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp index 68366ee2f8..b53ce42583 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp index 9d0c50e134..10fecb0b1f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..9683175ce5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp index 25c006c093..99ecd3f153 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp index 77ab1fc3e3..9fe1f47000 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp index 15311470c6..9cb5037ff4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp index 4c98864b26..688e746c30 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..9d345eb620 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h index 003d768942..a0a632332c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_batched_infer.h" -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp index db28d72f40..384ed6c7d6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp index 228bb5397c..1d14ec3223 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp index d0152e1600..38bb1e4898 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp index 8cb88dd943..9e01187176 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..94a7b0ecf6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp index d20c61ee11..f9eee86a38 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp index 0410708e11..662850493a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp index d837f7b54e..809d7fb2fa 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp index 7462600fb3..2b015348a3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..23badfdcb3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp index 8129cbf852..1eb945d8c9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp index 3d6e897a47..bdae23c5f5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp index c264d95adc..abcd6e5054 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp index fb8e9fb0a5..f91e7d396a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..6633c2a2d8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp index 65d1fd39a7..606f3e51d0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp index c0ea4369af..f37c3155a5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp index b46f0c0c8a..d05287595e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp index 8051de4d96..931c73fb80 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..222818766a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp index c1ee8c7693..48d3a2c3f0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp index 46a38e82df..71e0a40272 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp index 6040d41cd9..2914d3566b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp index db5d5d577c..1dc4f4cefe 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..49089a5a2c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp index ccc0a02543..83ee3847ac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp index d81ff0d38e..f6d3cd1f9a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp index 48b74b2bc4..44e794f26d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp index fda07f6cda..2b8d9371b1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cda89d9882 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp index 43069dd547..b83806efab 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp index bf8afd4242..c22ec1891b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp index 351f5ea1d4..39d5af11cb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp index d06dc1f10c..1333e0e3a2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..c6dd68fbcf --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp index 609b4981c9..a8c94892ae 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp index 5fca4f4eea..37abd037a6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp index fe3a2e2bc3..d45e9747ac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp index d077701b99..4a5b32f1d6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..3aded97795 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h index 266b3643ee..5b63c0083b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_batched_infer.h" -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 32>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 64>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 128>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 256>(BatchedForwardParams& param, hipStream_t stream); -extern template void run_batched_infer_causalmask_bias_dropout_dispatch< +extern template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp index 37f18fd7d1..215574613c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp index dd5ec21185..fda3a851a9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp index 3afe1c2f86..3a461d75b8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp index e9ddc972d5..f5de5ab9fc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..6199c05109 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp index 501a83e9ae..8ca40c295b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp index d0b619f604..9ea1c82aae 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp index af0bc1c85a..7e6fdd12f1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp index 578454c52f..4eeeafdda9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cba6c7eb6e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp index df91366da4..a46736ec72 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp index 4c292918bd..477836c7c5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp index 9dc31e3ea8..81dba703d5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp index 2bbd4f3dd4..92dd14a639 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..c2780682c6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp index d20d225cd6..4488da3605 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp index ce76fd765e..f38d36564e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp index ca44ac6b0d..9025bd9b97 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp index 5d7589a162..8aa5368312 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_batched_infer.h" -template void run_batched_infer_causalmask_bias_dropout_dispatch< +template void run_batched_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..3ef3ae0ad2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_batched_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index c22b793d35..52258dd70d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index f4b7a307aa..f18614fa08 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index c5b1454c5a..ba78d65d3b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index c8c71960df..7258831cee 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..c37c77d554 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index de55b8e88c..bd10c628ac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index 18d1940620..99903f6560 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, - false, true, + false, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 1ba22ae616..fb92ebe6fc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, true, + false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 71ac1de6fa..59249a8b03 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..db4d2ce297 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index f2baaf01df..bbe5fc4a71 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index 577c43def5..91f7af8f29 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, - true, false, + true, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 8e87f044df..33467b58f8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index dbe7c0560f..628ad56249 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..979c39e34a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 295e3f4034..67f3bb857b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, + true, false, false, - true, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index 07b019af4c..5fc15b960f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, false, + false, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 8b878747f6..be106ab035 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index ac1bccc146..1bc566b34a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, false, + false, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..f17c75ecbc --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 42818cfa92..6ab1929abb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, - true, true, false, + false, + true, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index e23b3c60b9..9153f0a6dd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 485b647757..f9d2de3cd8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, - true, true, false, + false, + true, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index dfbcd25bec..02e6479f99 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, true, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..7352541275 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 8650510c3c..cdf8c64d07 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, false, + false, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index b85fa82e9b..ea0cdd8794 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 842c071d96..4b20062e26 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, false, + false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index e8e862d54d..262fe65ae7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..342bccf249 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h index 870b4dda9f..77fd2adfd4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_instances_ref.h @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -19,7 +19,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -27,7 +27,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -35,7 +35,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -43,7 +43,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -51,7 +51,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -59,7 +59,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -67,7 +67,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -75,7 +75,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -83,7 +83,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -91,7 +91,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -99,7 +99,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -107,7 +107,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -115,7 +115,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -123,7 +123,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -131,7 +131,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -139,7 +139,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -147,7 +147,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -155,7 +155,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -163,7 +163,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -171,7 +171,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -179,7 +179,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -187,7 +187,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -195,7 +195,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -203,7 +203,103 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -211,7 +307,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -219,7 +315,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -227,7 +323,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -235,7 +331,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -243,7 +339,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -251,7 +347,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -259,7 +355,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -267,7 +363,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -275,7 +371,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -283,7 +379,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -291,7 +387,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -299,7 +395,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -307,7 +403,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -315,7 +411,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -323,7 +419,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -331,7 +427,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -339,7 +435,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -347,7 +443,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, @@ -355,7 +451,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, @@ -363,7 +459,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -371,7 +467,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, @@ -379,7 +475,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, @@ -387,7 +483,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 76a4e7dcb7..1ec85b39bd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index a4b3c633d0..11e98efd9f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 9ffa70e780..28a019accc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, true, - false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 07813b2c57..ea25b5eaff 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..a5e8ac4541 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 65b67988ae..fb21b6429d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, - false, true, + false, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index 81616d6af3..90046688f1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, - false, true, + false, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 08af2d6677..8bee1bacd7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, - false, false, true, + true, + false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 1871a6cbed..b8a6e10e65 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, false, - false, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..1f0d4e2d28 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 7a293a9735..fb7617cf96 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, - true, false, + true, false, + true, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index dc5f5c749a..649682a521 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, - true, false, + true, false, + true, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 9fc0a6c625..b7ef701393 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 4d2d7e78dd..f043077872 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, + true, false, true, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..7f5cc32bf8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index 43fc95070c..20f2299474 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, + true, false, false, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index 261017c529..0c5b0899d2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 86d8d4776c..a10ed99695 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, + true, false, false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 1bf3602e38..1778c650af 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..7f18e6c0d9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 302c566e73..90eaf9020c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index c3f030c5f3..6041d88106 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 070e741168..f4f4a74a29 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 8011c547d1..723dad8b4d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..725fb3b751 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index 249bf2a54b..a213e1feea 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index 9fed2aefc2..55be37bff0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 224d5f1bc3..8d4e8157c4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index 43fea8dee1..2a11628eaf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..37c739e6d6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_bf16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index dc70813fc6..be282c1692 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index 10ae8c3026..16c1a56335 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 4fdbb099c2..0d126762fb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index e5d4365a19..bba62020d6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..b4973f6d4f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index e028d1bee9..d397432a8b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index ccd459e844..576f4ec43c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, - false, true, + false, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index de40300749..9ec9c32a5e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, true, true, + false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 28fcbfad6e..0e1421f0ba 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..1cfbb64a6a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 34b227fad6..936aceb179 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index 3c47d406b5..2601c44b53 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, - true, false, + true, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 20033dee2f..db40de8e14 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index c9dece923d..520aef06c0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..e11bd53369 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index c0d222f058..db1a8fe044 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index 8d32e0b35b..9a7ae39f16 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index fe11f7f00e..57b874c858 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index 45ba2ddd3e..c542a2c255 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, + true, false, false, - true, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..1d22178487 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 7c5978f3fe..a4f08bb7be 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index 1dd5dfa0f7..9d24093276 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 69ebd58335..3596811967 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 3218e1606b..a958635127 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, - true, true, false, + false, + true, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..792825647a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index e8e20cb4d5..7fb1932394 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index 81668563ec..a81fe6db2e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 1961a1a295..e4940345d3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index ba07be603b..dad5ec5274 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..c0e01a73b9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h index 367ca6bcfe..61472494f2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_instances_ref.h @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -19,7 +19,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -27,7 +27,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -35,7 +35,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -43,7 +43,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -51,7 +51,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -59,7 +59,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -67,7 +67,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -75,7 +75,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -83,7 +83,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -91,7 +91,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -99,7 +99,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -107,7 +107,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 32>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -115,7 +115,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -123,7 +123,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -131,7 +131,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -139,7 +139,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -147,7 +147,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -155,7 +155,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -163,7 +163,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -171,7 +171,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -179,7 +179,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -187,7 +187,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -195,7 +195,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -203,7 +203,103 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 64>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); + +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -211,7 +307,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -219,7 +315,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -227,7 +323,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -235,7 +331,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -243,7 +339,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -251,7 +347,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -259,7 +355,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -267,7 +363,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -275,7 +371,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -283,7 +379,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -291,7 +387,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -299,7 +395,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 128>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -307,7 +403,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -315,7 +411,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -323,7 +419,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -331,7 +427,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -339,7 +435,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -347,7 +443,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, @@ -355,7 +451,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, @@ -363,7 +459,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -371,7 +467,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, @@ -379,7 +475,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< true, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, @@ -387,7 +483,7 @@ extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< false, 256>(GroupedBackwardParams& param, hipStream_t stream); -extern template void run_grouped_backward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp index 15e2f31d8f..70837e9b2b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp index 00effd83ca..3ad63b3fb7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp index 1651af366a..d2ec293abe 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, + false, true, true, true, - false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp index 756c1dc187..6f988aedf5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..170b7dc080 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp index 831e8b9ac2..060a6b875a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp index d7aeb937ff..4093a812e7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp index 2659f809d0..ef3521c8bd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp index 4668340309..9f76e20d90 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, - false, true, + false, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..6274a56bb5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_has_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp index 3b71014f6b..6b97237665 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 128>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp index 09ac8a84e2..fc9b10b1a1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 256>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp index 62df2f2dd3..c166a7bd48 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 32>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp index 07514352b4..30cc3c575d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_has_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,10 +11,10 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, - true, false, + true, false, + true, 64>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..2f4058c055 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp index dc7f41755f..dd172a8cd6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp index 8d13665117..4eb6cba1aa 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp index 07e60021b8..34a1a45a03 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp index d562c03844..15691115b3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..5ea99eb70f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_has_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp index 3b38e48f68..9e72f65f20 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp index cc9c0e3771..143c79b972 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp index 7237f3cab9..e7935d54b7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp index 7f7b87b465..0b911129cb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..e2ff64c3dc --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_has_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp index fca2defab5..ee07981f0e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp index 247d2933f7..5e47962a51 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp index 952d91a05e..8936424612 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp index df612447ff..b8d022181c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_causalmask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_backward.h" -template void run_grouped_backward_causalmask_bias_dropout_dispatch< +template void run_grouped_backward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..835604b023 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_backward_fp16_no_mask_no_bias_no_biasgrad_no_dropout_maxk_96.cpp @@ -0,0 +1,20 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_backward.h" + +template void run_grouped_backward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 96>(GroupedBackwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp index 436b35249b..e221a4df68 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp index 673ace243f..7708b6be81 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp index 12f2dce035..f500369249 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp index b05db1117d..7af9ce737b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..90ed257288 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp index ac8a014bc1..63d87a7ceb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp index b98a212b3c..5ec5b2076d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, - false, true, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp index 249011ee13..0202533758 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, true, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp index 58357d0f8e..d49d2b41de 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..8945954299 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp index 6b03e2ffd8..acc3e80445 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp index 2bb41cd3bc..ef243b0dc5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, - true, false, + true, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp index ba57b065d5..23a3d60725 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp index 6b5463311d..2048527030 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..9866d6a0b2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp index 4b833c8f83..ce742afc08 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp index 3e07c10500..8170a8859c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp index 2b9b0559fd..33515ab436 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp index f43d7b41cc..c1bfa5227f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..c0602f9c08 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h index 4b1740f1a7..ea0947de21 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp index 222d1ed50c..9f5253947f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp index bcad83e85f..83474e1d76 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp index 8c17a20b72..8e8b152379 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, + false, true, true, - false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp index 15ac9062f8..c542571932 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..a5a67b1ad6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp index 1da0732d8f..48a41626a1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp index 4891094bce..41c9d6f57c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp index 276962324e..553b1fc8ba 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp index 2e552a9973..dfe68ffcad 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..810e671500 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp index c1b145ccd0..2d72bcb6a3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp index ea2ee50829..eda1008bf9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp index d20de70d8a..c101072938 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp index 6bad209f7c..a67bb0844f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..71182531ad --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp index 85f9097f59..4910d1463e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp index 456ae223ab..ab647a2e7e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp index 51cbbf71d7..f8c7491ae4 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp index 0614b84a2a..c4cd4e7b88 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..9203a02a35 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp index 6db568b7c8..1d130ea119 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp index 7c14a9f97a..e9525bfd6a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp index 3ad15a89ca..601415d752 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp index a0431622e7..571780c49b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..608cf7b582 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp index 3c5f652c7d..3841dadae0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp index f765d967b0..3ed3b86656 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, - false, true, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp index b7f09b7c36..8f45feab8c 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, true, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp index 1f3b70c843..8690683e49 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..e8ae22495d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp index 1ce7084261..4a985fb011 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp index 562298f722..3420d3aa50 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, - true, false, + true, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp index 65a976a9a2..74849113c9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp index 30b56e1b19..1303aa9b43 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..213703efeb --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp index 2b747e5e28..5ef755ddf6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp index 0d7c558cd3..24c5729743 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp index 3efca37987..6a6952ec63 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp index dae892ab78..434dcc2693 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..1ecdd0f832 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h index 2ac28a5200..e4327e83e5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_forward_causalmask_bias_dropout_dispatch< +extern template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp index b0918f6838..3f5f2707fc 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp index 432cdd9783..3a24dd4611 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp index 9daf7f6c68..b20dcc77ec 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, + false, true, true, - false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp index 8c6ad2498e..e93471b9a2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cbfcdfa07d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp index d2020485ee..4fd11b41bf 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp index a29929b80d..5b83a321c2 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp index d5f3cdffe1..ece97ea1d8 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp index 6a7482d692..a9af6a8ded 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..60f4f7d652 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp index 22ece82890..94bfe75ea5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp index d5a7778e5d..31136ded22 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp index bc5553560a..0e79cea140 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp index 4b74c49ef9..c4e8677838 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..77d6057173 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp index fc5604b5e4..25c0c1ac25 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp index f8741ae4f8..d7d3a36219 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp index 8c4e8581b5..a49ac26ee6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp index b29ac4d4f3..fc7ddced9e 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_forward.h" -template void run_grouped_forward_causalmask_bias_dropout_dispatch< +template void run_grouped_forward_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..2942d3e91a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_forward_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp index 52e1d5d711..d50935b1d7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp index 055b769f9f..e985ad8805 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp index 9ce3756a6f..8f88cf8e63 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp index 46d4e69b75..bcf4508b97 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..e6bbaad9e8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp index 5f11a042f3..82b400f0c6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp index 3134e1c4ca..a3325e6686 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp index f858eccb53..cca4cc5431 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp index 5da3272f08..e033986a24 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..cb80ff6e05 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp index ed632d7ea6..2f257ffd73 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp index d336cc52d2..a772490804 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp index 7095195dd6..94b83ea16b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp index 312a64a29d..1e0258d11a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..b8aecbef49 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp index 5949924e4c..5c5052773a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp index 4ed0179061..f5267d11a7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp index d5df909462..17549b1ff6 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp index 8be8afd5ee..49b14547cd 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - false, true, false, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..30db8093b3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h index aa5c84146c..6022b79cc3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, true, false, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp index 95eb7e0ed8..e5fb64fac3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp index e9c361bd0a..4eec28e4df 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp index 5530bb928f..d26e0d4771 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp index 0a55926151..b9498adfc1 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..48530caca9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp index 4416036397..d09cd5a863 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp index 39e2f9fed8..acb1b14fef 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp index 6172df88a2..1924525a47 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp index 41681f1805..818af21711 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, - false, true, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..a1236ed698 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp index 5747867dc4..b73fbd3e60 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp index f54dadca5a..8e40965635 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp index a6b637a297..92db0a3bac 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp index 47abe27d92..affb5a980b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, - true, false, false, + true, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..75ff69dfec --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp index 98625d1428..7efc0e9203 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp index 9d3d732888..c1493d3e44 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp index bb537cfe2c..315429ef08 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp index 66769f244c..8cce00c824 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::bf16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..86f93c2b3a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_bf16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp index 4c35127f9b..cbbd746a8f 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp index 12a2a61052..960634ed47 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp index 885584ef4b..d3bbeeaea0 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp index a11af5773c..0fda8f6a47 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..9eac3a46b5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp index 8d1f0fb7f9..91a3b3aec9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp index 50577f7f96..8859657b71 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp index 07fcfd2eb6..ab8ee4823b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp index dc3690344b..dea721a634 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..d843caa1ac --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp index b3727732a0..edecb5ee5b 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp index b8cb896222..5aabfa102d 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp index a4c2cacf19..d4b2a56bd7 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp index 2b36d6f33f..5c6b91be17 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..90175276f0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp index 2f63665845..40d3950944 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp index aed425ba5c..0abf5b79ba 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp index c3678b42f5..afa07836b9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp index 7481a9b9aa..03fa1e82b9 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - false, true, false, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..5efcef2c86 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h index f3a5d8501a..c38d01ca60 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_instances_ref.h @@ -11,224 +11,280 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 32>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 64>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, false, 128>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, true, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, true, false, false, 256>(GroupedForwardParams& param, hipStream_t stream); -extern template void run_grouped_infer_causalmask_bias_dropout_dispatch< +extern template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp index ffb1b36d60..db687f5110 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp index db5416d92f..d78135bea3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp index d5cce31a76..fd4fea5d62 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp index bb3ad0e570..c1c4742435 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_has_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, true, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..37d18699ee --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp index f6282217df..33dd36ae2a 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp index 0564af6ec1..4ed97869a3 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp index afbe9a21f5..8317354c85 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp index 99e9133dce..f761773b84 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_has_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, - false, true, + false, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..3d80d5fd9c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_has_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + true, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp index f3827c2401..f9ab0be1fa 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_128.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp index 6627919bb5..f4f7fee792 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_256.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp index 793fc5c902..a510dfb2b5 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_32.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp index 2d50423e73..9d8b8e8987 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_has_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_64.cpp @@ -11,9 +11,9 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, - true, false, false, + true, 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp new file mode 100644 index 0000000000..15788edbf7 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_has_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp index 637d40bc17..3287d5e4ba 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_128.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_128.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp index ca8cb1bed3..b7f99432ce 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_256.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_256.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp index 61f1540aeb..f6d6340842 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_32.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_32.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp similarity index 87% rename from xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp rename to xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp index cad791039f..44f3b7d0cb 100644 --- a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_causalmask_no_bias_no_dropout_maxk_64.cpp +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_64.cpp @@ -11,7 +11,7 @@ #include #include "ck_tiled_fmha_grouped_infer.h" -template void run_grouped_infer_causalmask_bias_dropout_dispatch< +template void run_grouped_infer_mask_bias_dropout_dispatch< ck_tile::fp16_t, false, false, diff --git a/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp new file mode 100644 index 0000000000..b6e94978f6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/fmha_grouped_infer_fp16_no_mask_no_bias_no_dropout_maxk_96.cpp @@ -0,0 +1,19 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + */ + +#include +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_mask_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + 96>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index a4defb17c3..5d11377117 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -6,7 +6,6 @@ from dataclasses import replace from enum import Enum -from functools import partial from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union import torch @@ -28,6 +27,9 @@ LowerTriangularFromBottomRightMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, ) from .common import ( AttentionBwOpBase, @@ -35,7 +37,6 @@ Context, Gradients, Inputs, - _attn_bias_apply, check_lastdim_alignment_stride1, ) @@ -50,7 +51,13 @@ def _get_seqlen_info( attn_bias = inp.attn_bias if isinstance( attn_bias, - (BlockDiagonalMask, BlockDiagonalPaddedKeysMask, BlockDiagonalGappyKeysMask), + ( + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + BlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + ), ): attn_bias.k_seqinfo.to(inp.query.device) attn_bias.q_seqinfo.to(inp.query.device) @@ -134,6 +141,7 @@ def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int attn_bias.BlockDiagonalCausalFromBottomRightMask, BlockDiagonalCausalWithOffsetPaddedKeysMask, BlockDiagonalCausalLocalAttentionFromBottomRightMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, ), ): return int(_CustomMaskType.CausalFromBottomRight) @@ -165,6 +173,9 @@ class FwOp(AttentionFwOpBase): attn_bias.BlockDiagonalCausalFromBottomRightMask, attn_bias.BlockDiagonalCausalLocalAttentionMask, BlockDiagonalCausalLocalAttentionFromBottomRightMask, + PagedBlockDiagonalPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, ) SUPPORTS_DROPOUT = True @@ -187,6 +198,7 @@ class FwOp(AttentionFwOpBase): _TEST_K: List[int] = [ 32, # 64x64 kernel + 96, 128, # 64x128 kernel 256, # 64x128 with accumulation in gmem ] @@ -197,62 +209,50 @@ def apply( ) -> Tuple[torch.Tensor, Optional[Context]]: if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: raise NotImplementedError("Unsupported attn_bias type") - if inp.query.ndim in [3, 4]: + if inp.query.ndim in [1, 2, 3]: + raise NotImplementedError("Unsupported number of dimensions") + if inp.query.ndim in [4]: return cls.apply_bmhk(inp, needs_gradient=needs_gradient) assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" ctx: Optional[Context] = None - # XXX: Hackfix for BMGHK with H=1 - # In that case we don't want to run G different streams because it adds - # some overhead - if inp.query.ndim == 5 and inp.query.shape[3] == 1: - slice_op = partial(torch.squeeze, dim=3) - inp = replace( - inp, - query=slice_op(inp.query), - key=slice_op(inp.key), - value=slice_op(inp.value), - attn_bias=_attn_bias_apply( - inp.attn_bias, partial(torch.squeeze, dim=2) - ), + + # consider for expanded 5-D inputted + if inp.key.stride()[3] == 0: + assert ( + inp.value.stride()[3] == 0 + ), "key and value should be expanded in the same way" + k_shape = inp.key.size() + k_stride = inp.key.stride() + key = inp.key.as_strided( + (k_shape[0], k_shape[1], k_shape[2], k_shape[4]), + (k_stride[0], k_stride[1], k_stride[2], k_stride[4]), ) - out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient) - out = out.unsqueeze(3) - if ctx is not None: - ctx = replace(ctx, lse=ctx.lse.unsqueeze(1), out=out) - return out, ctx - - # Workaround until this is properly implemented in C++ - # run each head group in a different stream - n_groups = inp.key.shape[2] - main_stream = torch.cuda.current_stream() - streams = [main_stream] + [ - torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1) - ] - outs = [] - for group, stream in enumerate(streams): - stream.wait_stream(main_stream) - with torch.cuda.stream(stream): - query = inp.query[:, :, group] - key = inp.key[:, :, group] - value = inp.value[:, :, group] - bias = _attn_bias_apply( - inp.attn_bias, partial(torch.select, dim=1, index=group) - ) - outs.append( - cls.apply_bmhk( - replace(inp, query=query, key=key, value=value, attn_bias=bias), - needs_gradient=needs_gradient, - ) - ) - for s in streams[1:]: - main_stream.wait_stream(s) - out = torch.stack([o[0] for o in outs], dim=2) - if needs_gradient: - ctx = Context( - out=out, - lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore - op_bw=outs[0][1].op_bw, # type: ignore + v_shape = inp.value.size() + v_stride = inp.value.stride() + value = inp.value.as_strided( + (v_shape[0], v_shape[1], v_shape[2], v_shape[4]), + (v_stride[0], v_stride[1], v_stride[2], v_stride[4]), ) + else: + key = inp.key.flatten(2, 3) + value = inp.value.flatten(2, 3) + + [_, _, G, Hq, _] = inp.query.shape + attn_bias_replace = inp.attn_bias + if isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.ndim != 0: + attn_bias_replace = inp.attn_bias.flatten(1, 2) + inp = replace( + inp, + query=inp.query.flatten(2, 3), + key=key, + value=value, + attn_bias=attn_bias_replace, + ) + out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient) + out = out.unflatten(2, (G, Hq)) + if ctx is not None: + lse = ctx.lse.unflatten(1, (G, Hq)) + ctx = replace(ctx, lse=lse, out=out) return out, ctx @classmethod @@ -281,6 +281,8 @@ def apply_bmhk( ( BlockDiagonalGappyKeysMask, BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, ), ) else None @@ -297,6 +299,28 @@ def apply_bmhk( ) else None ), + block_tables=( + inp.attn_bias.block_tables + if isinstance( + inp.attn_bias, + ( + PagedBlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + ), + ) + else None + ), + page_size=( + inp.attn_bias.page_size + if isinstance( + inp.attn_bias, + ( + PagedBlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + ), + ) + else None + ), ) ctx: Optional[Context] = None @@ -356,6 +380,7 @@ class BwOp(AttentionBwOpBase): _TEST_K: List[int] = [ 32, # 64x64 kernel 64, + 96, 128, # 64x128/128x128 kernel 256, ]