diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b6be3b968..e01f0c0fa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - When using the most recent version of Flash-Attention, it is no longer possible to mix it with the cutlass backend. In other words, it is no longer possible to use the cutlass Fw with the flash Bw. ### Removed - fMHA: Removed `decoder` and `small_k` backends +- profiler: Removed `DetectSlowOpsProfiler` profiler ## [0.0.27.post2] - 2024-07-26 Pre-built binary wheels require PyTorch 2.4.0 diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 04a72c9edb..b3b2a643db 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -11,13 +11,12 @@ import torch import torch.nn as nn from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode +from torch.utils._python_dispatch import _get_current_dispatch_mode import xformers.ops as xops import xformers.ops.fmha as fmha import xformers.profiler from xformers.profiler import profile_analyzer -from xformers.profiler.slow_ops_profiler import GemmOpComputeFlops, flop_mapping cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") @@ -31,61 +30,6 @@ ) -class GEMMShapeDispatcher(TorchDispatchMode): - def __init__(self) -> None: - super().__init__() - self.mnk = (0, 0, 0) - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func._overloadpacket in flop_mapping: - compute_flops = flop_mapping[func._overloadpacket] - if isinstance(compute_flops, GemmOpComputeFlops): - self.mnk = compute_flops._get_mnk(args) - return func(*args) - - -def test_gemm_flops() -> None: - M, N, K = 13, 17, 53 - - a = torch.empty([M, K]) - b = torch.empty([K, N]) - x = torch.empty([K]) - - with GEMMShapeDispatcher() as disp: - a @ b - assert disp.mnk == (M, N, K) - with GEMMShapeDispatcher() as disp: - a @ x - assert disp.mnk == (M, 1, K) - with GEMMShapeDispatcher() as disp: - torch.nn.functional.linear(a, b.transpose(0, 1)) - assert disp.mnk == (M, N, K) - with GEMMShapeDispatcher() as disp: - torch.addmm(torch.empty([1, 1]), a, b) - assert disp.mnk == (M, N, K) - - B = 3 - ba = torch.empty([B, M, K]) - bb = torch.empty([B, K, N]) - with GEMMShapeDispatcher() as disp: - ba @ bb - assert disp.mnk == (B * M, N, K) - with GEMMShapeDispatcher() as disp: - ba @ bb[:1] - assert disp.mnk == (B * M, N, K) - with GEMMShapeDispatcher() as disp: - ba[:1] @ bb - assert disp.mnk == (B * M, N, K) - with GEMMShapeDispatcher() as disp: - ba @ bb[0] - assert disp.mnk == (B * M, N, K) - with GEMMShapeDispatcher() as disp: - torch.addbmm(torch.empty([1, 1]), ba, bb) - assert disp.mnk == (B * M, N, K) - - @cuda_only def test_profiler_dispatcher_stream_workaround() -> None: x = torch.zeros([10, 10], device="cuda") diff --git a/xformers/ops/common.py b/xformers/ops/common.py index 7dab42caa5..f31fde7332 100644 --- a/xformers/ops/common.py +++ b/xformers/ops/common.py @@ -39,11 +39,6 @@ def is_available(cls) -> bool: return False return True - @classmethod - def operator_flop(cls, *inputs) -> int: - """Calculate number of FLOP given inputs to `OPERATOR`""" - return -1 - OPERATORS_REGISTRY: List[Type[BaseOperator]] = [] FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {} diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 444326f93a..f004d2bff0 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -316,30 +316,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: _check_bias_alignment(reasons, d.attn_bias) return reasons - @classmethod - # type: ignore - def operator_flop( - cls, - q, - k, - v, - b, - seqstart_q, - seqstart_k, - max_seqlen_q_, - compute_lse, - custom_mask_type, - *a, - ) -> int: - return cls.attn_operator_flop( - q, - k, - v, - causal=custom_mask_type > 0, - seqstart_k=seqstart_k, - seqstart_q=seqstart_q, - ) - @register_operator class BwOp(AttentionBwOpBase): @@ -478,33 +454,3 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: grad_bias = None return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias) - - @classmethod - # type: ignore - def operator_flop( - cls, - dO, - q, - k, - v, - b, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - logsumexp, - output, - dropout_p, - rng_seed, - rng_offset, - custom_mask_type, - scale, - ) -> int: - return cls.attn_operator_flop( - q, - k, - v, - seqstart_q=cu_seqlens_q, - seqstart_k=cu_seqlens_k, - causal=custom_mask_type > 0, - ) diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 0848aab2a1..c862a50a99 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -427,49 +427,6 @@ def apply( ) -> Tuple[torch.Tensor, Optional[Context]]: raise NotImplementedError() - @classmethod - def attn_operator_flop( - cls, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - causal: bool = False, - seqstart_k: Optional[torch.Tensor] = None, - seqstart_q: Optional[torch.Tensor] = None, - ) -> int: - """ - Computes total flops for the attention - Assumes inputs in format BMHK - """ - assert query.ndim == 4 - - if seqstart_q is not None: - seqstart_q_py = seqstart_q.tolist() - else: - seqstart_q_py = [0, query.shape[1]] - if seqstart_k is not None: - seqstart_k_py = seqstart_k.tolist() - else: - seqstart_k_py = [0, key.shape[1]] - - total_flop = 0 - for q_start, q_end, k_start, k_end in zip( - seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] - ): - num_q = q_end - q_start - num_kv = k_end - k_start - # (M,K) @ (K,N) GEMM needs M*N*K*2 flop - # Q @ K.transpose - total_flop += num_q * num_kv * query.shape[-1] * 2 - # (ignore softmax) - # attn @ V - total_flop += num_q * key.shape[-1] * num_kv * 2 - # Multiply by num_heads and batches - total_flop = total_flop * value.shape[2] * value.shape[0] - if causal: - total_flop //= 2 - return total_flop - class AttentionBwOpBase(AttentionOpBase): # NOTE on tolerances: These are tested for `scales => (1/32)**0.5` @@ -508,56 +465,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: raise NotImplementedError() - @classmethod - def attn_operator_flop( - cls, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - causal: bool = False, - seqstart_k: Optional[torch.Tensor] = None, - seqstart_q: Optional[torch.Tensor] = None, - ) -> int: - """ - Computes total flops for the attention - Assumes inputs in format BMHK - """ - assert query.ndim == 4 - - if seqstart_q is not None: - seqstart_q_py = seqstart_q.tolist() - else: - seqstart_q_py = [0, query.shape[1]] - if seqstart_k is not None: - seqstart_k_py = seqstart_k.tolist() - else: - seqstart_k_py = [0, key.shape[1]] - - total_flop = 0 - for q_start, q_end, k_start, k_end in zip( - seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] - ): - num_q = q_end - q_start - num_kv = k_end - k_start - Kqk = query.shape[-1] - Kv = value.shape[-1] - # (M,K) @ (K,N) GEMM needs M*N*K*2 flop - # att = Q @ K.transpose - total_flop += num_q * num_kv * Kqk * 2 - # att @ dO - total_flop += num_kv * num_q * Kv * 2 - # dov = dO @ V - total_flop += num_q * Kv * num_kv * 2 - # dov @ K - total_flop += num_q * Kqk * num_kv * 2 - # dov @ Q - total_flop += num_q * Kqk * num_kv * 2 - # Multiply by num_heads and batches - total_flop = total_flop * value.shape[2] * value.shape[0] - if causal: - total_flop //= 2 - return total_flop - AttentionOp = Tuple[ Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]] diff --git a/xformers/ops/fmha/cutlass.py b/xformers/ops/fmha/cutlass.py index ac0369d0f7..f26252340a 100644 --- a/xformers/ops/fmha/cutlass.py +++ b/xformers/ops/fmha/cutlass.py @@ -332,30 +332,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: _check_bias_alignment(reasons, d.attn_bias) return reasons - @classmethod - # type: ignore - def operator_flop( - cls, - q, - k, - v, - b, - seqstart_q, - seqstart_k, - max_seqlen_q_, - compute_lse, - custom_mask_type, - *a, - ) -> int: - return cls.attn_operator_flop( - q, - k, - v, - causal=custom_mask_type > 0, - seqstart_k=seqstart_k, - seqstart_q=seqstart_q, - ) - @register_operator class BwOp(AttentionBwOpBase): @@ -492,33 +468,3 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: grad_bias = None return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias) - - @classmethod - # type: ignore - def operator_flop( - cls, - dO, - q, - k, - v, - b, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - logsumexp, - output, - dropout_p, - rng_seed, - rng_offset, - custom_mask_type, - scale, - ) -> int: - return cls.attn_operator_flop( - q, - k, - v, - seqstart_q=cu_seqlens_q, - seqstart_k=cu_seqlens_k, - causal=custom_mask_type > 0, - ) diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index c0159f9eda..b298ce8a36 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -706,31 +706,6 @@ def apply( ctx.rng_state = rng_state return (out, ctx) - @classmethod - # type: ignore - def operator_flop( - cls, - query, - key, - value, - cu_seq_lens_q, - cu_seq_lens_k, - max_seq_len_q, - max_seq_len_k, - p, - softmax_scale, - causal, - return_softmax, - ) -> int: - return cls.attn_operator_flop( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - causal=causal, - seqstart_k=cu_seq_lens_k, - seqstart_q=cu_seq_lens_q, - ) - @register_operator class BwOp(AttentionBwOpBase): @@ -849,33 +824,3 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: grads.dk = grads.dk.reshape(dk_shape) grads.dv = grads.dv.reshape(dv_shape) return grads - - @classmethod - # type: ignore - def operator_flop( - cls, - grad, - query, - key, - value, - out, - lse, - dq, - dk, - dv, - cu_seq_lens_q, - cu_seq_lens_k, - max_seq_len_q, - max_seq_len_k, - p, - softmax_scale, - causal, - ) -> int: - return cls.attn_operator_flop( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - causal=causal, - seqstart_k=cu_seq_lens_k, - seqstart_q=cu_seq_lens_q, - ) diff --git a/xformers/ops/swiglu_op.py b/xformers/ops/swiglu_op.py index 4f6938d846..d764196803 100644 --- a/xformers/ops/swiglu_op.py +++ b/xformers/ops/swiglu_op.py @@ -42,15 +42,6 @@ class DualGemmSiluOp(BaseOperator): OPERATOR_CATEGORY = "swiglu" NAME = "dual_gemm_silu" - @classmethod - # type: ignore - def operator_flop( - cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2 - ) -> int: - """NOTE: we neglect the impact of biases / pointwises""" - M, N, K = x.shape[0], w1.shape[0], w1.shape[1] - return M * N * K * 2 * 2 - @register_operator class GemmFusedSumOp(BaseOperator): @@ -58,12 +49,6 @@ class GemmFusedSumOp(BaseOperator): OPERATOR_CATEGORY = "swiglu" NAME = "gemm_fused_operand_sum" - @classmethod - # type: ignore - def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int: - M, N, K = a.shape[0], b.shape[1], a.shape[1] - return M * N * K * 2 - class _SwiGLUDecomposedFunc(torch.autograd.Function): """ diff --git a/xformers/profiler/__init__.py b/xformers/profiler/__init__.py index ae14c7ddbc..fe9d0a4926 100644 --- a/xformers/profiler/__init__.py +++ b/xformers/profiler/__init__.py @@ -5,7 +5,6 @@ from .api import profile, step from .profiler import MemSnapshotsProfiler, NsightProfiler, PyTorchProfiler -from .slow_ops_profiler import DetectSlowOpsProfiler __all__ = [ "profile", @@ -13,5 +12,4 @@ "MemSnapshotsProfiler", "PyTorchProfiler", "NsightProfiler", - "DetectSlowOpsProfiler", ] diff --git a/xformers/profiler/api.py b/xformers/profiler/api.py index 1dc7ef9d81..02722dc4ac 100644 --- a/xformers/profiler/api.py +++ b/xformers/profiler/api.py @@ -15,7 +15,6 @@ _Profiler, ) from .profiler_dcgm import DCGMProfiler # noqa: F401 -from .slow_ops_profiler import DetectSlowOpsProfiler # noqa: F401 DEFAULT_SCHEDULE = ( (MemSnapshotsProfiler, 0, 2), @@ -25,10 +24,6 @@ # TODO: Found issues where this can take minutes to # start, as it flushes previous values # (DCGMProfiler, 9, 11), - # TODO: There are some issues in PyTorch stable - # which are now fixed on main, but might break this profiler - # https://github.com/pytorch/pytorch/issues/94403 - # (DetectSlowOpsProfiler, 9, 10), ) @@ -62,7 +57,6 @@ def profile( module=model, schedule=[ (MemSnapshotsProfiler, 0, 2), - (DetectSlowOpsProfiler, 2, 4), (NsightProfiler, 4, 6), (PyTorchProfiler, 6, 20), ] diff --git a/xformers/profiler/slow_ops_profiler.py b/xformers/profiler/slow_ops_profiler.py deleted file mode 100644 index 1f0c64dc5b..0000000000 --- a/xformers/profiler/slow_ops_profiler.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -import json -import math -from collections import defaultdict -from dataclasses import dataclass, field -from functools import partial -from typing import Any, Dict, List, Set, Tuple - -import torch.cuda.memory -import torch.cuda.nvtx -import torch.profiler -import torch.utils.hooks -from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily -from torch.utils._pytree import tree_map - -from ..ops.common import FUNC_TO_XFORMERS_OPERATOR -from .device_limits import get_device_limits -from .profiler import _Profiler - - -class TorchFuncMockNoDispatch: - """ - Wraps a method to call it without the custom - pytorch dispatcher - """ - - def __init__(self, pt_impl): - self.pt_impl = pt_impl - - def __get__(self, obj, c): - return partial(self, obj) - - def __call__(self, obj, *args, **kwargs): - with _pop_mode_temporarily(): - return self.pt_impl(obj, *args, **kwargs) - - -class DispatcherWithoutBrokenFuncs(TorchDispatchMode): - TENSOR_FUNCS_NO_DISPATCH = [ - # Can't convert Stream argument to Python object - # https://github.com/pytorch/pytorch/issues/94403 - "record_stream" - ] - - def __enter__(self) -> None: - self._pt_impls = {} - for k in self.TENSOR_FUNCS_NO_DISPATCH: - impl = getattr(torch.Tensor, k) - self._pt_impls[k] = impl - setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl)) - return super().__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - for k in self.TENSOR_FUNCS_NO_DISPATCH: - setattr(torch.Tensor, k, self._pt_impls[k]) - return super().__exit__(exc_type, exc_val, exc_tb) - - -def get_shape(i): - return i.shape - - -def prod(x): - res = 1 - for i in x: - res *= i - return res - - -class GemmOpComputeFlops: - def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: - return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1]) - - def __call__(self, inputs: List[Any], outputs: List[Any]) -> float: - return 2 * prod(self._get_mnk(inputs)) - - def op_suffix(self, inputs: List[Any]) -> str: - m, n, k = self._get_mnk(inputs) - return f"_{m}x{n}x{k}" - - -class GemmOpComputeFlopsLinear(GemmOpComputeFlops): - def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: - return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1]) - - -class GemmOpComputeFlopsMv(GemmOpComputeFlops): - def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: - return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1]) - - -class GemmOpComputeFlopsBmm(GemmOpComputeFlops): - def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: - a, b = inputs[0], inputs[1] - assert a.ndim == 3 - assert b.ndim == 3 - bs = max(inputs[0].shape[0], inputs[1].shape[0]) - return (bs * a.shape[1], b.shape[-1], b.shape[-2]) - - -class GemmOpComputeFlopsAddmm(GemmOpComputeFlops): - def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: - return super()._get_mnk(inputs[1:]) - - -class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm): - def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: - return super()._get_mnk(inputs[1:]) - - -def conv_flop_count( - x_shape: List[int], - w_shape: List[int], - out_shape: List[int], - transposed: bool = False, -) -> float: - """ - Count flops for convolution. Note only multiplication is - counted. Computation for addition and bias is ignored. - Flops for a transposed convolution are calculated as - flops = (x_shape[2:] * prod(w_shape) * batch_size). - Args: - x_shape (list(int)): The input shape before convolution. - w_shape (list(int)): The filter shape. - out_shape (list(int)): The output shape after convolution. - transposed (bool): is the convolution transposed - Returns: - int: the number of flops - """ - batch_size = x_shape[0] - conv_shape = (x_shape if transposed else out_shape)[2:] - flop = batch_size * prod(w_shape) * prod(conv_shape) - return flop - - -def conv_flop(inputs: List[Any], outputs: List[Any]): - """ - Count flops for convolution. - """ - x, w = inputs[:2] - x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) - transposed = inputs[6] - - return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) - - -def transpose_shape(shape): - return [shape[1], shape[0]] + list(shape[2:]) - - -def conv_backward_flop(inputs: List[Any], outputs: List[Any]): - grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]] - output_mask = inputs[-1] - fwd_transposed = inputs[7] - flop_count = 0.0 - - if output_mask[0]: - grad_input_shape = get_shape(outputs[0]) - flop_count += conv_flop_count( - grad_out_shape, w_shape, grad_input_shape, not fwd_transposed - ) - if output_mask[1]: - grad_weight_shape = get_shape(outputs[1]) - flop_count += conv_flop_count( - transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed - ) - - return flop_count - - -def tensor_storage_size_in_mem(x: torch.Tensor): - total = 1 - for dim_sz, stride in zip(x.shape, x.stride()): - if stride >= 1: - total *= dim_sz - return total - - -def get_size(inputs: List[Any]): - total_bytes = 0 - - def process(x) -> None: - nonlocal total_bytes - if isinstance(x, torch.Tensor): - total_bytes += tensor_storage_size_in_mem(x) * x.element_size() - - tree_map(process, inputs) - return total_bytes - - -def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]): - size_input, size_output = get_size(inputs), get_size(outputs) - return size_input + size_output - - -def output_read_from_input(inputs: List[Any], outputs: List[Any]): - size_input, size_output = get_size(inputs), get_size(outputs) - return size_output + min(size_input, size_output) - - -def output_total_size(inputs: List[Any], outputs: List[Any]): - return get_size(outputs) - - -def input_total_size(inputs: List[Any], outputs: List[Any]): - return get_size(inputs) - - -def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]): - # Approximation that isn't too bad - total_elements = 0 - - def process(x) -> None: - nonlocal total_elements - if isinstance(x, torch.Tensor): - total_elements += x.numel() - - tree_map(process, inputs) - tree_map(process, outputs) - return total_elements / 2 - - -def no_flop(inputs: List[Any], outputs: List[Any]): - return 0 - - -def no_io(inputs: List[Any], outputs: List[Any]): - return 0 - - -aten = torch.ops.aten -NO_FLOPS_NO_IO_OPS = [ - aten.permute, - aten.view, - aten.view_as, - aten.detach, - aten.t, - aten.transpose, - aten.expand, - aten._unsafe_view, - aten.select, - aten.split, - aten.split_with_sizes, - aten.empty, - aten.empty_strided, - aten.empty_like, - aten.is_same_size, -] -NO_FLOPS_OPS = [ - aten._reshape_alias, - aten.reshape, - aten.clone, - aten.cat, - aten.select_backward, - aten.slice, - aten.slice_backward, - aten.ones, - aten.ones_like, - aten.zeros_like, - aten.zero_, - aten.zeros, - aten.masked_fill, - aten.masked_fill_, -] - -flop_mapping = { - aten.mv: GemmOpComputeFlopsMv(), # mat-vec - aten.mm: GemmOpComputeFlops(), - aten.matmul: GemmOpComputeFlops(), - aten.addmm: GemmOpComputeFlopsAddmm(), - aten.bmm: GemmOpComputeFlopsBmm(), - aten.addbmm: GemmOpComputeFlopsAddbmm(), - aten.linear: GemmOpComputeFlopsLinear(), - aten.convolution: conv_flop, - aten._convolution: conv_flop, - aten.convolution_backward: conv_backward_flop, - # Operations with 0 flop - **{op: no_flop for op in NO_FLOPS_OPS}, - **{op: no_flop for op in NO_FLOPS_NO_IO_OPS}, -} -io_mapping = { - aten.clone: output_read_from_input, - aten.cat: output_read_from_input, - aten.slice: output_read_from_input, - aten.ones_like: output_total_size, - aten.zeros_like: output_total_size, - aten.zero_: input_total_size, - **{op: no_io for op in NO_FLOPS_NO_IO_OPS} - # TODO: Check how this is implemented in PT - # aten.slice_backward: no_flop, - # aten.select_backward: no_flop, -} - - -@dataclass -class _OpInfo: - flop_count: float = 0.0 - time_ms: float = 0.0 - io_bytes: int = 0 - is_exact_flop: bool = True - op_name: str = "" - op_suffix: str = "" - stacktrace: Tuple[str, ...] = field(default_factory=tuple) - ev_start: torch.cuda.Event = field( - default_factory=lambda: torch.cuda.Event(enable_timing=True) - ) - ev_end: torch.cuda.Event = field( - default_factory=lambda: torch.cuda.Event(enable_timing=True) - ) - - # Hardware limits for this operation (inf if unknown) - hardware_tflops_limit: float = math.inf - hardware_membw_limit: float = math.inf - - @property - def time_membound_ms(self) -> float: - assert self.time_ms > 0.0 - if self.io_bytes == 0: - return 0.0 - return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit) - - @property - def time_computebound_ms(self) -> float: - assert self.time_ms > 0.0 - tflop = self.flop_count / (1000**4) - if tflop == 0.0: - return 0.0 - return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit) - - def finalize(self) -> None: - self.time_ms = self.ev_start.elapsed_time(self.ev_end) - - -@dataclass -class _OpInfoAggregated: - is_exact_flop: bool = True - total_flop_count: float = 0.0 - total_io_bytes: int = 0 - total_time_ms: float = 0.0 - total_time_membound_ms: float = 0.0 - total_time_computebound_ms: float = 0.0 - num: int = 0 - stacktraces: List[Tuple[str, ...]] = field(default_factory=list) - - def add(self, op: _OpInfo) -> None: - self.total_flop_count += op.flop_count - self.total_time_ms += op.time_ms - self.total_io_bytes += op.io_bytes - self.total_time_membound_ms += op.time_membound_ms - self.total_time_computebound_ms += op.time_computebound_ms - self.num += 1 - self.is_exact_flop = op.is_exact_flop - self.stacktraces.append(op.stacktrace) - - def as_dict(self, **kwargs) -> Dict[str, Any]: - mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms) - tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4) - compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms) - return { - "is_exact_flop": self.is_exact_flop, - "total_flop_count": self.total_flop_count, - "total_time_ms": self.total_time_ms, - "total_io_bytes": self.total_io_bytes, - "num": self.num, - "Tflops": tflops, - "mem_bound": mem_bound, - "compute_bound": compute_bound, - **kwargs, - } - - -class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs): - """ - Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/ - """ - - def __init__(self, main_profiler: _Profiler) -> None: - self.main_profiler = main_profiler - self.trace: List[_OpInfo] = [] - self.temp_disabled = False - - def _hardware_tflops_membw_limit( - self, args: Tuple[Any, ...], outputs: Tuple[Any, ...] - ) -> Tuple[float, float]: - device = None - dtypes: List[torch.dtype] = [] - for a in itertools.chain(outputs, args): - if isinstance(a, torch.Tensor): - if device is None: - device = a.device - dtypes.append(a.dtype) - limits = get_device_limits(device) - if not limits: - return (math.inf, math.inf) - dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops] - if not dtypes or device is None: - return (math.inf, math.inf) - dtype = dtypes[0] - if torch.is_autocast_enabled() and dtype is torch.float32: - dtype = torch.get_autocast_gpu_dtype() - return limits.gemm_tflops[dtype], limits.gmem_bandwidth - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - func_packet = func._overloadpacket - if self.temp_disabled or func_packet.__name__ in [ - "_record_function_exit", - "_record_function_enter_new", - ]: - return func(*args, **kwargs) - - op = _OpInfo() - op.ev_start.record() - out = func(*args, **kwargs) - op.ev_end.record() - - ( - op.hardware_tflops_limit, - op.hardware_membw_limit, - ) = self._hardware_tflops_membw_limit( - args, out if isinstance(out, tuple) else (out,) - ) - op.op_name = func_packet.__name__ - # Prevent functions called by flop counting ops to be recorded - self.temp_disabled = True - flop_count = -1 - compute_flops = None - if func_packet in FUNC_TO_XFORMERS_OPERATOR: - flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop( - *args, **kwargs - ) - if flop_count == -1: - compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op) - flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,)) - if isinstance(compute_flops, GemmOpComputeFlops): - op.op_name += compute_flops.op_suffix(args) - - compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes) - op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,)) - self.temp_disabled = False - - op.stacktrace = tuple(self.main_profiler.parents) - op.flop_count = flop_count - op.is_exact_flop = compute_flops is not guess_flops_unknown_op - self.trace.append(op) - - return out - - def __exit__(self, exc_type, exc_val, exc_tb): - super().__exit__(exc_type, exc_val, exc_tb) - torch.cuda.synchronize() # Wait for the events to be recorded - for op in self.trace: - op.finalize() - self.save_json() - - def step(self) -> None: - pass - - def save_json(self) -> None: - # Aggregate data at the module + op level - all_paths: Set[Tuple[str, ...]] = set() - per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict( - _OpInfoAggregated - ) - per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated) - for op in self.trace: - all_paths.add(op.stacktrace) - for op in self.trace: - for i in range(len(op.stacktrace)): - if op.stacktrace[: i + 1] in all_paths: - per_module_data[op.stacktrace[: i + 1]].add(op) - per_op_data[op.op_name].add(op) - - # Generate JSON - all_data = [] - for stacktrace, agg_info in per_module_data.items(): - all_data.append( - agg_info.as_dict( - agg="module", path=stacktrace, name=stacktrace[-1], op="" - ) - ) - for op_name, agg_info in per_op_data.items(): - # Find the most common path - paths_count: Dict[Tuple[str, ...], int] = defaultdict(int) - agg_info.stacktraces.sort() # In case of a draw, let's always return the same - for p in agg_info.stacktraces: - paths_count[p] += 1 - maxp = agg_info.stacktraces[0] - for p, count in paths_count.items(): - if count > paths_count[maxp]: - maxp = p - all_data.append( - agg_info.as_dict( - agg="opname", - path=f"{'.'.join(maxp)} (x{paths_count[maxp]})", - name="", - op=op_name, - ) - ) - - filename = self.main_profiler._create_output_filename("ops.json") - self.main_profiler.summary.append(("OpsSummary", str(filename))) - with open(filename, "w+") as f: - json.dump(all_data, f)