Skip to content

Commit

Permalink
Remove _check_large_shapes checking in fmha/ck.py (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz authored Jul 14, 2024
1 parent f53c36e commit 2456ea3
Showing 1 changed file with 0 additions and 17 deletions.
17 changes: 0 additions & 17 deletions xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,6 @@ def _check_bias_alignment(
)


def _check_large_shapes(reasons: List[str], inp: Inputs) -> None:
"""CK kernel throws "Memory access fault by GPU node-2" when B * T >= 2**20, might be some index overflow.
To reproduce, remove this function and run benchmark_mem_eff_attention with ParlAI model shape (256, 4096, 16, 64).
This needs further debugging, for now let's not support such shapes.
"""
b_t_limit = 1024**2
q_too_large = inp.query.shape[0] * inp.query.shape[1] >= b_t_limit
k_too_large = inp.key.shape[0] * inp.key.shape[1] >= b_t_limit
v_too_large = inp.value.shape[0] * inp.value.shape[1] >= b_t_limit
if q_too_large or k_too_large or v_too_large:
reasons.append(
"Input is too large: product of first two dimensions of q/k/v must be < 2**20"
)


class _CustomMaskType(int, Enum):
"""
(Matches CustomMaskType in C++.)
Expand Down Expand Up @@ -325,7 +310,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
_check_bias_alignment(reasons, d.attn_bias)
_check_large_shapes(reasons, d)
return reasons

@classmethod
Expand Down Expand Up @@ -416,7 +400,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
f"(shape: {tuple(attn_bias_tensor.shape)}"
f"/ expected: {expected_bias_shape})"
)
_check_large_shapes(reasons, d)

return reasons

Expand Down

0 comments on commit 2456ea3

Please sign in to comment.