diff --git a/.github/workflows/test_iree.yml b/.github/workflows/test_iree.yml index 572e694ef..0a24d2800 100644 --- a/.github/workflows/test_iree.yml +++ b/.github/workflows/test_iree.yml @@ -146,12 +146,15 @@ jobs: run: | source ${VENV_DIR}/bin/activate python3 iree_tests/download_remote_files.py --root-dir pytorch/models + python3 iree_tests/download_remote_files.py --root-dir sharktank - - name: "Running real weights model tests" - if: ${{ !cancelled() }} + - name: "Running real weight model tests" + if: "matrix.models-config-file != '' && !cancelled()" run: | source ${VENV_DIR}/bin/activate - pytest iree_tests/pytorch/models \ + pytest \ + iree_tests/pytorch/models \ + iree_tests/sharktank \ -n 4 \ -rpfE \ -k real_weights \ diff --git a/iree_tests/README.md b/iree_tests/README.md index d186bee60..e7a8e7f15 100644 --- a/iree_tests/README.md +++ b/iree_tests/README.md @@ -413,6 +413,25 @@ Then, run the runner with the appropriate command line args (vmfb path, device f You should have all the artifacts needed to add to this TestSuite at that point. Make sure to follow to follow appendix instructions to convert between different file types for weights and mlir. +### SHARK Tank models + +These test cases are exported from https://github.com/nod-ai/sharktank. + +## Steps to add test cases + +* Follow instructions in https://github.com/nod-ai/sharktank/blob/main/docs/model_cookbook.md +* Convert the exported `.mlir` to `.mlirbc`: + + ```bash + iree-ir-tool cp file.mlir --emit-bytecode -o file.mlirbc + ``` + +* Create a test_cases.json file with parameters, inputs, and outputs + * Parameters can come from Hugging Face by using URL from "download file" + * TODO: inputs and outputs should be exportable from sharktank/shortfin + (or a script here - need to run the tokenizer and optionally populate the + KV cache for some models) + ## Appendix ### Working with .mlirbc files diff --git a/iree_tests/configs/models_gpu_rocm_gfx90a.json b/iree_tests/configs/models_gpu_rocm_gfx90a.json index 187e9ccf2..85c726dab 100644 --- a/iree_tests/configs/models_gpu_rocm_gfx90a.json +++ b/iree_tests/configs/models_gpu_rocm_gfx90a.json @@ -18,6 +18,9 @@ "expected_compile_failures": [ "pytorch/models/opt-125M", // TODO(#17344): need to regenerate .mlirbc "pytorch/models/resnet50", + // error: 'builtin.module' op failed to run transform dialect passes + // (might need to drop the iree-codegen-transform-dialect-library flag) + "sharktank/llama/open-llama-3b-v2-f16" ], "expected_run_failures": [] } diff --git a/iree_tests/sharktank/llama/open-llama-3b-v2-f16/open-llama-3b-v2-f16.mlirbc b/iree_tests/sharktank/llama/open-llama-3b-v2-f16/open-llama-3b-v2-f16.mlirbc new file mode 100644 index 000000000..8328e208b --- /dev/null +++ b/iree_tests/sharktank/llama/open-llama-3b-v2-f16/open-llama-3b-v2-f16.mlirbc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2047dcb7bdeaab647953245fd150d088f836172ad9f8775d12bd6118d68f1e7 +size 5578312 diff --git a/iree_tests/sharktank/llama/open-llama-3b-v2-f16/real_weights_prefill_data_flags.txt b/iree_tests/sharktank/llama/open-llama-3b-v2-f16/real_weights_prefill_data_flags.txt new file mode 100644 index 000000000..ef378861a --- /dev/null +++ b/iree_tests/sharktank/llama/open-llama-3b-v2-f16/real_weights_prefill_data_flags.txt @@ -0,0 +1,6 @@ +--parameters=model=open-llama-3b-v2-f16.gguf +--function=prefill_bs4 +--input=4x1xi64=0 +--input=4xi64=1 +--input=4x1xi64=0,1,2,3 +--input=1x2662400xf16 diff --git a/iree_tests/sharktank/llama/open-llama-3b-v2-f16/test_cases.json b/iree_tests/sharktank/llama/open-llama-3b-v2-f16/test_cases.json new file mode 100644 index 000000000..9fec50fd7 --- /dev/null +++ b/iree_tests/sharktank/llama/open-llama-3b-v2-f16/test_cases.json @@ -0,0 +1,13 @@ +{ + "file_format": "test_cases_v0", + "test_cases": [ + { + "name": "real_weights_prefill", + "runtime_flagfile": "real_weights_prefill_data_flags.txt", + "remote_files": [ + "https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf/resolve/main/open-llama-3b-v2-f16.gguf", + // TODO: files for real inputs and real expected outputs + ] + } + ] +}