Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #17330: Add stride for amax_o/s for fp8 cudnn fused attention
Imported from GitHub PR #17330 As per requirement of cudnn graph API, the amax_s and amax_o has to be set stride. Otherwise, the following error will be hit. ``` xla/service/gpu/tests/gpu_fused_mha_test.cc:1348 Value of: RunAndCompareTwoModules(hlo_string, hlo_string_ref, ErrorSpec{1e-2, 1e-2}) Actual: false (INTERNAL: Tensor 'sdpa_fp8::Amax_O' strides not set. in xla/stream_executor/cuda/cuda_dnn.cc(8232): 'graph_.validate()' ) Copybara import of the project: -- 01c0ede by shuw <[email protected]>: Add strides for amax_o/s at graph building which is required by cudnn-fe. Add tests for bnth and btnh layouts. -- 16b83a2 by Shu Wang <[email protected]>: Split into multiple lines. -- 77a8e91 by shuw <[email protected]>: Improve after review 1 Merging this change closes #17330 COPYBARA_INTEGRATE_REVIEW=#17330 from wenscarl:sdpa_fp8_amax_stride 77a8e91 PiperOrigin-RevId: 679474160
- Loading branch information