Skip to content

Commit

Permalink
feat(torch-extras): Add flash-attn 3 beta
Browse files Browse the repository at this point in the history
  • Loading branch information
Eta0 committed Jul 24, 2024
1 parent 82434b6 commit a45d087
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 51 deletions.
8 changes: 2 additions & 6 deletions .github/configurations/torch-base.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
cuda: [ 12.4.1, 12.3.2, 12.2.2, 12.0.1, 11.8.0 ]
os: [ ubuntu22.04, ubuntu20.04 ]
exclude:
# Not a supported combination
- cuda: 11.8.0
os: ubuntu22.04
cuda: [ 12.4.1 ]
os: [ ubuntu22.04 ]
include:
- torch: 2.3.1
vision: 0.18.1
Expand Down
41 changes: 0 additions & 41 deletions .github/configurations/torch-nccl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,6 @@ image:
os: ubuntu22.04
nccl: 2.21.5-1
nccl-tests-hash: 85f9143
- cuda: 12.3.2
cudnn: cudnn9
os: ubuntu22.04
nccl: 2.20.3-1
nccl-tests-hash: 85f9143
- cuda: 12.2.2
cudnn: cudnn8
os: ubuntu22.04
nccl: 2.19.3-1
nccl-tests-hash: 85f9143
- cuda: 12.0.1
cudnn: cudnn8
os: ubuntu22.04
nccl: 2.18.5-1
nccl-tests-hash: 85f9143
# Ubuntu 20.04
- cuda: 12.4.1
cudnn: cudnn
os: ubuntu20.04
nccl: 2.21.5-1
nccl-tests-hash: 85f9143
- cuda: 12.3.2
cudnn: cudnn9
os: ubuntu20.04
nccl: 2.20.3-1
nccl-tests-hash: 85f9143
- cuda: 12.2.2
cudnn: cudnn8
os: ubuntu20.04
nccl: 2.21.5-1
nccl-tests-hash: 85f9143
- cuda: 12.0.1
cudnn: cudnn8
os: ubuntu20.04
nccl: 2.19.3-1
nccl-tests-hash: 85f9143
- cuda: 11.8.0
cudnn: cudnn8
os: ubuntu20.04
nccl: 2.16.5-1
nccl-tests-hash: 868dc3d
include:
- torch: 2.3.1
vision: 0.18.1
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ on:
cuda-arch-support:
required: false
type: string
default: "7.0 7.5 8.0 8.6 8.9 9.0+PTX"
default: "9.0+PTX"
image-name:
required: false
type: string
Expand Down Expand Up @@ -71,7 +71,7 @@ on:
required: false
description: "Space-separated list of CUDA architectures to support"
type: string
default: "7.0 7.5 8.0 8.6 8.9 9.0+PTX"
default: "9.0+PTX"
image-name:
required: false
description: "Custom name under which to publish the resulting container"
Expand Down
5 changes: 3 additions & 2 deletions torch-extras/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

ARG BASE_IMAGE
ARG DEEPSPEED_VERSION="0.14.2"
ARG FLASH_ATTN_VERSION="2.5.9.post1"
ARG FLASH_ATTN_VERSION="1899c970c8639e82e6b8a78408f4041425e9f900"
ARG APEX_COMMIT="23c1f86520e22b505e8fdfcf6298273dff2d93d8"
ARG XFORMERS_VERSION="0.0.26.post1"

Expand All @@ -11,7 +11,7 @@ WORKDIR /git
ARG FLASH_ATTN_VERSION
RUN git clone --recurse-submodules --shallow-submodules -j8 --depth 1 \
--filter=blob:none --also-filter-submodules \
https://github.com/Dao-AILab/flash-attention -b v${FLASH_ATTN_VERSION}
https://github.com/Dao-AILab/flash-attention -b ${FLASH_ATTN_VERSION}

FROM alpine/git:2.36.3 as apex-downloader
WORKDIR /git
Expand Down Expand Up @@ -168,6 +168,7 @@ RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,tar
( \
for EXT_DIR in $(realpath -s -e \
. \
hopper \
csrc/ft_attention \
csrc/fused_dense_lib \
csrc/fused_softmax \
Expand Down

0 comments on commit a45d087

Please sign in to comment.