From ad445f85abd50648d95bcd9cdc40cd3cd8fa98b8 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:59:16 +0000 Subject: [PATCH] Profiler: Misc bug fixes (fairinternal/xformers#1176) - Profiler: Fix computation of FLOPS for the attention when using xFormers - Profiler: Fix MFU/HFU calculation when multiple dtypes are used Co-authored-by: danthe3rd __original_commit__ = fairinternal/xformers@8c994eff39f52ea8d9f60b66552b6107cd56661b --- CHANGELOG.md | 2 + tests/test_profiler.py | 20 +++---- xformers/profiler/profile_analyzer.py | 77 ++++++++++++++------------- xformers/profiler/profiler.py | 25 +++++---- 4 files changed, 69 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ce8489ef1..992513244d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.0.28] - TBD ### Added ### Improved +- Profiler: Fix computation of FLOPS for the attention when using xFormers +- Profiler: Fix MFU/HFU calculation when multiple dtypes are used ### Removed ## [0.0.27.post2] - 2024-07-26 diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 253b98397b..04a72c9edb 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -225,9 +225,10 @@ def test_analyze_prof(dtype) -> None: @pytest.mark.parametrize("causal", [True, False], ids=["causal", ""]) @cuda_only def test_analyze_prof_sdpa(dtype, enable_flash: bool, causal: bool) -> None: - B, N = 64, 128 - x = torch.ones([B, 1, N, 128], dtype=dtype, device="cuda", requires_grad=True) - fw_flops = 2 * 2 * B * N * N * 128 + B, M, H, K = 64, 256, 3, 128 + x = torch.ones([B, H, M, K], dtype=dtype, device="cuda", requires_grad=True) + fw_flops = 2 * 2 * M * M * K + fw_flops *= B * H if causal: fw_flops //= 2 with sdpa_kernel( @@ -252,16 +253,17 @@ def test_analyze_prof_sdpa(dtype, enable_flash: bool, causal: bool) -> None: @cuda_only def test_analyze_prof_memeff(op, causal: bool) -> None: dtype = torch.float16 - B, N = 64, 128 - x = torch.ones([B, 1, N, 128], dtype=dtype, device="cuda", requires_grad=True) - device_sm = torch.cuda.get_device_capability(x.device) - if device_sm < op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY: - pytest.skip(f"Requires sm{op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY}") - fw_flops = 2 * 2 * B * N * N * 128 + B, M, H, K = 64, 256, 3, 128 + x = torch.ones([B, M, H, K], dtype=dtype, device="cuda", requires_grad=True) + fw_flops = 2 * 2 * M * M * K + fw_flops *= B * H bias = None if causal: bias = fmha.attn_bias.LowerTriangularMask() fw_flops //= 2 + device_sm = torch.cuda.get_device_capability(x.device) + if device_sm < op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY: + pytest.skip(f"Requires sm{op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY}") with assert_flops("memory_efficient_attention", match=fw_flops): y = xops.memory_efficient_attention(x, x, x, attn_bias=bias, op=op) with assert_flops("memory_efficient_attention BW", match=fw_flops * 5 // 2): diff --git a/xformers/profiler/profile_analyzer.py b/xformers/profiler/profile_analyzer.py index fff2e99ea8..07b991843b 100644 --- a/xformers/profiler/profile_analyzer.py +++ b/xformers/profiler/profile_analyzer.py @@ -20,8 +20,11 @@ def __init__(self, e: torch._C._autograd._KinetoEvent) -> None: self._kineto_event = e -def _attention_flops(queries, values, causal: bool) -> int: +def _attention_flops(queries, values, causal: bool, fmt: str = "BHMK") -> int: assert isinstance(causal, bool) + assert fmt in ["BMHK", "BHMK"] + if fmt == "BMHK": + queries, values = [[x[0], x[2], x[1], x[3]] for x in [queries, values]] *B, N, K = queries *B, Nv, Kv = values if causal: # NOTE: Causal from bottom right @@ -58,23 +61,39 @@ def _replace_if_needed( op_name = e.name() flops = None + FMT_BMHK = dict(fmt="BMHK") ATTN_OPS = { - getattr(lib, op).default.name(): (getattr(lib, op), is_bwd) - for lib, op, is_bwd in [ - (torch.ops.aten, "scaled_dot_product_attention", False), - (torch.ops.xformers_flash, "flash_fwd", False), - (torch.ops.xformers, "efficient_attention_forward_cutlass", False), - (torch.ops.aten, "_efficient_attention_forward", False), - (torch.ops.aten, "_scaled_dot_product_flash_attention_backward", True), - (torch.ops.aten, "_scaled_dot_product_efficient_attention_backward", True), - (torch.ops.xformers_flash, "flash_bwd", True), - (torch.ops.xformers, "efficient_attention_backward_cutlass", True), - (torch.ops.aten, "_efficient_attention_backward", True), + getattr(lib, op).default.name(): (getattr(lib, op), is_bwd, kwargs) + for lib, op, is_bwd, kwargs in [ + (torch.ops.aten, "scaled_dot_product_attention", False, {}), + (torch.ops.xformers_flash, "flash_fwd", False, FMT_BMHK), + ( + torch.ops.xformers, + "efficient_attention_forward_cutlass", + False, + FMT_BMHK, + ), + (torch.ops.aten, "_efficient_attention_forward", False, FMT_BMHK), + (torch.ops.aten, "_scaled_dot_product_flash_attention_backward", True, {}), + ( + torch.ops.aten, + "_scaled_dot_product_efficient_attention_backward", + True, + {}, + ), + (torch.ops.xformers_flash, "flash_bwd", True, FMT_BMHK), + ( + torch.ops.xformers, + "efficient_attention_backward_cutlass", + True, + FMT_BMHK, + ), + (torch.ops.aten, "_efficient_attention_backward", True, FMT_BMHK), ] if hasattr(lib, op) } if op_name in ATTN_OPS.keys(): - op, is_bwd = ATTN_OPS[op_name] + op, is_bwd, kwargs = ATTN_OPS[op_name] shapes = e.shapes() concrete_inputs = e.concrete_inputs() try: @@ -85,6 +104,7 @@ def _replace_if_needed( shapes[_get_arg_idx(op, "query")], shapes[_get_arg_idx(op, "value")], is_causal, + **kwargs, ) if is_bwd: flops = flops * 5 // 2 @@ -95,24 +115,6 @@ def _replace_if_needed( return e -# BW compat with older PT versions -if "start_ns" in dir(torch._C._autograd._KinetoEvent): - - def _start_ns(e) -> int: - return e.start_ns() - - def _duration_ns(e) -> int: - return e.duration_ns() - -else: - - def _start_ns(e) -> int: - return e.start_us() * 1000 - - def _duration_ns(e) -> int: - return e.duration_us() * 1000 - - @dataclass class AnalyzedTrace: operations_per_dtype_fw: Dict[torch.dtype, float] @@ -169,19 +171,20 @@ def from_profile( def _find_parent_op( e: torch._C._autograd._KinetoEvent, ) -> torch._C._autograd._KinetoEvent: - e_range = [_start_ns(e), _start_ns(e) + _duration_ns(e)] + e_range = [e.start_ns(), e.start_ns() + e.duration_ns()] candidate = e for parent in all_ops: if parent.device_type() != e.device_type(): continue if parent.start_thread_id() != e.start_thread_id(): continue - p_range = [_start_ns(parent), _start_ns(parent) + _duration_ns(parent)] + p_range = [parent.start_ns(), parent.start_ns() + parent.duration_ns()] if not (p_range[0] < e_range[0] < e_range[1] < p_range[1]): continue # We take the longest parent with flops - if parent.flops() > 0 and _duration_ns(candidate) < _duration_ns( - parent + if ( + parent.flops() > 0 + and candidate.duration_ns() < parent.duration_ns() ): candidate = parent return candidate @@ -221,8 +224,8 @@ def _find_parent_op( for op in events: if op.device_type().name != "CUDA": continue - begin_ns = min(begin_ns, _start_ns(op)) - end_ns = max(end_ns, _start_ns(op) + _duration_ns(op)) + begin_ns = min(begin_ns, op.start_ns()) + end_ns = max(end_ns, op.start_ns() + op.duration_ns()) return AnalyzedTrace( operations_per_dtype_fw=operations_per_dtype_fw, diff --git a/xformers/profiler/profiler.py b/xformers/profiler/profiler.py index ae4826dd1a..e50ab9284d 100644 --- a/xformers/profiler/profiler.py +++ b/xformers/profiler/profiler.py @@ -59,6 +59,7 @@ class PyTorchProfiler: def __init__(self, main_profiler: "_Profiler") -> None: self.main_profiler = main_profiler + self.num_steps = 0 self.pytorch_profiler = torch.profiler.profile( on_trace_ready=self._on_trace, profile_memory=True, @@ -95,16 +96,18 @@ def _analyze_trace(self, prof: torch.profiler.profiler.profile) -> None: if limits is not None: for dtype, tflops in limits.gemm_tflops.items(): hw_flops[dtype] = tflops * (1000**4) - total_flops = 0.0 - total_hfu = 0.0 - total_mfu = 0.0 - for dtype in results.operations_per_dtype_fw.keys(): - total_flops += results.compute_num_ops(dtype) / results.total_time_s - total_hfu += results.compute_hfu(hw_flops) - total_mfu += results.compute_mfu(hw_flops) + total_hfu = results.compute_hfu(hw_flops) + total_mfu = results.compute_mfu(hw_flops) + total_flop = sum( + results.compute_num_ops(dtype) + for dtype in results.operations_per_dtype_fw.keys() + ) s = self.main_profiler.summary - s.append(("Step time (ms)", f"{int(results.total_time_s * 1000)}")) - s.append(("TFlops", f"{total_flops / (1000**4):0.1f}")) + s.append( + ("Step time (ms)", f"{int(results.total_time_s * 1000 / self.num_steps)}") + ) + s.append(("TFlop/step", f"{total_flop / (self.num_steps * 1000**4):0.1f}")) + s.append(("TFlops", f"{total_flop / (results.total_time_s * 1000**4):0.1f}")) s.append(("HFU", f"{total_hfu:0.3f}")) s.append(("MFU", f"{total_mfu:0.3f}")) @@ -118,6 +121,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def step(self) -> None: self.pytorch_profiler.step() + self.num_steps += 1 class PyTorchProfiler_CUDAOnly(PyTorchProfiler): @@ -267,6 +271,9 @@ def update_profilers_on_step(self) -> None: o = p.object p.object = None logging.info(f"Shutting down {p.cls.__name__} profiler...") + # Make sure the profiler's `step` function is called + # $N times when we do $N steps with this profiler. + o.step() o.__exit__(None, None, None) def _create_output_filename(self, filename: str) -> Path: