Skip to content

Commit

Permalink
Post-Merge CI (#612)
Browse files Browse the repository at this point in the history
* remove on push for Integration Tests

* rename

* add post merge test

* save

* dtype params

* skip bad config

* fix more stuff
  • Loading branch information
micmelesse authored Jul 16, 2024
1 parent 18930eb commit d26ef1d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: AMD Perf Kernel Tests
name: AMD Perf Kernel Integration Tests

on:
workflow_dispatch:
Expand All @@ -7,8 +7,6 @@ on:
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf]

concurrency:
group: ${{ github.ref }}
Expand Down Expand Up @@ -36,8 +34,8 @@ jobs:
changed_files=$(git diff --name-only origin/${{ github.base_ref }} ${{ github.sha }})
echo "Changed files:"
echo "$changed_files"
if echo "$changed_files" | grep -v "^python/perf-kernels/"; then
echo "Changes detected outside of the python/perf-kernels directory. Failing the workflow."
if echo "$changed_files" | grep -vE "^python/perf-kernels/|^\.github/workflows/amd_"; then
echo "Changes detected outside of the python/perf-kernels directory or .github/workflows/amd_ files. Failing the workflow."
exit 1
fi
Expand Down
92 changes: 92 additions & 0 deletions .github/workflows/amd_perf_kernel_postmerge_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
name: AMD Perf Kernel Post-Merge Tests

on:
workflow_dispatch:
push:
branches: [main_perf, micmelesse/post_merge_ci]

concurrency:
group: ${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }}

permissions: read-all

env:
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
TRITON_DISABLE_LINE_INFO: 1

jobs:
Runner-Preparation-AMD:
runs-on: ubuntu-latest
timeout-minutes: 30
outputs:
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
PostMerge-Tests-AMD:
needs: Runner-Preparation-AMD
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
runs-on: ${{ matrix.runner }}
timeout-minutes: 30
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # Ensure the entire history is fetched for rebase
- name: Add upstream remote
run: |
git config --global --add safe.directory /__w/triton/triton
if [ $(git remote | grep -c upstream) -eq 0 ]; then
git remote add upstream https://github.com/triton-lang/triton.git
fi
git fetch upstream
- name: Rebase onto upstream/main
run: |
git config --global user.email "[email protected]"
git config --global user.name "Github Actions Post-Merge CI Script"
git rebase upstream/main || { echo "Rebase failed"; exit 1; }
- name: Show Git Log
run: |
echo "Git log after rebase from upstream/main to HEAD:"
git log $(git rev-parse upstream/main~2)..HEAD --oneline --graph --decorate
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Clear cache
run: |
rm -rf ~/.triton
mkdir -p ~/.triton
ls -alh ~/.triton
- name: Update PATH
run: |
echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH
- name: Install pip dependencies
run: |
python3 -m pip install --upgrade pip
python3 -m pip install lit matplotlib pandas
- name: Install Triton
run: |
echo "PATH is '$PATH'"
pip uninstall -y triton
cd python
pip install -v -e .
- name: Run Perf Kernels Unit Tests
run: |
pytest -vvv ./python/perf-kernels/flash-attention.py
- name: Run Perf Kernels Benchmark
run: |
python ./python/perf-kernels/flash-attention.py
11 changes: 6 additions & 5 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
# num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
Expand Down Expand Up @@ -1166,15 +1166,16 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout,
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_bias', [True])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16):
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
torch.manual_seed(20)
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd')
if causal:
input_metadata.need_causal()
if use_bias:
bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda")
bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda")
input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K)
else:
bias = None
Expand All @@ -1197,7 +1198,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor
# this by converting the NaNs to 0s, which is what they should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask == 1] = 0
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)

Expand Down

0 comments on commit d26ef1d

Please sign in to comment.