forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LinalgExt] Masked Attention Implementation (iree-org#18525)
Enables float/boolean mask as parameters and created linalg generic ops to apply masking. This image (https://imgur.com/a/1MePgcy) elaborates on the main files changed and how they enable masked attention: - Blue boxes represent changed .cpp and .td files to enable/pass/decompose the mask - Yellow boxes represent the different op classes - Red boxes represent test mlir files pertaining to certain .cpp/.td implementations or ops For quick reference, AggregateOpInterfaceImpl.cpp contains the bulk of the actual mask decomposition (QK += mask) And for clarification, TileAttention.cpp only holds the convertToOnlineAttentionOp and getTileAttentionIndexingMaps functions; TilingInterfaceImpl.cpp contains the main tiling capabilities in the form of AttentionOp::getTiledImplementation and OnlineAttentionOp::getTiledImplementation. Updated version of iree-org#18461. This version was created to include scale affine map and enable fused attention (incorporated https://github.com/IanWood1/iree/tree/raikonen/sdpa_mask). - To that end, many modifications in tests are for adding the scale affine map (without much functionality change) - For tiling and decomposition tests, most functionality tests are included in "tiling.mlir" and "decompose_online_attention.mlir". On the other hand, the "tile_attention.mlir and "decompose_attention.mlir" are old paths intended to be be retired and deprecate soon. Hence, no major tests were added it there. Test directory for numerical verification: https://github.com/rohan-tan-bhowmik/iree-masked-attention-test --------- Signed-off-by: Stanley Winata <[email protected]> Co-authored-by: Stanley Winata <[email protected]> Co-authored-by: Ian Wood <[email protected]>
- Loading branch information
1 parent
891f438
commit 9ee061d
Showing
30 changed files
with
1,211 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.