Skip to content

Commit

Permalink
Profiler: Misc bug fixes (fairinternal/xformers#1176)
Browse files Browse the repository at this point in the history
- Profiler: Fix computation of FLOPS for the attention when using xFormers
- Profiler: Fix MFU/HFU calculation when multiple dtypes are used

Co-authored-by: danthe3rd <danthe3rd>

__original_commit__ = fairinternal/xformers@8c994ef
  • Loading branch information
danthe3rd authored and xFormers Bot committed Aug 7, 2024
1 parent 926f410 commit ad445f8
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 55 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.0.28] - TBD
### Added
### Improved
- Profiler: Fix computation of FLOPS for the attention when using xFormers
- Profiler: Fix MFU/HFU calculation when multiple dtypes are used
### Removed

## [0.0.27.post2] - 2024-07-26
Expand Down
20 changes: 11 additions & 9 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,10 @@ def test_analyze_prof(dtype) -> None:
@pytest.mark.parametrize("causal", [True, False], ids=["causal", ""])
@cuda_only
def test_analyze_prof_sdpa(dtype, enable_flash: bool, causal: bool) -> None:
B, N = 64, 128
x = torch.ones([B, 1, N, 128], dtype=dtype, device="cuda", requires_grad=True)
fw_flops = 2 * 2 * B * N * N * 128
B, M, H, K = 64, 256, 3, 128
x = torch.ones([B, H, M, K], dtype=dtype, device="cuda", requires_grad=True)
fw_flops = 2 * 2 * M * M * K
fw_flops *= B * H
if causal:
fw_flops //= 2
with sdpa_kernel(
Expand All @@ -252,16 +253,17 @@ def test_analyze_prof_sdpa(dtype, enable_flash: bool, causal: bool) -> None:
@cuda_only
def test_analyze_prof_memeff(op, causal: bool) -> None:
dtype = torch.float16
B, N = 64, 128
x = torch.ones([B, 1, N, 128], dtype=dtype, device="cuda", requires_grad=True)
device_sm = torch.cuda.get_device_capability(x.device)
if device_sm < op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY:
pytest.skip(f"Requires sm{op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY}")
fw_flops = 2 * 2 * B * N * N * 128
B, M, H, K = 64, 256, 3, 128
x = torch.ones([B, M, H, K], dtype=dtype, device="cuda", requires_grad=True)
fw_flops = 2 * 2 * M * M * K
fw_flops *= B * H
bias = None
if causal:
bias = fmha.attn_bias.LowerTriangularMask()
fw_flops //= 2
device_sm = torch.cuda.get_device_capability(x.device)
if device_sm < op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY:
pytest.skip(f"Requires sm{op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY}")
with assert_flops("memory_efficient_attention", match=fw_flops):
y = xops.memory_efficient_attention(x, x, x, attn_bias=bias, op=op)
with assert_flops("memory_efficient_attention BW", match=fw_flops * 5 // 2):
Expand Down
77 changes: 40 additions & 37 deletions xformers/profiler/profile_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ def __init__(self, e: torch._C._autograd._KinetoEvent) -> None:
self._kineto_event = e


def _attention_flops(queries, values, causal: bool) -> int:
def _attention_flops(queries, values, causal: bool, fmt: str = "BHMK") -> int:
assert isinstance(causal, bool)
assert fmt in ["BMHK", "BHMK"]
if fmt == "BMHK":
queries, values = [[x[0], x[2], x[1], x[3]] for x in [queries, values]]
*B, N, K = queries
*B, Nv, Kv = values
if causal: # NOTE: Causal from bottom right
Expand Down Expand Up @@ -58,23 +61,39 @@ def _replace_if_needed(
op_name = e.name()
flops = None

FMT_BMHK = dict(fmt="BMHK")
ATTN_OPS = {
getattr(lib, op).default.name(): (getattr(lib, op), is_bwd)
for lib, op, is_bwd in [
(torch.ops.aten, "scaled_dot_product_attention", False),
(torch.ops.xformers_flash, "flash_fwd", False),
(torch.ops.xformers, "efficient_attention_forward_cutlass", False),
(torch.ops.aten, "_efficient_attention_forward", False),
(torch.ops.aten, "_scaled_dot_product_flash_attention_backward", True),
(torch.ops.aten, "_scaled_dot_product_efficient_attention_backward", True),
(torch.ops.xformers_flash, "flash_bwd", True),
(torch.ops.xformers, "efficient_attention_backward_cutlass", True),
(torch.ops.aten, "_efficient_attention_backward", True),
getattr(lib, op).default.name(): (getattr(lib, op), is_bwd, kwargs)
for lib, op, is_bwd, kwargs in [
(torch.ops.aten, "scaled_dot_product_attention", False, {}),
(torch.ops.xformers_flash, "flash_fwd", False, FMT_BMHK),
(
torch.ops.xformers,
"efficient_attention_forward_cutlass",
False,
FMT_BMHK,
),
(torch.ops.aten, "_efficient_attention_forward", False, FMT_BMHK),
(torch.ops.aten, "_scaled_dot_product_flash_attention_backward", True, {}),
(
torch.ops.aten,
"_scaled_dot_product_efficient_attention_backward",
True,
{},
),
(torch.ops.xformers_flash, "flash_bwd", True, FMT_BMHK),
(
torch.ops.xformers,
"efficient_attention_backward_cutlass",
True,
FMT_BMHK,
),
(torch.ops.aten, "_efficient_attention_backward", True, FMT_BMHK),
]
if hasattr(lib, op)
}
if op_name in ATTN_OPS.keys():
op, is_bwd = ATTN_OPS[op_name]
op, is_bwd, kwargs = ATTN_OPS[op_name]
shapes = e.shapes()
concrete_inputs = e.concrete_inputs()
try:
Expand All @@ -85,6 +104,7 @@ def _replace_if_needed(
shapes[_get_arg_idx(op, "query")],
shapes[_get_arg_idx(op, "value")],
is_causal,
**kwargs,
)
if is_bwd:
flops = flops * 5 // 2
Expand All @@ -95,24 +115,6 @@ def _replace_if_needed(
return e


# BW compat with older PT versions
if "start_ns" in dir(torch._C._autograd._KinetoEvent):

def _start_ns(e) -> int:
return e.start_ns()

def _duration_ns(e) -> int:
return e.duration_ns()

else:

def _start_ns(e) -> int:
return e.start_us() * 1000

def _duration_ns(e) -> int:
return e.duration_us() * 1000


@dataclass
class AnalyzedTrace:
operations_per_dtype_fw: Dict[torch.dtype, float]
Expand Down Expand Up @@ -169,19 +171,20 @@ def from_profile(
def _find_parent_op(
e: torch._C._autograd._KinetoEvent,
) -> torch._C._autograd._KinetoEvent:
e_range = [_start_ns(e), _start_ns(e) + _duration_ns(e)]
e_range = [e.start_ns(), e.start_ns() + e.duration_ns()]
candidate = e
for parent in all_ops:
if parent.device_type() != e.device_type():
continue
if parent.start_thread_id() != e.start_thread_id():
continue
p_range = [_start_ns(parent), _start_ns(parent) + _duration_ns(parent)]
p_range = [parent.start_ns(), parent.start_ns() + parent.duration_ns()]
if not (p_range[0] < e_range[0] < e_range[1] < p_range[1]):
continue
# We take the longest parent with flops
if parent.flops() > 0 and _duration_ns(candidate) < _duration_ns(
parent
if (
parent.flops() > 0
and candidate.duration_ns() < parent.duration_ns()
):
candidate = parent
return candidate
Expand Down Expand Up @@ -221,8 +224,8 @@ def _find_parent_op(
for op in events:
if op.device_type().name != "CUDA":
continue
begin_ns = min(begin_ns, _start_ns(op))
end_ns = max(end_ns, _start_ns(op) + _duration_ns(op))
begin_ns = min(begin_ns, op.start_ns())
end_ns = max(end_ns, op.start_ns() + op.duration_ns())

return AnalyzedTrace(
operations_per_dtype_fw=operations_per_dtype_fw,
Expand Down
25 changes: 16 additions & 9 deletions xformers/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class PyTorchProfiler:

def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
self.num_steps = 0
self.pytorch_profiler = torch.profiler.profile(
on_trace_ready=self._on_trace,
profile_memory=True,
Expand Down Expand Up @@ -95,16 +96,18 @@ def _analyze_trace(self, prof: torch.profiler.profiler.profile) -> None:
if limits is not None:
for dtype, tflops in limits.gemm_tflops.items():
hw_flops[dtype] = tflops * (1000**4)
total_flops = 0.0
total_hfu = 0.0
total_mfu = 0.0
for dtype in results.operations_per_dtype_fw.keys():
total_flops += results.compute_num_ops(dtype) / results.total_time_s
total_hfu += results.compute_hfu(hw_flops)
total_mfu += results.compute_mfu(hw_flops)
total_hfu = results.compute_hfu(hw_flops)
total_mfu = results.compute_mfu(hw_flops)
total_flop = sum(
results.compute_num_ops(dtype)
for dtype in results.operations_per_dtype_fw.keys()
)
s = self.main_profiler.summary
s.append(("Step time (ms)", f"{int(results.total_time_s * 1000)}"))
s.append(("TFlops", f"{total_flops / (1000**4):0.1f}"))
s.append(
("Step time (ms)", f"{int(results.total_time_s * 1000 / self.num_steps)}")
)
s.append(("TFlop/step", f"{total_flop / (self.num_steps * 1000**4):0.1f}"))
s.append(("TFlops", f"{total_flop / (results.total_time_s * 1000**4):0.1f}"))
s.append(("HFU", f"{total_hfu:0.3f}"))
s.append(("MFU", f"{total_mfu:0.3f}"))

Expand All @@ -118,6 +121,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def step(self) -> None:
self.pytorch_profiler.step()
self.num_steps += 1


class PyTorchProfiler_CUDAOnly(PyTorchProfiler):
Expand Down Expand Up @@ -267,6 +271,9 @@ def update_profilers_on_step(self) -> None:
o = p.object
p.object = None
logging.info(f"Shutting down {p.cls.__name__} profiler...")
# Make sure the profiler's `step` function is called
# $N times when we do $N steps with this profiler.
o.step()
o.__exit__(None, None, None)

def _create_output_filename(self, filename: str) -> Path:
Expand Down

0 comments on commit ad445f8

Please sign in to comment.