Skip to content

Commit

Permalink
bugfix: fix AOT mode unittests (#665)
Browse files Browse the repository at this point in the history
Follow up of #657
  • Loading branch information
yzh119 authored Dec 16, 2024
1 parent b1b1fb8 commit d9d8eb1
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 225 deletions.
49 changes: 25 additions & 24 deletions tests/test_alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,32 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
else:
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("seq_len", [1, 9, 81, 729])
Expand Down
49 changes: 25 additions & 24 deletions tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,32 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
else:
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("batch_size", [12, 17])
Expand Down
33 changes: 17 additions & 16 deletions tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,24 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
else:
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("batch_size", [12, 17])
Expand Down
49 changes: 25 additions & 24 deletions tests/test_block_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,32 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
else:
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


def bsr_attention_ref(
Expand Down
49 changes: 25 additions & 24 deletions tests/test_logits_cap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,32 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
else:
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


def attention_logits_soft_cap_torch(q, k, v, soft_cap):
Expand Down
49 changes: 25 additions & 24 deletions tests/test_non_contiguous_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,32 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
else:
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("batch_size", [1, 19, 99])
Expand Down
33 changes: 17 additions & 16 deletions tests/test_non_contiguous_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,24 @@
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
yield
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
else:
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("seq_len", [1, 7, 127, 999, 3579])
Expand Down
Loading

0 comments on commit d9d8eb1

Please sign in to comment.