-
Notifications
You must be signed in to change notification settings - Fork 609
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Have a single _FusedSequenceParallel class handle all dtypes (fairint…
…ernal/xformers#1144) * have a single _FusedSequenceParallel class handle all dtypes * Fix flake8 linter + refactor * Added type annotation * use dtype.itemsize * use uint8 for opaque bytes * remove the extra parentheses * remove useless buffer_metadata: total_num_bytes computation is cheap * remove paranthesis * remove paranthesis * use the same sequence number for each dtype * using staging.view(dtype) * simplified uint8 handling in linear_and_reducescatter * using linters versions from requirements-test.txt * removed useless pair of parentheses * added scattered_inputs elements dtype consistency check * fixed my_matmul multiline call formatting mishap * black format fix for _ensure_staging_is_large_enough call * add test showcasing handling multiple dtypes * refactored fused and non fused output comparison * put subbatch_dims line back __original_commit__ = fairinternal/xformers@4b4d8e7
- Loading branch information
Showing
2 changed files
with
111 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters