Skip to content

Commit

Permalink
profiler: Removed 'DetectSlowOpsProfiler' profiler
Browse files Browse the repository at this point in the history
ghstack-source-id: 8b40b657d2da0d8c173954652a1b86266e8fced1
Pull Request resolved: fairinternal/xformers#1193

__original_commit__ = fairinternal/xformers@cd80105
  • Loading branch information
danthe3rd authored and xFormers Bot committed Aug 22, 2024
1 parent 1f6242d commit b7c5a3d
Show file tree
Hide file tree
Showing 11 changed files with 2 additions and 851 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- When using the most recent version of Flash-Attention, it is no longer possible to mix it with the cutlass backend. In other words, it is no longer possible to use the cutlass Fw with the flash Bw.
### Removed
- fMHA: Removed `decoder` and `small_k` backends
- profiler: Removed `DetectSlowOpsProfiler` profiler

## [0.0.27.post2] - 2024-07-26
Pre-built binary wheels require PyTorch 2.4.0
Expand Down
58 changes: 1 addition & 57 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
import torch
import torch.nn as nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode
from torch.utils._python_dispatch import _get_current_dispatch_mode

import xformers.ops as xops
import xformers.ops.fmha as fmha
import xformers.profiler
from xformers.profiler import profile_analyzer
from xformers.profiler.slow_ops_profiler import GemmOpComputeFlops, flop_mapping

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")

Expand All @@ -31,61 +30,6 @@
)


class GEMMShapeDispatcher(TorchDispatchMode):
def __init__(self) -> None:
super().__init__()
self.mnk = (0, 0, 0)

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func._overloadpacket in flop_mapping:
compute_flops = flop_mapping[func._overloadpacket]
if isinstance(compute_flops, GemmOpComputeFlops):
self.mnk = compute_flops._get_mnk(args)
return func(*args)


def test_gemm_flops() -> None:
M, N, K = 13, 17, 53

a = torch.empty([M, K])
b = torch.empty([K, N])
x = torch.empty([K])

with GEMMShapeDispatcher() as disp:
a @ b
assert disp.mnk == (M, N, K)
with GEMMShapeDispatcher() as disp:
a @ x
assert disp.mnk == (M, 1, K)
with GEMMShapeDispatcher() as disp:
torch.nn.functional.linear(a, b.transpose(0, 1))
assert disp.mnk == (M, N, K)
with GEMMShapeDispatcher() as disp:
torch.addmm(torch.empty([1, 1]), a, b)
assert disp.mnk == (M, N, K)

B = 3
ba = torch.empty([B, M, K])
bb = torch.empty([B, K, N])
with GEMMShapeDispatcher() as disp:
ba @ bb
assert disp.mnk == (B * M, N, K)
with GEMMShapeDispatcher() as disp:
ba @ bb[:1]
assert disp.mnk == (B * M, N, K)
with GEMMShapeDispatcher() as disp:
ba[:1] @ bb
assert disp.mnk == (B * M, N, K)
with GEMMShapeDispatcher() as disp:
ba @ bb[0]
assert disp.mnk == (B * M, N, K)
with GEMMShapeDispatcher() as disp:
torch.addbmm(torch.empty([1, 1]), ba, bb)
assert disp.mnk == (B * M, N, K)


@cuda_only
def test_profiler_dispatcher_stream_workaround() -> None:
x = torch.zeros([10, 10], device="cuda")
Expand Down
5 changes: 0 additions & 5 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ def is_available(cls) -> bool:
return False
return True

@classmethod
def operator_flop(cls, *inputs) -> int:
"""Calculate number of FLOP given inputs to `OPERATOR`"""
return -1


OPERATORS_REGISTRY: List[Type[BaseOperator]] = []
FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {}
Expand Down
54 changes: 0 additions & 54 deletions xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,30 +316,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
_check_bias_alignment(reasons, d.attn_bias)
return reasons

@classmethod
# type: ignore
def operator_flop(
cls,
q,
k,
v,
b,
seqstart_q,
seqstart_k,
max_seqlen_q_,
compute_lse,
custom_mask_type,
*a,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
causal=custom_mask_type > 0,
seqstart_k=seqstart_k,
seqstart_q=seqstart_q,
)


@register_operator
class BwOp(AttentionBwOpBase):
Expand Down Expand Up @@ -478,33 +454,3 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
grad_bias = None

return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)

@classmethod
# type: ignore
def operator_flop(
cls,
dO,
q,
k,
v,
b,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
logsumexp,
output,
dropout_p,
rng_seed,
rng_offset,
custom_mask_type,
scale,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
seqstart_q=cu_seqlens_q,
seqstart_k=cu_seqlens_k,
causal=custom_mask_type > 0,
)
93 changes: 0 additions & 93 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,49 +427,6 @@ def apply(
) -> Tuple[torch.Tensor, Optional[Context]]:
raise NotImplementedError()

@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4

if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]

total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# Q @ K.transpose
total_flop += num_q * num_kv * query.shape[-1] * 2
# (ignore softmax)
# attn @ V
total_flop += num_q * key.shape[-1] * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop


class AttentionBwOpBase(AttentionOpBase):
# NOTE on tolerances: These are tested for `scales => (1/32)**0.5`
Expand Down Expand Up @@ -508,56 +465,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
raise NotImplementedError()

@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4

if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]

total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
Kqk = query.shape[-1]
Kv = value.shape[-1]
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# att = Q @ K.transpose
total_flop += num_q * num_kv * Kqk * 2
# att @ dO
total_flop += num_kv * num_q * Kv * 2
# dov = dO @ V
total_flop += num_q * Kv * num_kv * 2
# dov @ K
total_flop += num_q * Kqk * num_kv * 2
# dov @ Q
total_flop += num_q * Kqk * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop


AttentionOp = Tuple[
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
Expand Down
54 changes: 0 additions & 54 deletions xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,30 +332,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
_check_bias_alignment(reasons, d.attn_bias)
return reasons

@classmethod
# type: ignore
def operator_flop(
cls,
q,
k,
v,
b,
seqstart_q,
seqstart_k,
max_seqlen_q_,
compute_lse,
custom_mask_type,
*a,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
causal=custom_mask_type > 0,
seqstart_k=seqstart_k,
seqstart_q=seqstart_q,
)


@register_operator
class BwOp(AttentionBwOpBase):
Expand Down Expand Up @@ -492,33 +468,3 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
grad_bias = None

return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)

@classmethod
# type: ignore
def operator_flop(
cls,
dO,
q,
k,
v,
b,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
logsumexp,
output,
dropout_p,
rng_seed,
rng_offset,
custom_mask_type,
scale,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
seqstart_q=cu_seqlens_q,
seqstart_k=cu_seqlens_k,
causal=custom_mask_type > 0,
)
55 changes: 0 additions & 55 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,31 +706,6 @@ def apply(
ctx.rng_state = rng_state
return (out, ctx)

@classmethod
# type: ignore
def operator_flop(
cls,
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
return_softmax,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)


@register_operator
class BwOp(AttentionBwOpBase):
Expand Down Expand Up @@ -849,33 +824,3 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
grads.dk = grads.dk.reshape(dk_shape)
grads.dv = grads.dv.reshape(dv_shape)
return grads

@classmethod
# type: ignore
def operator_flop(
cls,
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)
Loading

0 comments on commit b7c5a3d

Please sign in to comment.