Skip to content

Commit

Permalink
Aggressively trim test bloat (#346)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

1. Disable the test for experimental kernels
2. Reduce the size of tensor if the tests takes too long
3. Remove redundant tests that are testing the same thing

Make sure unit test time < 5 mins

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
ByronHsu authored Nov 4, 2024
1 parent e68b291 commit a2dfa3c
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 283 deletions.
50 changes: 25 additions & 25 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,28 @@ jobs:
run: |
modal run dev.modal.unit_tests
convergence-tests:
runs-on: ubuntu-latest
needs: [checkstyle]

env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install modal
- name: Run convergence tests
run: |
modal run dev.modal.conv_tests
# convergence-tests:
# runs-on: ubuntu-latest
# needs: [checkstyle]

# env:
# MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
# MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}

# steps:
# - name: Checkout code
# uses: actions/checkout@v3

# - name: Set up Python
# uses: actions/setup-python@v3
# with:
# python-version: '3.10'

# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install modal

# - name: Run convergence tests
# run: |
# modal run dev.modal.conv_tests
2 changes: 1 addition & 1 deletion dev/modal/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 5)
def liger_unit_test():
import subprocess

Expand Down
165 changes: 5 additions & 160 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,14 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
@pytest.mark.parametrize(
"B, T, V",
[
(2, 4096, 32000), # llama2, mistral
(2, 4096, 32000), # llama2, mistral
(1, 4096, 128256), # llama3
# # weird shapes
(3, 423, 32000),
(2, 4096, 32000), # llama
(3, 423, 32000), # weird shapes
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
Expand All @@ -199,24 +187,9 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-7,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol):
liger_ce = LigerCrossEntropyLoss(reduction=reduction)
_test_correctness_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol)
Expand All @@ -233,12 +206,8 @@ def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(0.1, torch.bfloat16, 1e-8, 5e-2),
(1.0, torch.bfloat16, 1e-8, 5e-2),
(10.0, torch.bfloat16, 1e-7, 5e-2),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
Expand All @@ -248,9 +217,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
@pytest.mark.parametrize(
"B, T, V, ignore_index",
[
(2, 4096, 32000, -100), # llama2, mistral
(2, 4096, 32000, 2), # llama2, mistral
(1, 4096, 128256, -300), # llama3
(2, 4096, 32000, 2),
# weird shapes
(3, 423, 32000, -123),
],
Expand All @@ -259,15 +226,6 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
Expand All @@ -277,24 +235,9 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_ignore_index(
B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
):
Expand All @@ -307,25 +250,14 @@ def test_correctness_with_ignore_index(
@pytest.mark.parametrize(
"B, T, V, label_smoothing",
[
(2, 4096, 32000, 0.1), # llama2, mistral
(2, 4096, 32000, 0.1), # llama2, mistral
(1, 4096, 128256, 0.1), # llama3
(2, 4096, 32000, 0.1),
# weird shapes
(3, 423, 32000, 0.1),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
Expand All @@ -335,24 +267,9 @@ def test_correctness_with_ignore_index(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_label_smoothing_once(
B, T, V, label_smoothing, scalar, dtype, atol, rtol
):
Expand All @@ -365,25 +282,14 @@ def test_correctness_with_label_smoothing_once(
@pytest.mark.parametrize(
"B, T, V, ignore_index, label_smoothing",
[
(2, 4096, 32000, 1, 0.1), # llama2, mistral
(2, 4096, 32000, -100, 0.2), # llama2, mistral
(1, 4096, 128256, 2, 0.1), # llama3
(2, 4096, 32000, 1, 0.1),
# weird shapes
(3, 423, 32000, -300, 0.2),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
Expand All @@ -393,24 +299,9 @@ def test_correctness_with_label_smoothing_once(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-6,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_label_smoothing_with_ignore_index_once(
B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
):
Expand All @@ -427,8 +318,6 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
"B, T, V",
[
(2, 4096, 32000), # llama2, mistral
(2, 4096, 32000), # llama2, mistral
(1, 4096, 128256), # llama3
# # weird shapes
(3, 423, 32000),
],
Expand All @@ -449,52 +338,8 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
(1.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol):
liger_ce = LigerCrossEntropyLoss(reduction=reduction)
_test_correctness_not_last_layer_once(
liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol
)


#############################################################################
# Test full pass of the liger cross entropy loss to ensure it doesn't crash
#############################################################################


def _full_pass_once(B, T, V, reduction):

liger_ce = LigerCrossEntropyLoss(reduction=reduction)

_input = torch.randn(
B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16
)
target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1)

output = liger_ce(_input, target)
output.backward()


@pytest.mark.parametrize(
"B, T, V",
[
(
8,
8192,
128256,
), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64
(8, 16384, 128256), # _input = 32GB, total = ~64GB
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000,
reason="Needs 64GB+ GPU memory.",
)
def test_large_no_exception(B, T, V, reduction):
# The large inputs were hitting cuda illegal memory access because of
# https://github.com/triton-lang/triton/issues/1058
_full_pass_once(B, T, V, reduction)
1 change: 1 addition & 0 deletions test/transformers/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SLEEP_SECONDS = 0.1


@pytest.mark.skip(reason="LigerEmbedding is under experimentation")
@pytest.mark.parametrize(
"num_embeddings, embedding_dim, padding_idx",
[
Expand Down
16 changes: 4 additions & 12 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,8 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"B, T, H, V",
[
# (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160
(8, 2048, 4096, 32000), # llama2, mistral
# Comment out to speed up testing
# (4, 2048, 4096, 128256), # llama3 8B
# (4, 1024, 8192, 128256), # llama3 70B
(4, 423, 8192, 32000), # random shape
(8, 128, 1024, 4096),
(4, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -233,12 +229,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160
(8, 2048, 4096, 32000), # llama2, mistral
# Comment out to speed up testing
(4, 2048, 4096, 128256), # llama3 8B
(4, 1024, 8192, 128256), # llama3 70B
(4, 423, 8192, 32000), # random shape
(8, 128, 1024, 4096),
(4, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit a2dfa3c

Please sign in to comment.