Skip to content

Commit

Permalink
Add rowwise fp8 matmul
Browse files Browse the repository at this point in the history
ghstack-source-id: 176054681fccd1427c24c04b4ea0e01980bbed4e
Pull Request resolved: fairinternal/xformers#1148

__original_commit__ = fairinternal/xformers@f31779c
  • Loading branch information
lw authored and xFormers Bot committed Jul 3, 2024
1 parent a9e2e7b commit 5dbdc2e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 31 deletions.
17 changes: 2 additions & 15 deletions .github/workflows/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ on:
# https://github.com/actions/runner/issues/1182

env:
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: "6.0+PTX 6.1 7.0 7.5 8.0+PTX"
MAX_JOBS: 3 # Avoids OOMs
XFORMERS_BUILD_TYPE: "Release"
XFORMERS_PACKAGE_FROM: "conda-${{ github.ref_name }}"
Expand Down Expand Up @@ -53,6 +51,8 @@ jobs:
env:
# alias for the current python version
PY: /opt/conda/bin/python
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: ${{ join('6.0+PTX 7.0 7.5 8.0+PTX', fromJSON(inputs.cuda_short_version) >= 118 && ' 9.0a' || '') }}

timeout-minutes: 360
defaults:
Expand All @@ -75,19 +75,6 @@ jobs:
if: startsWith(github.ref, 'refs/tags/v')
run: echo "XFORMERS_CONDA_TAG=main" >> $GITHUB_ENV
- run: echo "${XFORMERS_CONDA_TAG}"
- name: Add H100 if nvcc 11.08+
shell: python
run: |
import os
import sys
print(sys.version)
cuda_short_version = "${{ matrix.config.cuda_short_version }}"
arch_list = os.environ["TORCH_CUDA_ARCH_LIST"]
if cuda_short_version not in ["116", "117"]:
arch_list += " 9.0"
with open(os.environ['GITHUB_ENV'], "r+") as fp:
fp.write("TORCH_CUDA_ARCH_LIST=" + arch_list + "\n")
- run: echo "${TORCH_CUDA_ARCH_LIST}"
- name: Free up disk space
run: |
df -h /
Expand Down
16 changes: 1 addition & 15 deletions .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ on:

env:
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: "6.0+PTX 6.1 7.0 7.5 8.0+PTX"
TORCH_CUDA_ARCH_LIST: ${{ join('6.0+PTX 7.0 7.5 8.0+PTX', fromJSON(inputs.cuda_short_version) >= 118 && ' 9.0a' || '') }}
MAX_JOBS: 4
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"
Expand Down Expand Up @@ -78,20 +78,6 @@ jobs:
- run: echo "TORCH_ORG_S3_PATH=${{ steps.cuda_info.outputs.TORCH_ORG_S3_PATH }}"
- run: echo "PUBLISH_PYPI=${{ steps.cuda_info.outputs.PUBLISH_PYPI }}"

- name: Add H100 if nvcc 11.08+
shell: python
run: |
import os
import sys
print(sys.version)
cuda_short_version = "${{ inputs.cuda_short_version }}"
arch_list = os.environ["TORCH_CUDA_ARCH_LIST"]
if cuda_short_version not in ["116", "117"]:
arch_list += " 9.0"
with open(os.environ['GITHUB_ENV'], "r+") as fp:
fp.write("TORCH_CUDA_ARCH_LIST=" + arch_list + "\n")
- run: echo "${TORCH_CUDA_ARCH_LIST}"

- name: Recursive checkout
uses: actions/checkout@v3
with:
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include third_party/flash-attention/version.txt
recursive-include xformers/csrc *
recursive-include third_party/sputnik *
recursive-include third_party/cutlass/include *
recursive-include third_party/cutlass/tools/util/include *
recursive-include third_party/cutlass/examples *
recursive-include third_party/flash-attention/csrc *
recursive-include third_party/flash-attention/flash_attn *
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def get_extensions():
source_cuda = list(set(source_cuda) - set(fmha_source_cuda))

cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_util_dir = os.path.join(
this_dir, "third_party", "cutlass", "tools", "util", "include"
)
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
Expand Down Expand Up @@ -337,7 +340,12 @@ def get_extensions():
cuda_version = get_cuda_version(CUDA_HOME)
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
include_dirs += [
sputnik_dir,
cutlass_dir,
cutlass_util_dir,
cutlass_examples_dir,
]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
Expand Down

0 comments on commit 5dbdc2e

Please sign in to comment.