From 2456ea3b8cc273fcc3be84e8a8d8871f5e4b04a9 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Mon, 15 Jul 2024 02:39:58 +0800 Subject: [PATCH] Remove _check_large_shapes checking in fmha/ck.py (#1067) --- xformers/ops/fmha/ck.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 39a0895533..5fd37ec1c6 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -102,21 +102,6 @@ def _check_bias_alignment( ) -def _check_large_shapes(reasons: List[str], inp: Inputs) -> None: - """CK kernel throws "Memory access fault by GPU node-2" when B * T >= 2**20, might be some index overflow. - To reproduce, remove this function and run benchmark_mem_eff_attention with ParlAI model shape (256, 4096, 16, 64). - This needs further debugging, for now let's not support such shapes. - """ - b_t_limit = 1024**2 - q_too_large = inp.query.shape[0] * inp.query.shape[1] >= b_t_limit - k_too_large = inp.key.shape[0] * inp.key.shape[1] >= b_t_limit - v_too_large = inp.value.shape[0] * inp.value.shape[1] >= b_t_limit - if q_too_large or k_too_large or v_too_large: - reasons.append( - "Input is too large: product of first two dimensions of q/k/v must be < 2**20" - ) - - class _CustomMaskType(int, Enum): """ (Matches CustomMaskType in C++.) @@ -325,7 +310,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) _check_bias_alignment(reasons, d.attn_bias) - _check_large_shapes(reasons, d) return reasons @classmethod @@ -416,7 +400,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: f"(shape: {tuple(attn_bias_tensor.shape)}" f"/ expected: {expected_bias_shape})" ) - _check_large_shapes(reasons, d) return reasons