Skip to content

Commit

Permalink
PR #17330: Add stride for amax_o/s for fp8 cudnn fused attention
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17330

As per requirement of cudnn graph API, the amax_s and amax_o has to be set stride. Otherwise, the following error will be hit.
```
xla/service/gpu/tests/gpu_fused_mha_test.cc:1348
Value of: RunAndCompareTwoModules(hlo_string, hlo_string_ref, ErrorSpec{1e-2, 1e-2})
  Actual: false (INTERNAL: Tensor 'sdpa_fp8::Amax_O' strides not set.
in xla/stream_executor/cuda/cuda_dnn.cc(8232): 'graph_.validate()' )
Copybara import of the project:

--
01c0ede by shuw <[email protected]>:

Add strides for amax_o/s at graph building which is required by cudnn-fe. Add tests for bnth and btnh layouts.

--
16b83a2 by Shu Wang <[email protected]>:

Split into multiple lines.
--
77a8e91 by shuw <[email protected]>:

Improve after review 1

Merging this change closes #17330

COPYBARA_INTEGRATE_REVIEW=#17330 from wenscarl:sdpa_fp8_amax_stride 77a8e91
PiperOrigin-RevId: 679474160
  • Loading branch information
wenscarl authored and Google-ML-Automation committed Sep 27, 2024
1 parent 82c2770 commit f7c4b2c
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 91 deletions.
Loading

0 comments on commit f7c4b2c

Please sign in to comment.