diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml new file mode 100644 index 00000000000..d923c27734b --- /dev/null +++ b/.azuredevops/rocm-ci.yml @@ -0,0 +1,43 @@ +resources: + repositories: + - repository: pipelines_repo + type: github + endpoint: ROCm + name: ROCm/ROCm + +variables: +- group: common +- template: /.azuredevops/variables-global.yml@pipelines_repo + +trigger: + batch: true + branches: + include: + - develop + paths: + exclude: + - .githooks + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + +pr: + autoCancel: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + drafts: false + +jobs: + - template: ${{ variables.CI_COMPONENT_PATH }}/AMDMIGraphX.yml@pipelines_repo diff --git a/.github/workflows/config.md b/.github/workflows/config.md new file mode 100644 index 00000000000..326c99ca483 --- /dev/null +++ b/.github/workflows/config.md @@ -0,0 +1,29 @@ +#=====ROCM INFO===== +ROCM_VERSION : '6.0.2' +#default ROCm version to be used +ROCM_BASE_IMAGE : 'rocm/dev-ubuntu-20.04' +#base image from dockerhub to be used +ROCM_BUILT_IMAGE : 'rocm-migraphx' +#name of the docker image built upon ROCm base +USE_NAVI : '0' +#disable NAVI in image build +OVERWRITE_EXISTING : 'true' +#building new ROCm image overwrites old with same version + +#=====REPOS INFO===== +ORGANIZATION_REPO : 'AMD' +BENCHMARK_UTILS_REPO : 'ROCm/migraphx-benchmark-utils' +PERFORMANCE_REPORTS_REPO : 'ROCm/migraphx-reports' +PERFORMANCE_BACKUP_REPO : 'migraphx-benchmark/performance-backup' + +#=====PERFORMANCE SCRIPT PARAMETERS===== +RESULTS_TO_COMPARE : '10' +#number of previous performance results to be used in calculations +CALCULATION_METHOD_FLAG : '-r' +#calculation method used in reporting, -m for Max value; -s for Std dev; -r for Threshold file +PERFORMANCE_TEST_TIMEOUT : '30m' +#timeout for each model after which test is aborted + +#===== W A R N I N G ===== +#VARIABLE NAMES NOT TO BE CHANGED, VALUES ONLY! +#VALUES MUST BE ENGLOSED IN SINGLE QUOTES! \ No newline at end of file diff --git a/.github/workflows/performance.yaml b/.github/workflows/performance.yaml index 42e92e6dc2b..98e79ac3efe 100644 --- a/.github/workflows/performance.yaml +++ b/.github/workflows/performance.yaml @@ -5,7 +5,7 @@ on: branches: [develop] types: [opened, synchronize, closed] schedule: - - cron: "0 6 * * 1-6" + - cron: "0 7 * * 1-6" workflow_dispatch: inputs: @@ -47,18 +47,53 @@ concurrency: cancel-in-progress: true jobs: - release: + get_config: + runs-on: ubuntu-latest + outputs: + rocm_version: ${{ steps.read_config.outputs.rocm_version }} + utils_repo: ${{ steps.read_config.outputs.utils_repo }} + reports_repo: ${{ steps.read_config.outputs.reports_repo }} + backup_repo: ${{ steps.read_config.outputs.backup_repo }} + repo_org: ${{ steps.read_config.outputs.repo_org }} + perf_number: ${{ steps.read_config.outputs.perf_number }} + perf_flag: ${{ steps.read_config.outputs.perf_flag }} + perf_timeout: ${{ steps.read_config.outputs.perf_timeout }} + steps: + - name: checkout + uses: actions/checkout@v4.1.1 + - name: read_config + id: read_config + run: | + ROCM_VERSION=$(grep 'ROCM_VERSION' .github/workflows/config.md | cut -d "'" -f2) + BENCHMARK_UTILS_REPO=$(grep 'BENCHMARK_UTILS_REPO' .github/workflows/config.md | cut -d "'" -f2) + PERFORMANCE_REPORTS_REPO=$(grep 'PERFORMANCE_REPORTS_REPO' .github/workflows/config.md | cut -d "'" -f2) + PERFORMANCE_BACKUP_REPO=$(grep 'PERFORMANCE_BACKUP_REPO' .github/workflows/config.md | cut -d "'" -f2) + ORGANIZATION_REPO=$(grep 'ORGANIZATION_REPO' .github/workflows/config.md | cut -d "'" -f2) + RESULTS_TO_COMPARE=$(grep 'RESULTS_TO_COMPARE' .github/workflows/config.md | cut -d "'" -f2) + CALCULATION_METHOD_FLAG=$(grep 'CALCULATION_METHOD_FLAG' .github/workflows/config.md | cut -d "'" -f2) + PERFORMANCE_TEST_TIMEOUT=$(grep 'PERFORMANCE_TEST_TIMEOUT' .github/workflows/config.md | cut -d "'" -f2) + echo "rocm_version=$ROCM_VERSION" >> $GITHUB_OUTPUT + echo "utils_repo=$BENCHMARK_UTILS_REPO" >> $GITHUB_OUTPUT + echo "reports_repo=$PERFORMANCE_REPORTS_REPO" >> $GITHUB_OUTPUT + echo "backup_repo=$PERFORMANCE_BACKUP_REPO" >> $GITHUB_OUTPUT + echo "repo_org=$ORGANIZATION_REPO" >> $GITHUB_OUTPUT + echo "perf_number=$RESULTS_TO_COMPARE" >> $GITHUB_OUTPUT + echo "perf_flag=$CALCULATION_METHOD_FLAG" >> $GITHUB_OUTPUT + echo "perf_timeout=$PERFORMANCE_TEST_TIMEOUT" >> $GITHUB_OUTPUT + + call_reusable: + needs: get_config uses: ROCm/migraphx-benchmark/.github/workflows/perf-test.yml@main with: - rocm_release: ${{ github.event.inputs.rocm_release || '6.0.2' }} - result_number: ${{ github.event.inputs.result_number || '10' }} - flags: ${{ github.event.inputs.flags || '-r' }} - performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || 'ROCm/migraphx-reports' }} - performance_backup_repo: ${{ github.event.inputs.performance_backup_repo || 'migraphx-benchmark/performance-backup' }} - benchmark_utils_repo: ${{ github.event.inputs.benchmark_utils_repo || 'ROCm/migraphx-benchmark-utils' }} - organization: ${{ github.event.inputs.organization || 'AMD' }} - model_timeout: ${{ github.event.inputs.model_timeout || '30m' }} + rocm_release: ${{ github.event.inputs.rocm_release || needs.get_config.outputs.rocm_version }} + benchmark_utils_repo: ${{ github.event.inputs.benchmark_utils_repo || needs.get_config.outputs.utils_repo }} + performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || needs.get_config.outputs.reports_repo }} + performance_backup_repo: ${{ github.event.inputs.performance_backup_repo || needs.get_config.outputs.backup_repo }} + organization: ${{ github.event.inputs.organization || needs.get_config.outputs.repo_org }} + result_number: ${{ github.event.inputs.result_number || needs.get_config.outputs.perf_number }} + flags: ${{ github.event.inputs.flags || needs.get_config.outputs.perf_flag }} + model_timeout: ${{ github.event.inputs.model_timeout || needs.get_config.outputs.perf_timeout }} secrets: gh_token: ${{ secrets.MIGRAPHX_BOT_TOKEN }} mail_user: ${{ secrets.MAIL_USERNAME }} - mail_pass: ${{ secrets.MAIL_PASSWORD }} + mail_pass: ${{ secrets.MAIL_PASSWORD }} \ No newline at end of file diff --git a/.github/workflows/rocm-image-release.yaml b/.github/workflows/rocm-image-release.yaml index 95ad3325ffc..ec98ccd9100 100644 --- a/.github/workflows/rocm-image-release.yaml +++ b/.github/workflows/rocm-image-release.yaml @@ -18,6 +18,10 @@ on: description: Docker image name for rocm Docker build required: true default: "rocm-migraphx" + branch_name: + description: branch to use for building base ROCm image + required: true + default: "develop" build_navi: description: Build navi number required: true @@ -35,6 +39,7 @@ jobs: benchmark-utils_repo: ${{ github.event.inputs.benchmark-utils_repo || 'ROCm/migraphx-benchmark-utils' }} base_image: ${{ github.event.inputs.base_image || 'rocm/dev-ubuntu-20.04' }} docker_image: ${{ github.event.inputs.docker_image || 'rocm-migraphx' }} + branch_name: ${{ github.event.inputs.branch_name || 'develop' }} build_navi: ${{ github.event.inputs.build_navi || '0' }} overwrite: ${{ github.event.inputs.overwrite == 'true' }} secrets: diff --git a/.github/workflows/sync_rocMLIR.yaml b/.github/workflows/sync_rocMLIR.yaml new file mode 100644 index 00000000000..9f473859c7c --- /dev/null +++ b/.github/workflows/sync_rocMLIR.yaml @@ -0,0 +1,90 @@ +name: rocMLIR sync with extended accuracy + +on: + schedule: + - cron: '0 7 * * sun' + pull_request: + branches: [rocMLIR-sync-*] + types: [synchronize, closed] + workflow_dispatch: + inputs: + rocm_release: + type: string + description: ROCm release version + required: true + default: '6.0.2' + base_image: + type: string + description: Base image for ROCm Docker build + required: true + default: 'rocm/dev-ubuntu-20.04' + docker_image: + type: string + description: Docker image name for rocm Docker build + required: true + default: 'rocm-migraphx' + build_navi: + type: string + description: Build navi number + required: true + default: '0' + benchmark_utils_repo: + type: string + description: Repository where benchmark utils are stored + required: true + default: 'ROCm/migraphx-benchmark-utils' + performance_reports_repo: + description: Repository where performance reports are stored + required: true + default: 'ROCm/migraphx-reports' + organization: + type: string + description: Organization based on which location of files will be different + required: true + default: 'AMD' + +jobs: + get_config: + runs-on: ubuntu-latest + outputs: + rocm_version: ${{ steps.read_config.outputs.rocm_version }} + rocm_base_image: ${{ steps.read_config.outputs.rocm_base_image }} + rocm_built_image: ${{ steps.read_config.outputs.rocm_built_image }} + use_navi: ${{ steps.read_config.outputs.use_navi }} + utils_repo: ${{ steps.read_config.outputs.utils_repo }} + reports_repo: ${{ steps.read_config.outputs.reports_repo }} + repo_org: ${{ steps.read_config.outputs.repo_org }} + steps: + - name: checkout + uses: actions/checkout@v4.1.1 + - name: read_config + id: read_config + run: | + ROCM_VERSION=$(grep 'ROCM_VERSION' .github/workflows/config.md | cut -d "'" -f2) + ROCM_BASE_IMAGE=$(grep 'ROCM_BASE_IMAGE' .github/workflows/config.md | cut -d "'" -f2) + ROCM_BUILT_IMAGE=$(grep 'ROCM_BUILT_IMAGE' .github/workflows/config.md | cut -d "'" -f2) + BENCHMARK_UTILS_REPO=$(grep 'BENCHMARK_UTILS_REPO' .github/workflows/config.md | cut -d "'" -f2) + PERFORMANCE_REPORTS_REPO=$(grep 'PERFORMANCE_REPORTS_REPO' .github/workflows/config.md | cut -d "'" -f2) + ORGANIZATION_REPO=$(grep 'ORGANIZATION_REPO' .github/workflows/config.md | cut -d "'" -f2) + USE_NAVI=$(grep 'USE_NAVI' .github/workflows/config.ymd | cut -d "'" -f2) + echo "rocm_version=$ROCM_VERSION" >> $GITHUB_OUTPUT + echo "rocm_base_image=$ROCM_BASE_IMAGE" >> $GITHUB_OUTPUT + echo "rocm_built_image=$ROCM_BUILT_IMAGE" >> $GITHUB_OUTPUT + echo "use_navi=$USE_NAVI" >> $GITHUB_OUTPUT + echo "utils_repo=$BENCHMARK_UTILS_REPO" >> $GITHUB_OUTPUT + echo "reports_repo=$PERFORMANCE_REPORTS_REPO" >> $GITHUB_OUTPUT + echo "repo_org=$ORGANIZATION_REPO" >> $GITHUB_OUTPUT + + call_reusable: + needs: get_config + uses: ROCm/migraphx-benchmark/.github/workflows/rocMLIR_sync.yml@main + with: + rocm_release: ${{ github.event.inputs.rocm_release || needs.get_config.outputs.rocm_version }} + base_image: ${{ github.event.inputs.base_image || needs.get_config.outputs.rocm_base_image }} + docker_image: ${{ github.event.inputs.docker_image || needs.get_config.outputs.rocm_built_image }} + build_navi: ${{ github.event.inputs.build_navi || needs.get_config.outputs.use_navi }} + benchmark_utils_repo: ${{ github.event.inputs.benchmark_utils_repo || needs.get_config.outputs.utils_repo }} + performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || needs.get_config.outputs.reports_repo }} + organization: ${{ github.event.inputs.organization || needs.get_config.outputs.repo_org }} + secrets: + gh_token: ${{ secrets.MIGRAPHX_BOT_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/weekly_master_sync.yaml b/.github/workflows/weekly_master_sync.yaml new file mode 100644 index 00000000000..732d6a08d63 --- /dev/null +++ b/.github/workflows/weekly_master_sync.yaml @@ -0,0 +1,23 @@ +name: Master weekly sync + +on: + schedule: + - cron: '0 15 * * sun' + workflow_dispatch: + +jobs: + SyncAndMerge: + name: Sync master and merge develop + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v4.1.1 + with: + ref: develop + fetch-depth: '0' + + - name: Merge Fast Forward Only + run: | + git checkout master + git merge origin/develop --ff-only + git push origin HEAD diff --git a/CMakeLists.txt b/CMakeLists.txt index 69b433e871a..d3a2fde80be 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -169,6 +169,9 @@ rocm_enable_clang_tidy( -cert-dcl51-cpp -cert-err33-c -cert-str34-c + # We seed random numbers with constants for reproducibility + -cert-msc32-c + -cert-msc51-cpp # Disable all alpha checks by default -clang-analyzer-alpha* # Enable some alpha checks @@ -340,11 +343,14 @@ if(MIGRAPHX_USE_ROCBLAS) list(APPEND PACKAGE_DEPENDS rocblas) endif() +rocm_package_add_deb_dependencies(SHARED_DEPENDS "hip-dev") +rocm_package_add_rpm_dependencies(SHARED_DEPENDS "hip-devel") + rocm_create_package( NAME MIGraphX DESCRIPTION "AMD's graph optimizer" MAINTAINER "AMDMIGraphX Maintainer " LDCONFIG PTH - DEPENDS miopen-hip ${DEPENDS_HIP_RUNTIME} hip-base half ${PACKAGE_DEPENDS} + DEPENDS miopen-hip ${DEPENDS_HIP_RUNTIME} half ${PACKAGE_DEPENDS} ) diff --git a/Dockerfile b/Dockerfile index 657f90e7e02..cc5f14a7ca0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libpython3.8 \ wget \ rocm-device-libs \ - hip-base \ + hip-dev \ libnuma-dev \ miopen-hip \ rocblas \ diff --git a/Jenkinsfile b/Jenkinsfile index 935af40b6e8..de6fa059a0b 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -155,13 +155,13 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags_cxx}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DGPU_TARGETS='${gpu_targets}'") } } -}, ck_hiprtc: rocmnode('mi100+') { cmake_build -> - stage('CK hipRTC') { - withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1', 'MIGRAPHX_DISABLE_MLIR=1']) { - def gpu_targets = getgputargets() - cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'") - } - } +//}, ck_hiprtc: rocmnode('mi100+') { cmake_build -> +// stage('CK hipRTC') { +// withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1', 'MIGRAPHX_DISABLE_MLIR=1']) { +// def gpu_targets = getgputargets() +// cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'") +// } +// } }, clang_asan: rocmnode('nogpu') { cmake_build -> stage('Clang ASAN') { def sanitizers = "undefined,address" diff --git a/dev-requirements.txt b/dev-requirements.txt index cf520e3cb59..c9fce73e910 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -26,5 +26,5 @@ facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake ccache@v4.1 -DENABLE_TESTING=OFF pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 danmar/cppcheck@bb2711c22a0be09efe7f1a8da3030876471026c8 -DHAVE_RULES=1 # 2.11 -RadeonOpenCompute/rocm-cmake@5a34e72d9f113eb5d028e740c2def1f944619595 --build +RadeonOpenCompute/rocm-cmake@a83c5075d85f1fd28d657a9277eb21c834d76f3f --build -f requirements.txt diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index 1ac83dabd3b..9bfc87108fb 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -210,7 +210,14 @@ Operator Support Matrix | | | | shape is not | | | | | supported | +--------------------------+-----------+-----------------+------------------------------+ -| Einsum | 👷 | 👷 | | +| Einsum | ✅ | Any | more than 1 diagonal per | +| | | | input is not supported | +| | | | e.g. ``iijj->ij`` | +| | | | | +| | | | batch diagonal where batches | +| | | | are not the leading dims is | +| | | | not supported | +| | | | e.g. ``ii...->i...`` | +--------------------------+-----------+-----------------+------------------------------+ | Elu | ✅ | FP8, FP16, | | | | | FP32, FP64 | | diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 1e667aaa891..19a63badcf7 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.1.1 +rocm-docs-core==1.2.0 sphinx-collapse diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 4a6de8f6fe5..b95758e80a4 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -88,11 +88,11 @@ pyyaml==6.0 # myst-parser # rocm-docs-core # sphinx-external-toc -requests==2.31.0 +requests==2.32.0 # via # pygithub # sphinx -rocm-docs-core==1.1.1 +rocm-docs-core==1.2.0 # via -r requirements.in smmap==5.0.0 # via gitdb diff --git a/hip-clang.docker b/hip-clang.docker index 6b090607b1e..73a3e8edbba 100755 --- a/hip-clang.docker +++ b/hip-clang.docker @@ -27,7 +27,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- software-properties-common \ wget \ rocm-device-libs \ - hip-base \ + hip-dev \ libnuma-dev \ miopen-hip \ rocblas \ diff --git a/requirements.txt b/requirements.txt index 49b42573add..915e644c4c7 100755 --- a/requirements.txt +++ b/requirements.txt @@ -24,8 +24,8 @@ google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off nlohmann/json@v3.8.0 ROCm/half@rocm-5.6.0 -pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build +pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@e50d72fc6ab9a7a792d92a1ba7db6db45e4c508c -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@3612396bca1139abf25e2ed0085fe481d275af89 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index eedc7a27501..876f69d2595 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,6 +67,7 @@ add_library(migraphx instruction.cpp json.cpp layout_nhwc.cpp + lexing.cpp load_save.cpp make_op.cpp memory_coloring.cpp @@ -269,6 +270,7 @@ register_op(migraphx HEADER migraphx/builtin.hpp OPERATORS builtin::literal buil rocm_clang_tidy_check(migraphx) migraphx_generate_export_header(migraphx) rocm_install_targets( + PRIVATE TARGETS migraphx INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include diff --git a/src/convert_to_json.cpp b/src/convert_to_json.cpp index f3872fd22e0..0dd823d1eb7 100644 --- a/src/convert_to_json.cpp +++ b/src/convert_to_json.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,66 +23,17 @@ */ #include #include -#include -#include #include #include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -using token = std::pair; -using lexer = std::function; - -template -auto lex_while(P p) -{ - return [=](const char* start, const char* end) { - return std::find_if(start, end, [&](char c) { return not p(c); }); - }; -} - -template -auto lex_if(P p) -{ - return [=](const char* start, const char*) { - if(p(*start)) - return start + 1; - return start; - }; -} - -std::vector tokenize(const char* start, const char* end, const std::vector& lexers) -{ - std::vector result; - while(start != end) - { - bool error = true; - for(const auto& l : lexers) - { - const auto* next = l(start, end); - if(next != start) - { - result.emplace_back(start, next); - start = next; - error = false; - break; - } - } - - if(error) - { - MIGRAPHX_THROW("TOKENIZE: no token found!"); - } - } - - return result; -} - -std::vector json_tokenize(const std::string& s) +std::vector json_tokenize(const std::string& s) { std::vector lexers; @@ -133,7 +84,7 @@ std::string convert_to_json(const std::string& str) for(auto& token : tokens) { - std::string s(token.first, token.second); + std::string s(token); if(starts_with(s, "#") or starts_with(s, "//")) continue; if(std::isalpha(s.front()) != 0 and diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index a13bee3c7f1..46e3934908f 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -88,7 +89,7 @@ static void create_pointwise_modules(module_pass_manager& mpm) { pointwise_inputs.push_back(input); param_map[input] = - pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()}); + pm->add_parameter(param_name(i), shape{input->get_shape().type()}); i++; } else @@ -133,7 +134,7 @@ static module::with_inputs append_pointwise_module(instruction_ref ins, instruct for(auto i : range(inputs.size())) { auto input = inputs[i]; - auto param = pm.get_parameter("x" + std::to_string(i)); + auto param = pm.get_parameter(param_name(i)); assert(param != pm.end()); input_map[input] = param; } @@ -141,7 +142,7 @@ static module::with_inputs append_pointwise_module(instruction_ref ins, instruct for(auto i : range(output->inputs().size())) { auto input = output->inputs()[i]; - auto param = xm->get_parameter("x" + std::to_string(i)); + auto param = xm->get_parameter(param_name(i)); assert(param != xm->end()); if(input == ins) { @@ -156,7 +157,7 @@ static module::with_inputs append_pointwise_module(instruction_ref ins, instruct else { map_ins[param] = - pm.add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()}); + pm.add_parameter(param_name(inputs.size()), {input->get_shape().type()}); inputs.push_back(input); input_map[input] = map_ins[param]; } diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8b7571d28bc..37dcfc0e028 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -22,16 +22,16 @@ * THE SOFTWARE. */ #include -#include +#include #include #include #include -#include -#include #include -#include -#include +#include #include +#include +#include +#include #include #include #include @@ -59,6 +59,8 @@ struct fused_reduce const auto* sm = mods.front(); if(sm->get_output_shapes().size() != 1) MIGRAPHX_THROW("Only one output supported"); + if(not sm->bypass()) + MIGRAPHX_THROW("fused_reduce: bypass flag is not set"); auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); std::sort(names.begin(), names.end()); @@ -67,7 +69,7 @@ struct fused_reduce if(not equal(names, inputs, [&](const auto& name, const auto& input) { return shapes.at(name).lens() == input.lens(); })) - MIGRAPHX_THROW("Dimenstion does not match the submodule."); + MIGRAPHX_THROW("Input dimension does not match the submodule."); return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), @@ -78,6 +80,17 @@ struct fused_reduce }; MIGRAPHX_REGISTER_OP(fused_reduce); +/* + * Predicate matcher checks that input and output shapes have the same rank. This is assumed + * for broadcast instructions for these fusions. + */ +MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins) +{ + auto input_shape = ins->inputs().front()->get_shape(); + auto output_shape = ins->get_shape(); + return input_shape.ndim() == output_shape.ndim(); +} + static void insert_params(module_ref sm, const std::vector& inputs, std::unordered_map& map_ins) @@ -227,7 +240,8 @@ template static auto match_broadcast(Ms... ms) { return match::skip(match::name("contiguous"))( - match::name("multibroadcast")(match::arg(0)(ms...), match::used_once()) + match::name("multibroadcast")( + match::arg(0)(ms...), match::used_once(), input_output_ndim_match()) .bind("broadcast")) .bind("final_broadcast"); } @@ -257,19 +271,19 @@ struct find_pointwise_reduce { auto matcher() const { + // fused_reduce instruction with pointwise inputs. return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise")); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto reduce = r.result; - auto input = r.instructions["pointwise"]; - + auto input = r.instructions["pointwise"]; const auto* pm = input->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front(); + auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name()); rm->set_bypass(); - std::unordered_map map_ins; // Insert pointwise auto rins = insert_ins_in_submodule(rm, input, map_ins).front(); @@ -414,6 +428,7 @@ struct reduce_reshape : rewrite_reshapes_base auto dims = base_dims(inputs); auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); + sm->set_bypass(); insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) { if(contains(sop.name(), "reduce")) return make_op(sop.name(), {{"axes", axes}}); diff --git a/src/include/migraphx/lexing.hpp b/src/include/migraphx/lexing.hpp new file mode 100644 index 00000000000..61d34b8e0c6 --- /dev/null +++ b/src/include/migraphx/lexing.hpp @@ -0,0 +1,64 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_LEXING_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_LEXING_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +using lexer = std::function; + +template +inline auto lex_while(P p) +{ + return [=](const char* start, const char* end) { + return std::find_if(start, end, [&](char c) { return not p(c); }); + }; +} + +template +inline auto lex_if(P p) +{ + return [=](const char* start, const char*) { + if(p(*start)) + return start + 1; + return start; + }; +} + +MIGRAPHX_EXPORT std::function +lex_equal(const std::string& s); + +MIGRAPHX_EXPORT std::vector +tokenize(const char* start, const char* end, const std::vector& lexers); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 6e7e57c67bf..6956f5916e8 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -330,9 +330,27 @@ struct matcher_result }); } + void debug_print() const + { + for(const auto& it : ins_map) + { + std::cout << it.first << ": \n"; + it.second->debug_print(); + } + } + private: std::unordered_map ins_map; }; + + void debug_print() const + { + std::cout << "matcher_container: \n instructions:"; + instructions.debug_print(); + std::cout << " result: \n"; + result->debug_print(); + } + instruction_container instructions; instruction_ref result; }; diff --git a/src/include/migraphx/onnx.hpp b/src/include/migraphx/onnx.hpp index c0d7637c093..b9bf42a4f0e 100644 --- a/src/include/migraphx/onnx.hpp +++ b/src/include/migraphx/onnx.hpp @@ -58,6 +58,9 @@ struct onnx_options int64_t limit_max_iterations = std::numeric_limits::max(); /// Use dynamic output for operators when available bool use_dyn_output = false; + /// Path to use for the external data if it is stored at different location compared to onnx + /// file + std::string external_data_path = ""; }; /// Create a program from an onnx file diff --git a/src/include/migraphx/op/mod.hpp b/src/include/migraphx/op/mod.hpp index f1a48e3c58f..38f947a3587 100644 --- a/src/include/migraphx/op/mod.hpp +++ b/src/include/migraphx/op/mod.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,7 +38,6 @@ struct mod : binary { auto a = base_attributes(); a["commutative"] = false; - a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})"; return a; } auto apply() const diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 1552c28300b..3506fccbb30 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -33,7 +33,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -std::string param_name(std::size_t i, const std::string& prefix = "x"); +MIGRAPHX_EXPORT std::string param_name(std::size_t i, const std::string& prefix = "x"); void sort_params(std::vector& params); diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index bd6ed1ec27d..e44413a09a1 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -74,13 +74,16 @@ struct rewrite_reshapes { auto reshape = match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once()); - auto skip_contiguous = [](auto... ms) { - return match::arg(0)(match::skip( - match::name("contiguous", "multibroadcast")(match::used_once()))(ms...)); + auto skip_contiguous_broadcast = + match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); + auto skip_contiguous_broadcast_arg = [&](auto... ms) { + return match::arg(0)(skip_contiguous_broadcast(ms...)); }; auto pointwise = match::name(op1)(match::used_once()); - auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape"); - return match::name(op2)(match::any_of[match::inputs()](reshape_pointwise)); + auto reshape_pointwise = + reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape"); + return match::name(op2)(match::any_of[match::inputs()]( + skip_contiguous_broadcast(reshape_pointwise).bind("input"))); } template @@ -107,17 +110,33 @@ struct rewrite_reshapes return x_ins == input; } + static std::optional is_broadcasted(instruction_ref start, instruction_ref last) + { + auto broadcast_ins = + find_input_if(start, last, [&](auto i) { return i->name() == "multibroadcast"; }); + bool result = broadcast_ins != last; + if(result and not match_input(broadcast_ins, last)) + return nullopt; + return result; + } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; auto reshape_ins = r.instructions["reshape"]; + auto input_ins = r.instructions["input"]; - auto broadcast_ins = find_input_if( - reshape_ins, x_ins, [&](auto i) { return i->name() == "multibroadcast"; }); - const bool has_broadcast = broadcast_ins != x_ins; - if(has_broadcast and not match_input(broadcast_ins, x_ins)) + const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins); + const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins); + if(not has_broadcast_before_reshape.has_value()) + return; + if(not has_broadcast_after_reshape.has_value()) + return; + if(*has_broadcast_after_reshape and *has_broadcast_before_reshape) return; + const bool has_broadcast = + *has_broadcast_after_reshape or *has_broadcast_before_reshape; auto dims1 = T::base_dims(ins); auto dims2 = T::base_dims(x_ins); @@ -153,7 +172,7 @@ struct rewrite_reshapes auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { - if(input == reshape_ins) + if(input == input_ins) return new_x_ins; return reshape_input(ins)(input); }); diff --git a/src/include/migraphx/tf.hpp b/src/include/migraphx/tf.hpp index 3ffbddbce30..2f8fa536fb8 100644 --- a/src/include/migraphx/tf.hpp +++ b/src/include/migraphx/tf.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,6 +45,15 @@ struct tf_options MIGRAPHX_TF_EXPORT program parse_tf(const std::string& name, const tf_options& options = tf_options{}); +/// Create a program from an tf buffer +MIGRAPHX_TF_EXPORT program parse_tf_buffer(const std::string& buffer, + const tf_options& options = tf_options{}); + +/// Create a program from tf buffer +MIGRAPHX_TF_EXPORT program parse_tf_buffer(const void* data, + std::size_t size, + const tf_options& options = tf_options{}); + MIGRAPHX_TF_EXPORT std::vector get_tf_operators(); } // namespace MIGRAPHX_INLINE_NS diff --git a/src/instruction.cpp b/src/instruction.cpp index cf1c80121d6..4a5537c32d2 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/lexing.cpp b/src/lexing.cpp new file mode 100644 index 00000000000..d17de024270 --- /dev/null +++ b/src/lexing.cpp @@ -0,0 +1,71 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::function lex_equal(const std::string& s) +{ + return [=](const char* start, const char* end) { + auto n = end - start; + if(n < s.size()) + return start; + if(std::equal(start, start + s.size(), s.data())) + return start + s.size(); + return start; + }; +} + +std::vector +tokenize(const char* start, const char* end, const std::vector& lexers) +{ + std::vector result; + while(start != end) + { + bool error = true; + for(const auto& l : lexers) + { + const auto* next = l(start, end); + if(next != start) + { + result.emplace_back(start, next - start); + start = next; + error = false; + break; + } + } + + if(error) + { + MIGRAPHX_THROW("TOKENIZE: no token found!"); + } + } + + return result; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp index 37c1721ac12..493ad845a8e 100644 --- a/src/onnx/include/migraphx/onnx/onnx_parser.hpp +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -45,6 +45,7 @@ struct onnx_parser { std::string filename; fs::path path; + std::string external_data_path; using attribute_map = std::unordered_map; struct node_info { diff --git a/src/onnx/onnx.cpp b/src/onnx/onnx.cpp index d8e10d0d6f9..a2c3db80b50 100644 --- a/src/onnx/onnx.cpp +++ b/src/onnx/onnx.cpp @@ -41,6 +41,7 @@ template program parse_onnx_from(const onnx_options& options, Ts&&... xs) { onnx::onnx_parser parser; + parser.external_data_path = options.external_data_path; parser.map_input_dims = options.map_input_dims; parser.dim_params = options.dim_params; parser.map_dyn_input_dims = options.map_dyn_input_dims; diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index ef00d241947..7bfaaa10213 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -514,7 +514,15 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const { nbytes = std::stoul(t.external_data().at(2).value()); } - auto raw_buffer = read_buffer(path / data_file, offset, nbytes); + std::vector raw_buffer; + if(not external_data_path.empty()) + { + raw_buffer = read_buffer(fs::path{external_data_path} / data_file, offset, nbytes); + } + else + { + raw_buffer = read_buffer(path / data_file, offset, nbytes); + } std::string s(raw_buffer.begin(), raw_buffer.end()); return create_literal(type, dims, s.data()); } diff --git a/src/onnx/parse_einsum.cpp b/src/onnx/parse_einsum.cpp new file mode 100644 index 00000000000..b80c6cc327b --- /dev/null +++ b/src/onnx/parse_einsum.cpp @@ -0,0 +1,768 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_einsum : op_parser +{ + using int_mat = std::vector>; + + struct equation_info + { + bool explicit_form = false; + std::vector input_terms; + std::string output_term; + std::map label_count; + std::vector>> duplicates; + size_t ellipsis_ndim = 0; + }; + + std::vector operators() const { return {{"Einsum"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser&, + const onnx_parser::node_info& info, + const std::vector& args) const + { + if(not contains(info.attributes, "equation")) + MIGRAPHX_THROW("Equation attribute is required"); + std::string equation = info.attributes.at("equation").s(); + + const equation_info eq_info = analyze_equation(equation, args); + + auto terms = eq_info.input_terms; + terms.push_back(eq_info.output_term); + const auto map_mat = make_mapping_matrix(terms, eq_info.label_count, eq_info.ellipsis_ndim); + + // Holds the mapping matrix representations of the two terms being processed + // cur_pair[0] acts as the accumulator for previously processed inputs + // cur_pair[1] holds the representation for the current input + // As operations are added to the einsum graph, cur_pair gets manipulated + int_mat cur_pair = make_matrix(2, map_mat[0].size(), -1); + + instruction_ref cur_op; + std::optional last_op; + // Perform a left fold on the inputs + for(auto arg_idx = 0; arg_idx < args.size(); ++arg_idx) + { + cur_op = args[arg_idx]; + cur_pair[1] = map_mat[arg_idx]; + + cur_op = preprocess_input( + info, cur_op, eq_info.duplicates[arg_idx], map_mat, arg_idx, cur_pair); + + if(last_op) + cur_op = process_pair(info, *last_op, cur_op, map_mat, arg_idx, cur_pair); + + last_op = cur_op; + cur_pair[0] = cur_pair[1]; + } + + return finalize_output(info, cur_op, map_mat, cur_pair); + } + + // Equation Parsing + + equation_info analyze_equation(std::string_view equation, + const std::vector& args) const + { + equation_info eq_info = parse_equation(equation); + + eq_info.ellipsis_ndim = validate_input_terms(eq_info.input_terms, args); + if(not eq_info.output_term.empty()) + validate_output_term(eq_info.output_term, eq_info.label_count, eq_info.ellipsis_ndim); + else if(not eq_info.explicit_form) + eq_info.output_term = generate_output_term(eq_info.label_count, eq_info.ellipsis_ndim); + + eq_info.duplicates = find_duplicates(eq_info.input_terms); + + return eq_info; + } + + // Equation: Input Output + // Input: Term | Term ',' Input + // Output: '->' TermOpt | epsilon + // TermOpt: Term | epsilon + // Term: Labels | LabelsOpt '...' LabelsOpt + // LabelsOpt: Labels | epsilon + // Labels: [a-zA-Z]+ + equation_info parse_equation(std::string_view equation) const + { + equation_info ret; + + std::vector lexers; + lexers.push_back(lex_while(&isspace)); + lexers.push_back(lex_while(&isalpha)); + lexers.push_back(lex_equal("->")); + lexers.push_back(lex_equal("...")); + lexers.push_back(lex_equal(",")); + + auto tokens = tokenize(equation.data(), equation.data() + equation.length(), lexers); + + std::string term; + bool has_ellipsis = false; + + for(const auto& token : tokens) + { + if(std::isspace(token.front()) != 0) + continue; + + if(std::isalpha(token.front()) != 0) + { + term += token; + if(not ret.explicit_form) + { + for(auto c : token) + ++ret.label_count[c]; + } + } + else if(token == "->") + { + if(ret.explicit_form) + MIGRAPHX_THROW("Einsum equation has multiple '->' symbols"); + + if(term.empty()) + MIGRAPHX_THROW("No term specified before '->' symbol"); + + ret.explicit_form = true; + has_ellipsis = false; + ret.input_terms.push_back(term); + term.clear(); + } + else if(token == "...") + { + if(has_ellipsis) + MIGRAPHX_THROW("Ellipsis can only appear once per einsum equation term"); + + has_ellipsis = true; + term += "*"; + } + else if(token == ",") + { + if(ret.explicit_form) + MIGRAPHX_THROW("Einsum equation can't have a ',' symbol in the output"); + + if(term.empty()) + MIGRAPHX_THROW("No term specified before ',' symbol"); + + has_ellipsis = false; + ret.input_terms.push_back(term); + term.clear(); + } + } + + if(ret.explicit_form) + ret.output_term = term; + else if(not term.empty()) + ret.input_terms.push_back(term); + else + MIGRAPHX_THROW("Last input term is missing"); + + return ret; + } + + size_t validate_input_terms(const std::vector& input_terms, + const std::vector& args) const + { + if(input_terms.size() != args.size()) + MIGRAPHX_THROW("Number of terms in the input equation - " + + std::to_string(input_terms.size()) + + " does not match the number of inputs " + std::to_string(args.size())); + + auto global_ellipsis_dims = 0u; + for(auto i = 0u; i < args.size(); ++i) + { + const auto& term = input_terms[i]; + const auto dims = args[i]->get_shape().lens(); + const auto rank = dims.size(); + + auto current_dim = 0u; + for(const auto l : term) + { + if(l == '*') + { + const auto ellipsis_dims = rank - term.size() + 1; + if(global_ellipsis_dims > 0 and ellipsis_dims != global_ellipsis_dims) + MIGRAPHX_THROW("Every occurrence of ellipsis in the equation must " + "represent the same number of dimensions"); + global_ellipsis_dims = ellipsis_dims; + current_dim += ellipsis_dims; + } + else + ++current_dim; + } + + if(current_dim != rank) + MIGRAPHX_THROW("Number of labels in " + std::to_string(i + 1) + ". input_term (" + + term + ") does not match the rank (" + std::to_string(rank) + + ") of corresponding input"); + } + + return global_ellipsis_dims; + } + + void validate_output_term(std::string_view output_term, + const std::map& label_count, + size_t ellipsis_ndim) const + { + std::string_view::iterator it = + std::find_if(output_term.begin(), output_term.end(), [&](auto l) { + return not contains(label_count, l) and l != '*'; + }); + if(it != output_term.end()) + MIGRAPHX_THROW("Output term contains label " + std::to_string(*it) + + ", which is not present in any of the input terms"); + + if(ellipsis_ndim != 0 and not contains(output_term, "*")) + MIGRAPHX_THROW( + "Output term does not contain ellipsis (...) even though an input term does"); + } + + // Creates output term when the equation is in implicit mode. + // The created output term must contain the alphabetically sorted sequence of labels appearing + // exactly once in the equation. + // If ellipsis are present in the left hand side of the equation, the ellipsis dimensions are + // set to the beginning of the output term. + std::string generate_output_term(const std::map& label_count, + size_t ellipsis_ndim) const + { + std::string output_term = ellipsis_ndim == 0 ? "" : "*"; + output_term = transform_accumulate( + label_count.begin(), label_count.end(), output_term, std::plus<>(), [](const auto& p) { + if(p.second == 1) + return std::string{p.first}; + else + return std::string{}; + }); + + return output_term; + } + + // Creates a matrix representation of the equation. + // + // Rows correspond to equation terms, in order of appearance. + // + // Columns represent the unique labels contained in the equation, ordered alphabetically. If + // ellipses are present in the equation, they are represented by the final N columns(N being the + // number of dimensions covered by and ellipsis). + // Labels not present in a given term are signified by -1. + // Labels present in a given term are signified by the input axis they represent. + // + // e.g. For equation "...ik,kj...->ij...", assuming ... cover two dimensions, the resulting + // matrix is: + // +-------+----+----+----+---+---+ + // | | i | j | k | * | * | + // +-------+----+----+----+---+---+ + // | ...ik | 2 | -1 | 3 | 0 | 1 | + // | kj... | -1 | 1 | 0 | 2 | 3 | + // | ij... | 0 | 1 | -1 | 2 | 3 | + // +-------+----+----+----+---+---+ + int_mat make_mapping_matrix(const std::vector& terms, + const std::map& label_count, + size_t ellipsis_ndim) const + { + std::map label_to_column; + + auto it = label_count.begin(); + for(auto i = 0; i < label_count.size(); ++i) + label_to_column[(it++)->first] = i; + + int_mat map_mat = make_matrix(terms.size(), label_count.size() + ellipsis_ndim, -1); + + for(auto i = 0; i < terms.size(); ++i) + { + const auto& term = terms[i]; + int col_id = 0; + for(const auto l : term) + { + if(l == '*') + { + std::iota(map_mat[i].end() - ellipsis_ndim, map_mat[i].end(), col_id); + col_id += ellipsis_ndim; + } + else + map_mat[i][label_to_column[l]] = col_id++; + } + } + + return map_mat; + } + + // Finds the duplicated labels in each of the terms and stores the axes on which they occur. + // + // e.g. For equation "iikjj,jkj", the result is a vector containing the two following maps: + // result[0]: {'i': [0, 1], 'j': [3, 4]} + // result[1]: {'j': [0, 2]} + std::vector>> + find_duplicates(const std::vector& terms) const + { + std::vector>> duplicates; + for(const auto& term : terms) + { + std::map> duplicate_axes; + for(auto i = 0; i < term.size(); ++i) + duplicate_axes[term[i]].push_back(i); + + erase_if(duplicate_axes, [](const auto& p) { return p.second.size() < 2; }); + duplicates.push_back(duplicate_axes); + } + + return duplicates; + } + + // Graph Building + + instruction_ref preprocess_input(const onnx_parser::node_info& info, + instruction_ref op, + const std::map>& duplicates, + const int_mat& map_mat, + size_t input_idx, + int_mat& cur_pair) const + { + if(not duplicates.empty()) + { + std::vector> diag; + diag.reserve(duplicates.size()); + std::transform(duplicates.begin(), + duplicates.end(), + std::back_inserter(diag), + [](const auto& d) { return d.second; }); + + op = gather_diagonal(info, cur_pair, op, diag); + } + + // Unsqueeze the input shape in the dimensions marked as -1 in the mapping_matrix + // Transpose the input shape so the labels are in alphabetical order + op = transpose_unsqueeze(info, cur_pair, op); + + std::vector red; + // Check if a given label appears in any of the subsequent mapping matrix terms(this + // includes the output). If does not, it is reduced and marked as -1 in cur_pair. + for(int d = 0; d < map_mat[0].size(); ++d) + { + bool all_neg_one = all_of(extract_column(map_mat, d, input_idx + 1, map_mat.size()), + [](auto i) { return i == -1; }); + if(all_neg_one and cur_pair[1][d] != -1 and cur_pair[0][d] == -1) + red.push_back(d); + } + + return apply_reduce_sum_op(info, op, red, cur_pair[1]); + } + + instruction_ref gather_diagonal(const onnx_parser::node_info& info, + int_mat& cur_pair, + instruction_ref op, + const int_mat& diag) const + { + if(diag.size() != 1) + MIGRAPHX_THROW( + "Parsing of equations with more than one duplicated labels per input term is not " + "implemented"); + + const auto& op_lens = op->get_shape().lens(); + + int first_axis = diag[0][0]; + const std::vector& axes = diag[0]; + if(not all_of(axes, [&](int a) { return op_lens[first_axis] == op_lens[a]; })) + MIGRAPHX_THROW("All duplicate labels have to be the same dimension"); + + std::vector batch_axes = set_difference(arange(0, op_lens.size()), axes); + if(not all_of(batch_axes, [&](int ba) { return ba < axes.front(); })) + MIGRAPHX_THROW( + "Parsing of equations with duplicated labels and batch axes that are not " + "the outer-most axes, is not implemented"); + + size_t batch_size = calc_dim(batch_axes, op_lens); + + std::vector indices; + for(size_t batch = 0; batch < batch_size; ++batch) + { + for(size_t i = 0; i < op_lens[first_axis]; ++i) + { + std::vector index(axes.size(), i); + indices.insert(indices.end(), index.begin(), index.end()); + } + } + + std::vector indices_lens{op_lens[first_axis], axes.size()}; + if(batch_size > 1) + indices_lens.insert(indices_lens.begin(), batch_size); + + auto indices_arg = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int64_type, indices_lens}, indices}); + + op = info.add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_axes.size()}}), op, indices_arg); + + // compute output row + std::replace_if( + cur_pair[1].begin(), + cur_pair[1].end(), + [&](auto r) { return contains(axes, r); }, + first_axis); + + for(auto t : range(axes.begin() + 1, axes.end())) + { + std::transform(cur_pair[1].begin(), + cur_pair[1].end(), + cur_pair[1].begin(), + [t](auto r) { return r > t ? r - 1 : r; }); + } + + return op; + } + + instruction_ref process_pair(const onnx_parser::node_info& info, + instruction_ref op1, + instruction_ref op2, + const int_mat& map_mat, + size_t input_idx, + int_mat& cur_pair) const + { + // Label is present in current two terms and somewhere in subsequent terms + std::vector batch_axes; + // Label is present in only left term + std::vector left_only; + // Label is present in only right term + std::vector right_only; + // Label is present in current two terms, but not in the subsequent terms + std::vector sum_axes; + + auto not_neg_one = [](auto i) { return i != -1; }; + // Categorize axes according to label distribution in equation + for(int d = 0; d < map_mat[0].size(); ++d) + { + // The label is present in both terms of cur_pair + if(all_of(extract_column(cur_pair, d, 0, cur_pair.size()), not_neg_one)) + { + // The label is present in at least one of the subsequent terms + if(any_of(extract_column(map_mat, d, input_idx + 1, map_mat.size()), not_neg_one)) + batch_axes.push_back(d); + else + sum_axes.push_back(d); + } + // The label is missing in one or both of the cur_pair + else + { + if(cur_pair[0][d] >= 0) + left_only.push_back(d); + else if(cur_pair[1][d] >= 0) + right_only.push_back(d); + else + batch_axes.push_back(d); + } + } + + // Permute the inputs so batch_axes are outermost axes and sum_axes are innermost axes + auto&& perm = concat_vectors(batch_axes, left_only, right_only, sum_axes); + std::vector perm64(perm.begin(), perm.end()); + op1 = apply_transpose_op(info, op1, perm64, cur_pair[0]); + op2 = apply_transpose_op(info, op2, perm64, cur_pair[1]); + + auto new_batch_axes = arange(0, batch_axes.size()); + auto new_sum_axes = arange(perm.size() - sum_axes.size(), perm.size()); + + auto common_labels = set_union(new_batch_axes, new_sum_axes); + std::tie(op1, op2) = apply_broadcast_op(info, op1, op2, common_labels); + + auto op = batch_dot(info, cur_pair, op1, op2, new_batch_axes, new_sum_axes); + + return apply_transpose_op(info, op, invert_permutation(perm64), cur_pair[1]); + } + + instruction_ref batch_dot(const onnx_parser::node_info& info, + int_mat& cur_pair, + instruction_ref op1, + instruction_ref op2, + const std::vector& batch_axes, + const std::vector& sum_axes) const + { + auto op1_lens = op1->get_shape().lens(); + auto op2_lens = op2->get_shape().lens(); + + std::vector dims1{static_cast(calc_dim(batch_axes, op1_lens)), + -1, + static_cast(calc_dim(sum_axes, op1_lens))}; + std::vector dims2{static_cast(calc_dim(batch_axes, op2_lens)), + -1, + static_cast(calc_dim(sum_axes, op2_lens))}; + + op1 = info.add_instruction(make_op("reshape", {{"dims", dims1}}), op1); + op2 = info.add_instruction(make_op("reshape", {{"dims", dims2}}), op2); + op2 = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), op2); + instruction_ref op = info.add_instruction(make_op("dot"), op1, op2); + + std::vector new_lens(op1_lens.size(), 1); + std::transform(op1_lens.begin(), + op1_lens.begin() + (new_lens.size() - sum_axes.size()), + op2_lens.begin(), + new_lens.begin(), + [](auto len1, auto len2) { return std::max(len1, len2); }); + + op = info.add_instruction(make_op("reshape", {{"dims", new_lens}}), op); + + // compute output row + std::transform(cur_pair[0].begin(), + cur_pair[0].end(), + cur_pair[1].begin(), + cur_pair[1].begin(), + [](int lhs, int rhs) { return std::max(lhs, rhs); }); + for(int a : sum_axes) + cur_pair[1][a] = -1; + + return op; + } + + instruction_ref finalize_output(const onnx_parser::node_info& info, + instruction_ref op, + const int_mat& map_mat, + int_mat& cur_pair) const + { + if(any_of(map_mat.back(), [](auto i) { return i >= 0; })) + { + cur_pair[1] = map_mat.back(); + std::vector red; + for(int d = 0; d < map_mat[0].size(); ++d) + { + if(cur_pair[0][d] > 0 and cur_pair[1][d] == -1) + red.push_back(d); + } + + op = apply_reduce_sum_op(info, op, red, cur_pair[1]); + } + + return squeeze_transpose(info, cur_pair, op, map_mat.back()); + } + + // Permutes the labels so they are in alphabetical order and expands the input dimensions to + // match the number of unique labels in the entire equation. + instruction_ref transpose_unsqueeze(const onnx_parser::node_info& info, + int_mat& cur_pair, + instruction_ref op) const + { + std::vector perm; + std::vector unsq_axes; + + for(auto i = 0; i < cur_pair[1].size(); ++i) + { + if(cur_pair[1][i] == -1) // unsqueeze the dimensions corresponding to the missing labels + unsq_axes.push_back(i); + else // permute the rest + perm.push_back(cur_pair[1][i]); + } + + std::vector perm64(perm.begin(), perm.end()); + op = apply_transpose_op(info, op, perm64, perm); + + // compute output row + for(auto axis : unsq_axes) + { + perm.insert(perm.begin() + axis, -1); + } + cur_pair[1] = perm; + + return info.add_instruction(make_op("unsqueeze", {{"axes", unsq_axes}}), op); + } + + // Reverts the effects of transpose_unsqueeze (adjusts the output so it fits the equation) + instruction_ref squeeze_transpose(const onnx_parser::node_info& info, + int_mat& cur_pair, + instruction_ref op, + std::vector row_output) const + { + std::vector sq_axes; + std::vector perm; + + for(auto i = 0; i < row_output.size(); ++i) + { + if(row_output[i] == -1) // squeeze the dimensions corresponding to the missing labels + sq_axes.push_back(i); + else // permute the rest + perm.push_back(row_output[i]); + } + + op = info.add_instruction(make_op("squeeze", {{"axes", sq_axes}}), op); + + if(not perm.empty()) + { + std::vector perm64(perm.begin(), perm.end()); + op = apply_transpose_op(info, op, invert_permutation(perm64), perm); + // compute output row + for(auto axis : sq_axes) + { + perm.insert(perm.begin() + axis, -1); + } + cur_pair[1] = perm; + } + + return op; + } + + instruction_ref apply_transpose_op(const onnx_parser::node_info& info, + instruction_ref op, + const std::vector& perm, + std::vector& row) const + { + op = info.add_instruction(make_op("transpose", {{"permutation", perm}}), op); + // compute output row + row = reorder_dims(row, perm); + + return op; + } + + std::pair + apply_broadcast_op(const onnx_parser::node_info& info, + instruction_ref opl, + instruction_ref opr, + const std::vector& common_labels) const + { + std::pair ret; + + auto llens = opl->get_shape().lens(); + auto rlens = opr->get_shape().lens(); + + bool lbc = false; + bool rbc = false; + for(auto l : common_labels) + { + if(llens[l] == 1 and rlens[l] == 1) + continue; + + if(llens[l] == 1) + { + lbc = true; + llens[l] = rlens[l]; + } + + if(rlens[l] == 1) + { + rbc = true; + rlens[l] = llens[l]; + } + } + + if(lbc) + opl = info.add_instruction(make_op("multibroadcast", {{"out_lens", llens}}), opl); + if(rbc) + opr = info.add_instruction(make_op("multibroadcast", {{"out_lens", rlens}}), opr); + + ret.first = opl; + ret.second = opr; + return ret; + } + + instruction_ref apply_reduce_sum_op(const onnx_parser::node_info& info, + instruction_ref op, + const std::vector& axes, + std::vector& row) const + { + if(axes.empty()) + return op; + + for(int a : axes) + row[a] = -1; + + return info.add_instruction(make_op("reduce_sum", {{"axes", axes}}), op); + } + + // Utility + + int_mat make_matrix(int cur_pair, int cols, int fill_value) const + { + return {static_cast(cur_pair), std::vector(cols, fill_value)}; + } + + std::vector extract_column(int_mat map_mat, int col_idx, int row_begin, int row_end) const + { + std::vector ret; + ret.reserve(row_end - row_begin); + + std::transform(map_mat.begin() + row_begin, + map_mat.begin() + row_end, + std::back_inserter(ret), + [col_idx](const auto& x) { return x[col_idx]; }); + + return ret; + } + + std::vector set_union(const std::vector& lhs, const std::vector& rhs) const + { + std::vector ret; + std::set_union(lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), std::back_inserter(ret)); + + return ret; + } + + std::vector set_difference(const std::vector& lhs, const std::vector& rhs) const + { + std::vector ret; + std::set_difference( + lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), std::back_inserter(ret)); + + return ret; + } + + // Equivalent to numpy.arange without the step parameter + std::vector arange(int start_value, int end_value) const + { + std::vector ret(end_value - start_value); + std::iota(ret.begin(), ret.end(), start_value); + return ret; + } + + template + Vec concat_vectors(Vec vec, Vecs&&... vecs) const + { + size_t reserve_size = vec.size(); + each_args([&](auto&& v) { reserve_size += v.size(); }, vecs...); + + vec.reserve(reserve_size); + each_args([&](auto&& v) { vec.insert(vec.end(), v.begin(), v.end()); }, vecs...); + + return vec; + } + + size_t calc_dim(const std::vector& axes, const std::vector& lens) const + { + return std::accumulate( + axes.begin(), axes.end(), 1, [&](auto acc, auto axis) { return acc * lens[axis]; }); + }; +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 61302a0afba..20c1ad8b0e2 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -25,14 +25,20 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { std::string param_name(std::size_t i, const std::string& prefix) { - assert(i < 10); - return prefix + std::to_string(i); + if(i < 10) + return prefix + std::to_string(i); + const std::size_t max_digits = 5; + if(i >= std::pow(10, max_digits)) + MIGRAPHX_THROW("Too many parameters."); + std::size_t n = log10(i) + 1; + return prefix + ":" + std::string(max_digits - n, '0') + std::to_string(i); } void sort_params(std::vector& params) diff --git a/src/program.cpp b/src/program.cpp index ec7b66d178f..ddccd40eed3 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1081,10 +1081,12 @@ const module* program::get_module(const std::string& name) const { return &impl- module* program::create_module(const std::string& name) { + assert(not contains(impl->modules, name)); auto r = impl->modules.emplace(name, name); return &(r.first->second); } + module* program::create_module(const std::string& name, module m) { assert(not contains(impl->modules, name)); diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 4e79caefd90..16df584593f 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -482,23 +482,174 @@ struct find_double_add_lit_broadcast } }; +/// Find elementswise operators that have all broadcast inputs. It then +/// rewrites the elementwise to do the computation on the non-broadcasted +/// axes, and then broadcast that result. struct find_inner_broadcast { auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); } - static auto non_scalar_op(const std::string& name) + static auto get_non_broadcast_input(instruction_ref ins) { - return [=](instruction_ref ins) { - if(ins->get_shape().scalar()) - return false; - return ins->name() == name; - }; + if(ins->inputs().size() != 1) + return ins; + auto input = ins->inputs().front(); + if(contains(input->name(), "broadcast")) + return get_non_broadcast_input(input); + return input; + } + + static bool is_unsqueeze_needed_for_multibroadcast(const shape& input, const shape& output) + { + if(input.elements() == 1) + return false; + auto shift = output.ndim() - input.ndim(); + if(shift == 0) + return false; + if(std::equal(input.lens().begin(), + input.lens().end(), + output.lens().begin() + shift, + output.lens().end())) + { + return std::all_of(output.lens().begin(), output.lens().begin() + shift, [](auto x) { + return x == 1; + }); + } + return true; + } + // Simple case + void apply_same_broadcasts(module& m, instruction_ref ins) const + { + const auto& broadcasts = ins->inputs(); + // Scalars can have different ndim, so find the largest ndim input + auto max_broadcast = *std::max_element( + broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref broadcast) { + return get_non_broadcast_input(broadcast)->get_shape().ndim(); + })); + auto max_ndim = max_broadcast->get_shape().ndim(); + std::vector inputs; + std::transform(broadcasts.begin(), + broadcasts.end(), + std::back_inserter(inputs), + [&](instruction_ref broadcast) { + auto input = get_non_broadcast_input(broadcast); + auto s = input->get_shape(); + // If scalar doesnt match the other input dims then add a squeeze + if(s.elements() == 1 and s.ndim() > 1 and s.ndim() != max_ndim) + return m.insert_instruction(broadcast, make_op("squeeze"), input); + return input; + }); + auto op = insert_common_op(m, ins, ins->get_operator(), inputs); + m.replace_instruction(ins, broadcasts.front()->get_operator(), op); + } + + void apply_diff_broadcasts(module& m, instruction_ref ins) const + { + const auto& broadcasts = ins->inputs(); + auto ndim = ins->get_shape().ndim(); + // Compute the inner dimensions and axes that the computation will + // use. Also compute the axes that will be broadcasted + std::vector idims; + std::vector iaxes; + std::vector axes; + for(auto axis : range(ndim)) + { + if(std::all_of(broadcasts.begin(), broadcasts.end(), [&](instruction_ref i) { + auto s = i->get_shape(); + return s.lens()[axis] == 1 or s.strides()[axis] == 0; + })) + { + axes.push_back(axis); + } + else + { + iaxes.push_back(axis); + idims.push_back(ins->get_shape().lens()[axis]); + } + } + // If the inner axes are the same as the original operator then + // there is no reason to do this transformation. + if(iaxes.size() == ndim) + return; + std::vector inputs; + std::transform( + broadcasts.begin(), + broadcasts.end(), + std::back_inserter(inputs), + [&](instruction_ref broadcast) { + auto input = broadcast->inputs().front(); + auto s = input->get_shape(); + + // If its a single element then just return that as an input + if(s.elements() == 1) + { + if(s.lens().size() > 1) + return m.insert_instruction(broadcast, make_op("squeeze"), input); + return input; + } + + // Find how the axes are shifted from the broadcast + std::int64_t shift = ndim - s.ndim(); + if(broadcast->name() == "broadcast") + shift = broadcast->get_operator().to_value()["axis"].to(); + // Compute the squeeze axes to be used by taking the inner + // axes and shifting to what the axes will be on the + // input + std::vector sq_axes; + for(auto axis : axes) + { + auto iaxis = axis - shift; + if(iaxis < 0) + continue; + if(iaxis >= s.ndim()) + continue; + sq_axes.push_back(iaxis); + } + instruction_ref result = input; + if(not sq_axes.empty()) + result = m.insert_instruction( + broadcast, make_op("squeeze", {{"axes", sq_axes}}), result); + // If the number of dimension are still smaller than the + // number of inner axes, then we need to insert a + // broadcast to have the same dimensions for all inputs. + if(result->get_shape().ndim() < iaxes.size()) + { + // We find the first inner axis that can be mapped to the input + auto start_axis = std::find_if(iaxes.begin(), + iaxes.end(), + [&](auto x) { return x >= shift; }) - + iaxes.begin(); + result = m.insert_instruction( + broadcast, + make_op("broadcast", {{"axis", start_axis}, {"out_lens", idims}}), + result); + } + return result; + }); + auto op = insert_common_op(m, ins, ins->get_operator(), inputs); + if(iaxes.size() == 1) + { + m.replace_instruction( + ins, + make_op("broadcast", + {{"axis", iaxes.front()}, {"out_lens", ins->get_shape().lens()}}), + op); + } + else + { + auto unsqueeze = + is_unsqueeze_needed_for_multibroadcast(op->get_shape(), ins->get_shape()) + ? m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), op) + : op; + m.replace_instruction( + ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze); + } } void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; - auto broadcasts = ins->inputs(); + auto ins = r.result; + const auto& broadcasts = ins->inputs(); if(broadcasts.empty()) return; // Skip if different data types are used @@ -506,65 +657,45 @@ struct find_inner_broadcast return i->get_shape().type() != broadcasts.front()->get_shape().type(); })) return; - bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and - any_of(broadcasts, non_scalar_op("multibroadcast")); - // If the broadcast is not a single dimension, then dont perform inner_broadcast - if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) { - if(i->get_shape().scalar()) - return false; - if(i->name() == "multibroadcast") - return false; - auto input = i->inputs().at(0); - const auto& lens = input->get_shape().lens(); - return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) { - return d == 1; - }) < (lens.size() - 1); + + // All inputs should have less elements + if(not all_of(broadcasts, [&](instruction_ref broadcast) { + auto input = broadcast->inputs().front(); + return input->get_shape().elements() < ins->get_shape().elements(); })) return; - if(broadcasts.size() > 1) + // Find first broadcast that is not a scalar + auto first = + std::find_if(broadcasts.begin(), broadcasts.end(), [&](instruction_ref broadcast) { + return not broadcast->get_shape().scalar(); + }); + // Try to see if we can do a simple case that just applies the op to + // the inputs of the broadcasts, and then just put that same + // broadcast after the op. For this case we need each of the + // broadcasts to be the same and the inputs to have the same dimesion + // (or be scalar). + const bool same_broadcasts = + std::all_of(first, broadcasts.end(), [&](instruction_ref broadcast) { + if(broadcast->get_operator() != (*first)->get_operator()) + return false; + auto s1 = get_non_broadcast_input(broadcast)->get_shape(); + auto s2 = get_non_broadcast_input(*first)->get_shape(); + if(s1.elements() == 1) + return true; + return s1.lens() == s2.lens(); + }); + if(same_broadcasts) { - auto bcast_strides = broadcasts.front()->get_shape().strides().size(); - std::vector common_axis(bcast_strides, 0); - // go through the strides of each broadcast, - // keep track of values that are equal to 0 in a dimension - for(auto i = 0; i < bcast_strides; i++) - { - for(const auto& broadcast : broadcasts) - { - if(broadcast->get_shape().strides()[i] == 0) - common_axis[i]++; - } - } - // if no common broadcast axis, transformation is not useful - if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) { - return num_common > 1; - }) == common_axis.end()) - return; + apply_same_broadcasts(m, ins); + } + // Skip if any input to the broadcasted inputs is already broadcasted + // as the below algorithm may not be able to handle such case. + else if(std::none_of(broadcasts.begin(), broadcasts.end(), [](instruction_ref broadcast) { + return broadcast->inputs().front()->get_shape().broadcasted(); + })) + { + apply_diff_broadcasts(m, ins); } - - std::vector inputs; - std::transform(broadcasts.begin(), - broadcasts.end(), - std::back_inserter(inputs), - [&](instruction_ref i) { - auto input = i->inputs().front(); - if(mixed_broadcasts and not i->get_shape().scalar() and - i->get_shape().lens().size() > 1) - return m.insert_instruction(i, make_op("squeeze"), input); - return input; - }); - - std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) { - if(i->get_shape().scalar()) - return 2; - else if(i->name() == "broadcast") - return 0; - if(i->name() == "multibroadcast") - return 1; - return 3; - })); - auto op = insert_common_op(m, ins, ins->get_operator(), inputs); - m.replace_instruction(ins, broadcasts.front()->get_operator(), op); } }; @@ -792,6 +923,9 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector get_splits(instruction_ref ins) { std::vector result; @@ -803,16 +937,22 @@ std::vector get_splits(instruction_ref ins) return {}; auto get_slice = [](auto& i) -> auto& { return any_cast(i->get_operator()); }; auto&& axes = get_slice(result.front()).axes; + + // "slice" instructions must all have the same axes if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; })) return {}; auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; }; auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; }; + + // Sort the "slice" instructions in order of starts std::sort( result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); }); if(std::any_of(get_start(result.front()).begin(), get_start(result.front()).end(), [&](auto i) { return i != 0; })) return {}; + + // one slice must "start" where the last slice "end" auto it = std::adjacent_find( result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); }); if(it != result.end()) @@ -998,6 +1138,10 @@ struct find_splits } }; +/** + * Matcher for a sequence of "slice" operations whose outputs are put back + * together by a "concat". + */ struct find_split_concat { auto matcher() const @@ -1008,40 +1152,56 @@ struct find_split_concat void apply(module& m, const match::matcher_result& r) const { + // Verifies that the slices meet several conditions: they must all output to the same + // concat instruction, slice on the same (1 only) axis, and the end of one slice + // must match the start of the next. auto ins = r.result; auto splits = get_splits(ins); if(splits.empty()) return; + // Each slice must output to only one instruction if(std::any_of( splits.begin(), splits.end(), [](auto i) { return i->outputs().size() != 1; })) return; - // Check for concat operator + // The single output instruction for all items in the list must be the same one auto concat = splits.front()->outputs().front(); if(std::any_of(splits.begin(), splits.end(), [&](auto i) { return i->outputs().front() != concat; })) return; - // Check axis match + + // The axis for the common output instruction must be the same as for the split ops auto concat_op = any_cast(concat->get_operator()); auto split_op = any_cast(splits.front()->get_operator()); if(split_op.axes.size() != 1) return; if(split_op.axes.front() != concat_op.axis) return; - // Replace args + + // Find where the slices are in the concat instruction's inputs (concat can have + // any number of inputs) auto args = concat->inputs(); auto it = std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); }); + // Verify the slices were found, and the list is long enough if(std::distance(it, args.end()) < splits.size()) return; - // If the slices are not in order then stop + // Don't do anything if the "slice" inputs to the concat op have other operations mixed in + // among them + if(std::any_of(it, it + splits.size(), [](instruction_ref x) { + return x->get_operator().name() != "slice"; + })) + return; + // Check that the slices passed to concat are in order. if(not std::is_sorted(it, it + splits.size(), [](instruction_ref x, instruction_ref y) { auto xop = any_cast(x->get_operator()); auto yop = any_cast(y->get_operator()); return std::tie(xop.starts, xop.ends) < std::tie(yop.starts, yop.ends); })) return; + + // Perform the substitution *it = splits.front()->inputs().front(); args.erase(std::next(it), it + splits.size()); diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a651d6e2432..a0a952d6aac 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -244,6 +244,21 @@ struct find_nested_slice } }; +/** + * Example case + * From: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * mb0: multibroadcast(param0, output_lens = [2, 3, 4]) + * mb1: multibroadcast(param1, output_lens = [2, 3, 4]) + * concat(mb0, mb1, axis = 2) + * + * To: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * con0: concat(param0, param1, axis = 1) + * multibroadcast(con0, lens = [2, 3, 4]) + */ struct find_concat_multibroadcasts { auto matcher() const @@ -253,32 +268,62 @@ struct find_concat_multibroadcasts void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto op = any_cast(ins->get_operator()); - auto out_lens = ins->get_shape().lens(); - auto inputs = ins->inputs(); - auto in_strides = inputs.front()->get_shape().strides(); + auto concat_ins = mr.result; + auto concat_op = any_cast(concat_ins->get_operator()); + auto concat_out_lens = concat_ins->get_shape().lens(); + auto concat_inputs = concat_ins->inputs(); + auto front_mb_strides = concat_inputs.front()->get_shape().strides(); + assert(concat_op.axis >= 0); // Only apply when concat axis is not a broadcasted dimension - if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { - return i->get_shape().strides()[op.axis] == 0; + if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) { + return i->get_shape().strides()[concat_op.axis] == 0; })) { return; } - // Use inputs of multibroadcast ops as inputs to new concat op - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) { + // Get the inputs of multibroadcast ops. Will be used as inputs to new concat op + std::vector mb_inputs(concat_inputs.size()); + std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { return i->inputs().front(); }); + // Check that the inputs into the multibroadcasts have the same rank + const auto& first_shape = mb_inputs.front()->get_shape(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { + return mb_in->get_shape().ndim() == first_shape.ndim(); + })) + { + return; + } + // Reduce axis by number of leading broadcasted dimensions - if(inputs.front()->get_shape().lens().size() < out_lens.size()) - op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0); + if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) + { + concat_op.axis -= + std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); + } - auto concat = m.insert_instruction(ins, op, inputs); - m.replace_instruction( - ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat); + // Inputs to multibroadcasts should have the same dimensions except for the axis to + // concatenate over + const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) { + const auto& lens = input_to_mb->get_shape().lens(); + return std::equal( + lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and + std::equal(lens.begin() + concat_op.axis + 1, + lens.end(), + front_in_lens.begin() + concat_op.axis + 1); + })) + { + return; + } + + auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); + m.replace_instruction(concat_ins, + migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), + new_concat_ins); } }; diff --git a/src/targets/cpu/CMakeLists.txt b/src/targets/cpu/CMakeLists.txt index 4c08d60123c..e8a57422b01 100755 --- a/src/targets/cpu/CMakeLists.txt +++ b/src/targets/cpu/CMakeLists.txt @@ -97,6 +97,7 @@ else() endif() rocm_install_targets( + PRIVATE TARGETS migraphx_cpu INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include diff --git a/src/targets/fpga/CMakeLists.txt b/src/targets/fpga/CMakeLists.txt index 39df3be6400..304c92851c8 100644 --- a/src/targets/fpga/CMakeLists.txt +++ b/src/targets/fpga/CMakeLists.txt @@ -36,6 +36,7 @@ rocm_clang_tidy_check(migraphx_fpga) target_link_libraries(migraphx_fpga migraphx) rocm_install_targets( + PRIVATE TARGETS migraphx_fpga INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index d2e73ef1de7..da733498f7d 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -332,6 +332,7 @@ add_subdirectory(driver) add_subdirectory(hiprtc) rocm_install_targets( + PRIVATE TARGETS migraphx_gpu migraphx_device compile_for_gpu INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index d2ae8584d46..bf9a269f3e1 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -21,11 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include -#include + #include #include +#include #include +#include +#include #include namespace migraphx { @@ -106,7 +108,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) return false; } auto device_name = trim(split_string(get_device_name(), ':').front()); - if(device_name == "gfx940") + if(starts_with(device_name, "gfx94")) { if(ins->get_shape().type() == shape::half_type) { diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index c6b40b753a1..70163a61365 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace migraphx { @@ -200,8 +201,8 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, { auto [upper_input, op_stream] = get_fusable_input_op_stream(input); top_inputs.push_back(upper_input); - instruction_ref prev_input = mm->add_parameter("y" + std::to_string(input_cnt++), - upper_input->get_shape().as_standard()); + instruction_ref prev_input = + mm->add_parameter(param_name(input_cnt++, "y"), upper_input->get_shape().as_standard()); for(const auto& op : reverse(op_stream)) { prev_input = mm->add_instruction(op, {prev_input}); diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index d97712a992e..579f4bfd552 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -166,7 +166,7 @@ struct fusion const std::unordered_set& get_supported_archs() { static std::unordered_set supported_archs{ - "gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"}; + "gfx900", "gfx906", "gfx908", "gfx1030", "gfx940", "gfx941", "gfx942"}; return supported_archs; } diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index 4d88e8d5800..dc395b8eece 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -39,7 +39,8 @@ namespace gpu { MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); -MIGRAPHX_GPU_EXPORT bool is_module_fusible(const module& m, const value& solution); +MIGRAPHX_GPU_EXPORT bool +is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution); struct MIGRAPHX_GPU_EXPORT mlir_code_object { diff --git a/src/targets/gpu/include/migraphx/gpu/rocblas.hpp b/src/targets/gpu/include/migraphx/gpu/rocblas.hpp index a03fa8385d2..d23c40f9d05 100644 --- a/src/targets/gpu/include/migraphx/gpu/rocblas.hpp +++ b/src/targets/gpu/include/migraphx/gpu/rocblas.hpp @@ -38,13 +38,13 @@ using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_h rocblas_handle_ptr create_rocblas_handle_ptr(); rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s); - +#endif struct context; MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available(); -#endif + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index e998ed49b85..656c28d2ec9 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -48,19 +48,20 @@ static module create_pointwise_module(module_ref in_mod) pw_mod.add_parameter(any_cast(param->get_operator()).parameter, shape{param->get_shape().type()}); } - pw_mod.add_instructions(in_mod, - &map_ins, - [](module& m, - instruction_ref ins, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) -> instruction_ref { - if(op.name() == "multibroadcast" and - inputs.front()->name() == "@literal") - return inputs.front(); - else - return m.insert_instruction(ins, op, inputs, mod_args); - }); + auto return_args = pw_mod.add_instructions( + in_mod, + &map_ins, + [](module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) -> instruction_ref { + if(op.name() == "multibroadcast" and inputs.front()->name() == "@literal") + return inputs.front(); + else + return m.insert_instruction(ins, op, inputs, mod_args); + }); + pw_mod.add_return(return_args); return pw_mod; } @@ -81,7 +82,7 @@ struct mlir_compiler : compiler // check if (a) module is fused (b) contains a dot instruction and (c) perfConfig can not // allow fused module if(gemm_ins != smod->end() and std::distance(gemm_ins, smod->end()) > 2 and - not is_module_fusible(*smod, solution)) + not is_module_fusible(*smod, ctx, solution)) { auto input_args = ins->inputs(); input_args.pop_back(); diff --git a/src/targets/gpu/kernel.cpp b/src/targets/gpu/kernel.cpp index 1cbb45852b1..033adbd07a2 100644 --- a/src/targets/gpu/kernel.cpp +++ b/src/targets/gpu/kernel.cpp @@ -27,6 +27,9 @@ #include #include +#ifdef _WIN32 +#include +#else // extern declare the function since hip/hip_ext.h header is broken extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT uint32_t, @@ -42,6 +45,7 @@ extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT hipEvent_t = nullptr, hipEvent_t = nullptr, uint32_t = 0); +#endif namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index b52a61eb498..c64ab553159 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,18 @@ namespace migraphx { +template +struct acc_type +{ + using type = float; +}; + +template <> +struct acc_type +{ + using type = double; +}; + template constexpr auto vec_reduce(const array& a, Op op) { @@ -50,33 +62,33 @@ __device__ void generic_binary_layernorm( using reduce_output = reduce::with_axis; block::template run([&](auto, auto r) { - auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); - using value_type = typename Input1::type; - using vec_value_type = vec_type; + using value_type = typename Input1::type; + using vec_value_type = typename acc_type>::type; + + auto input = r.inner([&](auto x1, auto x2) { + return migraphx::convert(op(x1, x2)); + })(input1, input2); + constexpr auto relements = r.template elements(); constexpr auto relements_r = vec_value_type{1.0 / relements}; auto relements_rsqrt = sqrt(relements_r); - auto means = r.reduce(op::sum{}, - make_array(vec_value_type{0}, vec_value_type{0}), - [&](auto x) { - auto x_out = x * relements_r; - // dividing x by sqrt(relements) before squaring allows computing - // higher values before overflow in low precision - auto x2_sqrt = x * relements_rsqrt; - return make_array(x_out, x2_sqrt * x2_sqrt); - })(input); + auto means = r.reduce(op::sum{}, make_array(0, 0), [&](auto x) { + auto x_out = x * relements_r; + // dividing x by sqrt(relements) before squaring allows computing + // higher values before overflow in low precision + auto x2_sqrt = x * relements_rsqrt; + return make_array(x_out, x2_sqrt * x2_sqrt); + })(input); - auto mean_x = means[0]; - auto mean_x2 = means[1]; - auto variance = mean_x2 - (mean_x * mean_x); - value_type eps_val = implicit_conversion(eps); + auto mean_x = means[0]; + auto mean_x2 = means[1]; + auto variance = mean_x2 - (mean_x * mean_x); + vec_value_type eps_val = implicit_conversion(eps); + auto rsqrt_val = rsqrt(variance + eps_val); r.inner([&](auto& y, auto x, auto... xs) { - auto m = x - mean_x; - - // m * rsqrt(mean(m ^ 2) + epsilon) - y = compute(m * rsqrt(variance + eps_val), xs...); + y = compute(migraphx::convert>((x - mean_x) * rsqrt_val), xs...); })(output, input, inputs...); }); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 5a6cca7bc24..da00ff9c781 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -256,6 +256,21 @@ constexpr auto min(const T& a, const U& b) return min>(a, b); } +template ())> +constexpr T mod(const T& a, const T& b) +{ + if constexpr(is_integral{}) + // onnx mod operator requires numpy style modulus + return ((a % b) + b) % b; + return static_cast(fmod(remainder(a, b) + b, b)); +} + +template {} and not is_any_vec())> +constexpr auto mod(const T& a, const U& b) +{ + return mod>(a, b); +} + MIGRAPHX_DEVICE_MATH_VEC(abs) MIGRAPHX_DEVICE_MATH_VEC(acos) MIGRAPHX_DEVICE_MATH_VEC(acosh) @@ -275,6 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(min) +MIGRAPHX_DEVICE_MATH_VEC(mod) MIGRAPHX_DEVICE_MATH_VEC(nearbyint) MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(remainder) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index cf59cdae554..d12e9b56a2d 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -947,10 +947,12 @@ struct mlir_program std::string sym_name; }; -bool is_module_fusible(const module& m, const value& solution) +bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) { mlir_program mp; + mp.set_gpu_properties(migraphx_ctx); mp.parse(m); + mp.run_high_level_pipeline(); return mlirIsModuleFusible(mp.mmodule.get(), make_mlir_string_ref(*solution.if_string())); } diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index 5779f7ae6e1..798fefbb811 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -48,7 +48,7 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) rocblas_set_stream(rb.get(), s); return rb; } - +#endif bool get_compute_fp32_flag() { const auto device_name = trim(split_string(get_device_name(), ':').front()); @@ -57,13 +57,17 @@ bool get_compute_fp32_flag() bool rocblas_fp8_available() { +#if MIGRAPHX_USE_ROCBLAS #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API return false; #else return gfx_has_fp8_intrinsics(); #endif -} +#else + return false; #endif +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3264f298f62..4a18e25aab5 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -97,13 +97,11 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::tuple_type); // whiltelist supported Ops for the FP8 std::set unsupported_fp8_ops = {}; -#if MIGRAPHX_USE_ROCBLAS if(not gpu::rocblas_fp8_available()) { unsupported_fp8_ops.insert("dot"); unsupported_fp8_ops.insert("quant_dot"); } -#endif // MIOpen doesn't have support for fp8 pooling yet. unsupported_fp8_ops.insert("pooling"); if(not gpu::gfx_has_fp8_intrinsics()) diff --git a/src/targets/ref/CMakeLists.txt b/src/targets/ref/CMakeLists.txt index 0d72afba107..3f442406a49 100644 --- a/src/targets/ref/CMakeLists.txt +++ b/src/targets/ref/CMakeLists.txt @@ -36,6 +36,7 @@ target_link_libraries(migraphx_ref PUBLIC migraphx) migraphx_generate_export_header(migraphx_ref) rocm_install_targets( + PRIVATE TARGETS migraphx_ref INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include diff --git a/src/tf/CMakeLists.txt b/src/tf/CMakeLists.txt index 316bd7371ed..6257be2afa3 100644 --- a/src/tf/CMakeLists.txt +++ b/src/tf/CMakeLists.txt @@ -60,6 +60,7 @@ endif() target_link_libraries(migraphx_tf PUBLIC migraphx) rocm_install_targets( + PRIVATE TARGETS migraphx_tf ) diff --git a/src/tf/tf.cpp b/src/tf/tf.cpp index e5b5cdff055..7b6c1322d66 100644 --- a/src/tf/tf.cpp +++ b/src/tf/tf.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,9 +37,9 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -program parse_tf(const std::string& name, const tf_options& options) +template +program parse_tf_from(const tf_options& options, Ts&&... xs) { - std::fstream input(name.c_str(), std::ios::in | std::ios::binary); tf::tf_parser parser; parser.is_nhwc = options.is_nhwc; parser.batch_size = options.batch_size; @@ -50,7 +50,7 @@ program parse_tf(const std::string& name, const tf_options& options) // Log the program when it can't be parsed try { - parser.parse_from(input); + parser.parse_from(std::forward(xs)...); } catch(...) { @@ -58,11 +58,27 @@ program parse_tf(const std::string& name, const tf_options& options) throw; } #else - parser.parse_from(input); + parser.parse_from(std::forward(xs)...); #endif return std::move(parser.prog); } +program parse_tf(const std::string& name, const tf_options& options) +{ + std::fstream input(name.c_str(), std::ios::in | std::ios::binary); + return parse_tf_from(options, input); +} + +program parse_tf_buffer(const std::string& buffer, const tf_options& options) +{ + return parse_tf_from(options, buffer.data(), buffer.size()); +} + +program parse_tf_buffer(const void* data, std::size_t size, const tf_options& options) +{ + return parse_tf_from(options, data, size); +} + std::vector get_tf_operators() { return tf::get_op_parsers(); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/tf/tf_parser.cpp b/src/tf/tf_parser.cpp index 6d95c34d988..a53c7b29ec1 100644 --- a/src/tf/tf_parser.cpp +++ b/src/tf/tf_parser.cpp @@ -393,6 +393,19 @@ void tf_parser::parse_from(std::istream& is) } } +void tf_parser::parse_from(const void* data, std::size_t size) +{ + tensorflow::GraphDef graph; + if(graph.ParseFromArray(data, size)) + { + this->parse_graph(graph); + } + else + { + throw std::runtime_error("Failed reading tf buffer array"); + } +} + shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const { shape::type_t shape_type{}; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b6ecc2e35d3..88eb38340a7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ # #################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -82,12 +82,7 @@ add_subdirectory(onnx) # tf test set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf) -add_executable(test_tf tf/tf_test.cpp) -rocm_mark_as_test(test_tf) -rocm_clang_tidy_check(test_tf) -target_link_libraries(test_tf migraphx_tf) -target_include_directories(test_tf PUBLIC include) -add_test(NAME test_tf COMMAND $ WORKING_DIRECTORY ${TEST_TF_DIR}) +add_subdirectory(tf) add_subdirectory(api) add_subdirectory(verify) diff --git a/test/api/test_tf_parser.cpp b/test/api/test_tf_parser.cpp index 4cdc23fb9f9..b440e6b64f9 100644 --- a/test/api/test_tf_parser.cpp +++ b/test/api/test_tf_parser.cpp @@ -27,7 +27,7 @@ TEST_CASE(load_tf) { - auto p = migraphx::parse_tf("add_test.pb"); + auto p = migraphx::parse_tf("models/add_test.pb"); auto shapes = p.get_output_shapes(); CHECK(shapes.size() == 1); } @@ -38,7 +38,7 @@ TEST_CASE(load_tf_default_dim) size_t batch = 2; tf_options.set_default_dim_value(batch); tf_options.set_nhwc(); - auto p = migraphx::parse_tf("conv_batch_test.pb", tf_options); + auto p = migraphx::parse_tf("models/conv_batch_test.pb", tf_options); auto shapes = p.get_output_shapes(); CHECK(shapes.size() == 1); CHECK(shapes.front().lengths().front() == batch); @@ -50,7 +50,7 @@ TEST_CASE(load_tf_param_shape) std::vector new_shape{1, 3}; tf_options.set_input_parameter_shape("0", new_shape); tf_options.set_input_parameter_shape("1", new_shape); - auto p = migraphx::parse_tf("add_test.pb", tf_options); + auto p = migraphx::parse_tf("models/add_test.pb", tf_options); auto shapes = p.get_output_shapes(); CHECK(shapes.size() == 1); CHECK(shapes.front().lengths() == new_shape); @@ -60,7 +60,7 @@ TEST_CASE(load_tf_multi_outputs) { migraphx::tf_options tf_options; tf_options.set_output_names({"relu", "tanh"}); - auto p = migraphx::parse_tf("multi_output_test.pb", tf_options); + auto p = migraphx::parse_tf("models/multi_output_test.pb", tf_options); auto shapes = p.get_output_shapes(); CHECK(shapes.size() == 2); } diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 0c29455f518..ba6cf4579b5 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -113,6 +113,7 @@ TEST_CASE(pointwise_reduce) mm->add_return({rsum}); } run_pass(p1); + migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -133,6 +134,108 @@ TEST_CASE(pointwise_reduce) EXPECT(p1 == p2); } +TEST_CASE(scalar_multibroadcast) +{ + // Matches the find_pointwise_reduce matcher, but input x has a (scalar) shape + // incompatible with the multibroadcast instruction; therefore it + // creates a fused_reduce module but does not add a submodule for the + // multibroadcast instruction. + migraphx::shape sdot{migraphx::shape::double_type, {80, 204, 204}}; + migraphx::shape sdot_double{migraphx::shape::double_type, {80, 204, 204}}; + migraphx::shape scalar{migraphx::shape::double_type, {1}, {0}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", scalar); + auto zap = add_pointwise(p1, "main:pointwise0", {x}, single_pointwise("sqrt")); + auto pow = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sdot.lens()}}), zap); + auto bip = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), pow); + + mm->add_return({bip}); + } + run_pass(p1); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", scalar); + auto zap = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); + + auto pow = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sdot.lens()}}), zap); + + // Add a reduce module. These are created by fuse_reduce::apply() for any reduce + // instruction whether the individual matchers do anything or not. + auto* reduce_mod = p2.create_module("main:reduce_sum0"); + auto x0 = reduce_mod->add_parameter("x0", sdot_double); + auto sqrtbc = + reduce_mod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x0); + reduce_mod->add_return({sqrtbc}); + + EXPECT(test::throws([&] { + mm->add_instruction( + migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod}); + })); + // reduce modules must be flagged for bypass when running subsequent passes + reduce_mod->set_bypass(); + auto bip = mm->add_instruction( + migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod}); + mm->add_return({bip}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(scalar_multibroadcast_contiguous) +{ + // Contains a contiguous op which is not passed through. + migraphx::shape sdot{migraphx::shape::double_type, {80, 204, 204}}; + migraphx::shape scalar{migraphx::shape::double_type, {1}, {0}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", scalar); + auto zap = add_pointwise(p1, "main:pointwise0", {x}, single_pointwise("sqrt")); + auto pow = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sdot.lens()}}), zap); + auto bip = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), pow); + auto sqrtbc = mm->add_instruction(migraphx::make_op("contiguous"), bip); + + mm->add_return({sqrtbc}); + } + run_pass(p1); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", scalar); + auto zap = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); + + auto pow = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sdot.lens()}}), zap); + + // Add a reduce module. These are created by fuse_reduce::apply() for any reduce + // instruction whether the individual matchers do anything or not. + auto* reduce_mod = p2.create_module("main:reduce_sum0"); + + auto x0 = reduce_mod->add_parameter("x0", sdot); + auto sqrtbc = + reduce_mod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x0); + reduce_mod->add_return({sqrtbc}); + + EXPECT(test::throws([&] { + mm->add_instruction( + migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod}); + })); + // reduce modules must be flagged for bypass when running subsequent passes + reduce_mod->set_bypass(); + auto bip = mm->add_instruction( + migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod}); + mm->add_return({bip}); + } + EXPECT(p1 == p2); +} + TEST_CASE(pointwise_broadcast_reduce_reshape) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -792,11 +895,12 @@ TEST_CASE(reduce_reshape_reduce) auto y = mm->add_parameter("y", s2); auto x1r = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x1); auto x2r = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3r.lens()}}), x2); + auto yr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y); auto freduce = add_reduce( p2, - "main:pointwise2:main:reduce_sum2_reshape_reshape:main:reduce_sum1:main:reduce_sum0:" - "main:pointwise0:main:pointwise1_reshape", - {x1r, x2r}, + "main:pointwise2:main:reduce_sum2_reshape_reshape:main:pointwise3_reshape:main:reduce_" + "sum1:main:reduce_sum0:main:pointwise0:main:pointwise1_reshape", + {x1r, x2r, yr}, {3, 4}, [&](auto* rm, const auto& inputs, const auto& axes) { auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), @@ -813,15 +917,16 @@ TEST_CASE(reduce_reshape_reduce) migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rsum2); auto sub2 = add_pointwise( p2, rm, "main:pointwise2", {rsum2b, inputs[0]}, single_pointwise("sub")); - return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), sub2); + auto rsum3 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), sub2); + auto rsum3b = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rsum3); + return add_pointwise( + p2, rm, "main:pointwise3", {rsum3b, inputs[2]}, single_pointwise("add")); }); auto freducer = - mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2r.lens()}}), freduce); - // TODO: Fuse the last add as well - auto freducerb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), freducer); - auto add = add_pointwise(p2, "main:pointwise3", {freducerb, y}, single_pointwise("add")); - mm->add_return({add}); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), freduce); + mm->add_return({freducer}); } EXPECT(p1.sort() == p2.sort()); } diff --git a/test/gpu/gemm_tune.cpp b/test/gpu/gemm_tune.cpp index 75081bc37ba..52a32a771b2 100644 --- a/test/gpu/gemm_tune.cpp +++ b/test/gpu/gemm_tune.cpp @@ -182,7 +182,6 @@ TEST_CASE(gemm_tune_strided_lowered) EXPECT(0 == solution_idx.to()); #endif } -#endif TEST_CASE(gemm_tune_invalid_sol_index) { @@ -223,5 +222,6 @@ TEST_CASE(gemm_tune_invalid_sol_index) EXPECT(0 != solution_idx.to()); #endif } +#endif int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/multi_target/multitarget_test.cpp b/test/multi_target/multitarget_test.cpp index 8dee6a0ca77..9adbfe6ad79 100644 --- a/test/multi_target/multitarget_test.cpp +++ b/test/multi_target/multitarget_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -95,7 +95,8 @@ bool is_compiled_cpu_module(const migraphx::module& m) { return std::all_of(m.begin(), m.end(), [](auto ins) { auto ins_name = ins.name(); - if(not migraphx::starts_with(ins_name, "@")) + // sub is not lowered on CPU backend due to vectorization on non-aligned memory. + if(not migraphx::starts_with(ins_name, "@") and ins_name != "sub") { if(not migraphx::starts_with(ins_name, "cpu::") and not migraphx::starts_with(ins_name, "dnnl::") and diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index 7dc1aa5712d..843156e909b 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -7c80c39f74d2508f1c553ce991914be8bf694944 +33a68d221f28bd8d412f2e9188e50bac8a255b71 diff --git a/test/onnx/CMakeLists.txt b/test/onnx/CMakeLists.txt index 97fdff3295c..78827b9bbdf 100644 --- a/test/onnx/CMakeLists.txt +++ b/test/onnx/CMakeLists.txt @@ -24,14 +24,16 @@ function(add_onnx_test TEST_NAME) - add_executable(${TEST_NAME} ${ARGN}) + rocm_add_test_executable(${TEST_NAME} ${ARGN}) rocm_clang_tidy_check(${TEST_NAME}) - target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) + target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref onnx_files) target_include_directories(${TEST_NAME} PUBLIC ../include include) - add_test(NAME ${TEST_NAME} COMMAND $ WORKING_DIRECTORY ${TEST_ONNX_DIR}) - rocm_mark_as_test(${TEST_NAME}) endfunction() +include(Embed) +file(GLOB_RECURSE ONNX_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.onnx ${CMAKE_CURRENT_SOURCE_DIR}/*.weight) +add_embed_library(onnx_files ${ONNX_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}) + file(GLOB ONNX_PARSE_TESTS CONFIGURE_DEPENDS parse/*.cpp) file(GLOB ONNX_VERIFY_TESTS CONFIGURE_DEPENDS verify/*.cpp) diff --git a/test/onnx/einsum_2d_3d_multiplication_test.onnx b/test/onnx/einsum_2d_3d_multiplication_test.onnx new file mode 100644 index 00000000000..1dc5f8b9fd3 --- /dev/null +++ b/test/onnx/einsum_2d_3d_multiplication_test.onnx @@ -0,0 +1,18 @@ +  einsum_2d_3d_multiplication_test: +* +x1 +x2y"Einsum* +equation"ij,jkl  einsum_2d_3d_multiplication_testZ +x1 +  + +Z +x2 + + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_3_inputs_test.onnx b/test/onnx/einsum_3_inputs_test.onnx new file mode 100644 index 00000000000..fe5d6649428 --- /dev/null +++ b/test/onnx/einsum_3_inputs_test.onnx @@ -0,0 +1,25 @@ + einsum_3_inputs_test:² +7 +x1 +x2 +x3y"Einsum* +equation"bac,cd,def->ebc einsum_3_inputs_testZ +x1 + + + +Z +x2 +  + +Z +x3 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_3d_broadcast_test.onnx b/test/onnx/einsum_3d_broadcast_test.onnx new file mode 100644 index 00000000000..d2acb96b6df --- /dev/null +++ b/test/onnx/einsum_3d_broadcast_test.onnx @@ -0,0 +1,20 @@ + einsum_3d_broadcast_test:™ +0 +x1 +x2y"Einsum* +equation" bik,bkj->bij einsum_3d_broadcast_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_3d_diagonal_test.onnx b/test/onnx/einsum_3d_diagonal_test.onnx new file mode 100644 index 00000000000..78afe5f3172 --- /dev/null +++ b/test/onnx/einsum_3d_diagonal_test.onnx @@ -0,0 +1,13 @@ + einsum_3d_diagonal_test:n +% +xy"Einsum* +equation"iii->i einsum_3d_diagonal_testZ +x + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_3d_opposite_broadcast_test.onnx b/test/onnx/einsum_3d_opposite_broadcast_test.onnx new file mode 100644 index 00000000000..14bffd25d02 --- /dev/null +++ b/test/onnx/einsum_3d_opposite_broadcast_test.onnx @@ -0,0 +1,20 @@ + !einsum_3d_opposite_broadcast_test:¢ +0 +x1 +x2y"Einsum* +equation" bik,bkj->bij !einsum_3d_opposite_broadcast_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_batch_matrix_diagonal_test.onnx b/test/onnx/einsum_batch_matrix_diagonal_test.onnx new file mode 100644 index 00000000000..db52eaf73ca --- /dev/null +++ b/test/onnx/einsum_batch_matrix_diagonal_test.onnx @@ -0,0 +1,13 @@ + !einsum_batch_matrix_diagonal_test:} +* +xy"Einsum* +equation" ...ii->...i !einsum_batch_matrix_diagonal_testZ +x + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_batch_matrix_multiplication_test.onnx b/test/onnx/einsum_batch_matrix_multiplication_test.onnx new file mode 100644 index 00000000000..f27df9fb6e0 --- /dev/null +++ b/test/onnx/einsum_batch_matrix_multiplication_test.onnx @@ -0,0 +1,20 @@ + 'einsum_batch_matrix_multiplication_test:¨ +0 +x1 +x2y"Einsum* +equation" ijk,ikl->ijl 'einsum_batch_matrix_multiplication_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_bilinear_transformation_test.onnx b/test/onnx/einsum_bilinear_transformation_test.onnx new file mode 100644 index 00000000000..258b4b9b007 --- /dev/null +++ b/test/onnx/einsum_bilinear_transformation_test.onnx @@ -0,0 +1,23 @@ + #einsum_bilinear_transformation_test:· +5 +x1 +x2 +x3y"Einsum* +equation" ik,jkl,il->ij #einsum_bilinear_transformation_testZ +x1 +  + +Z +x2 + + + +Z +x3 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_broadcast_test.onnx b/test/onnx/einsum_broadcast_test.onnx new file mode 100644 index 00000000000..7ead2ed4dcc --- /dev/null +++ b/test/onnx/einsum_broadcast_test.onnx @@ -0,0 +1,17 @@ + einsum_broadcast_test:Š +0 +x1 +x2y"Einsum* +equation" ij, jk -> ik einsum_broadcast_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_column_sum_test.onnx b/test/onnx/einsum_column_sum_test.onnx new file mode 100644 index 00000000000..c280d867dc7 --- /dev/null +++ b/test/onnx/einsum_column_sum_test.onnx @@ -0,0 +1,12 @@ + einsum_column_sum_test:d +$ +xy"Einsum* +equation"ij->j einsum_column_sum_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_comma_in_output_negative_test.onnx b/test/onnx/einsum_comma_in_output_negative_test.onnx new file mode 100644 index 00000000000..f9c413fae79 --- /dev/null +++ b/test/onnx/einsum_comma_in_output_negative_test.onnx @@ -0,0 +1,18 @@ + $einsum_comma_in_output_negative_test:— +. +x1 +x2y"Einsum* +equation" +ii,jj->i,j $einsum_comma_in_output_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_1_test.onnx b/test/onnx/einsum_common_1_test.onnx new file mode 100644 index 00000000000..b2c1d03c4f1 --- /dev/null +++ b/test/onnx/einsum_common_1_test.onnx @@ -0,0 +1,23 @@ + einsum_common_1_test:¤ +3 +x1 +x2y"Einsum* +equation"bsnh,btnh->bnts einsum_common_1_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_2_test.onnx b/test/onnx/einsum_common_2_test.onnx new file mode 100644 index 00000000000..79a4d2a1431 --- /dev/null +++ b/test/onnx/einsum_common_2_test.onnx @@ -0,0 +1,22 @@ + einsum_common_2_test:Ÿ +2 +x1 +x2y"Einsum* +equation"bsnh,ctnh->nts einsum_common_2_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_3_test.onnx b/test/onnx/einsum_common_3_test.onnx new file mode 100644 index 00000000000..359e3a9566e --- /dev/null +++ b/test/onnx/einsum_common_3_test.onnx @@ -0,0 +1,22 @@ + einsum_common_3_test:Ÿ +2 +x1 +x2y"Einsum* +equation"bnst,chst->shn einsum_common_3_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_4_test.onnx b/test/onnx/einsum_common_4_test.onnx new file mode 100644 index 00000000000..63266e23e1e --- /dev/null +++ b/test/onnx/einsum_common_4_test.onnx @@ -0,0 +1,23 @@ + einsum_common_4_test:¤ +3 +x1 +x2y"Einsum* +equation"bcxd,bcyd->bcxy einsum_common_4_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_5_test.onnx b/test/onnx/einsum_common_5_test.onnx new file mode 100644 index 00000000000..1db9d5983b1 --- /dev/null +++ b/test/onnx/einsum_common_5_test.onnx @@ -0,0 +1,23 @@ + einsum_common_5_test:ª +9 +x1 +x2y"Einsum*$ +equation"...qhd,...khd->...hqk einsum_common_5_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_6_test.onnx b/test/onnx/einsum_common_6_test.onnx new file mode 100644 index 00000000000..4b3d088e46c --- /dev/null +++ b/test/onnx/einsum_common_6_test.onnx @@ -0,0 +1,20 @@ + einsum_common_6_test:› +6 +x1 +x2y"Einsum*! +equation"i...k,k...j->i...j einsum_common_6_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_7_test.onnx b/test/onnx/einsum_common_7_test.onnx new file mode 100644 index 00000000000..375d0ae1d44 --- /dev/null +++ b/test/onnx/einsum_common_7_test.onnx @@ -0,0 +1,12 @@ + einsum_common_7_test:f +( +xy"Einsum* +equation" ...j->... einsum_common_7_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_common_8_test.onnx b/test/onnx/einsum_common_8_test.onnx new file mode 100644 index 00000000000..4f1935dfdc3 --- /dev/null +++ b/test/onnx/einsum_common_8_test.onnx @@ -0,0 +1,17 @@ + einsum_common_8_test:† +- +x1 +x2y"Einsum* +equation" ii,jj->ij einsum_common_8_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_diag_vector_multiply_test.onnx b/test/onnx/einsum_diag_vector_multiply_test.onnx new file mode 100644 index 00000000000..ac9e169db0c --- /dev/null +++ b/test/onnx/einsum_diag_vector_multiply_test.onnx @@ -0,0 +1,17 @@ +  einsum_diag_vector_multiply_test:ˆ ++ +x1 +x2y"Einsum* +equation"ii,i->i  einsum_diag_vector_multiply_testZ +x1 +  + +Z +x2 + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_diagonal_dim_mismatch_negative_test.onnx b/test/onnx/einsum_diagonal_dim_mismatch_negative_test.onnx new file mode 100644 index 00000000000..692a71dde96 --- /dev/null +++ b/test/onnx/einsum_diagonal_dim_mismatch_negative_test.onnx @@ -0,0 +1,12 @@ + *einsum_diagonal_dim_mismatch_negative_test:x +$ +xy"Einsum* +equation"ii->i *einsum_diagonal_dim_mismatch_negative_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_element_wise_multiplication_and_row_sum_test.onnx b/test/onnx/einsum_element_wise_multiplication_and_row_sum_test.onnx new file mode 100644 index 00000000000..dd5b8b6eb5a --- /dev/null +++ b/test/onnx/einsum_element_wise_multiplication_and_row_sum_test.onnx @@ -0,0 +1,17 @@ + 3einsum_element_wise_multiplication_and_row_sum_test:› ++ +x1 +x2y"Einsum* +equation"i,ij->i 3einsum_element_wise_multiplication_and_row_sum_testZ +x1 + + +Z +x2 +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_ellipsis_implicit_form_test.onnx b/test/onnx/einsum_ellipsis_implicit_form_test.onnx new file mode 100644 index 00000000000..e706a21d803 --- /dev/null +++ b/test/onnx/einsum_ellipsis_implicit_form_test.onnx @@ -0,0 +1,22 @@ + "einsum_ellipsis_implicit_form_test:¬ +1 +x1 +x2y"Einsum* +equation" ...qhd,...khd "einsum_ellipsis_implicit_form_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_ellipsis_mismatch_negative_test.onnx b/test/onnx/einsum_ellipsis_mismatch_negative_test.onnx new file mode 100644 index 00000000000..c93e6a69c5a --- /dev/null +++ b/test/onnx/einsum_ellipsis_mismatch_negative_test.onnx @@ -0,0 +1,21 @@ + &einsum_ellipsis_mismatch_negative_test:± +6 +x1 +x2y"Einsum*! +equation"...ii,...jj->...ij &einsum_ellipsis_mismatch_negative_testZ +x1 + + + +Z +x2 + + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_ellipsis_multidim_test.onnx b/test/onnx/einsum_ellipsis_multidim_test.onnx new file mode 100644 index 00000000000..d43f99d8dc1 --- /dev/null +++ b/test/onnx/einsum_ellipsis_multidim_test.onnx @@ -0,0 +1,23 @@ + einsum_ellipsis_multidim_test:° +6 +x1 +x2y"Einsum*! +equation"...ik,kj...->ij... einsum_ellipsis_multidim_testZ +x1 + + + + +Z +x2 + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_ellipsis_scalar_multiplication_test.onnx b/test/onnx/einsum_ellipsis_scalar_multiplication_test.onnx new file mode 100644 index 00000000000..f2c64b96d74 --- /dev/null +++ b/test/onnx/einsum_ellipsis_scalar_multiplication_test.onnx @@ -0,0 +1,17 @@ + *einsum_ellipsis_scalar_multiplication_test:› +, +x1 +x2y"Einsum* +equation"..., ... *einsum_ellipsis_scalar_multiplication_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_ellipsis_test.onnx b/test/onnx/einsum_ellipsis_test.onnx new file mode 100644 index 00000000000..3073e777bbe --- /dev/null +++ b/test/onnx/einsum_ellipsis_test.onnx @@ -0,0 +1,20 @@ + einsum_ellipsis_test:› +6 +x1 +x2y"Einsum*! +equation"...ik,kj...->ij... einsum_ellipsis_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_ellipsis_zero_test.onnx b/test/onnx/einsum_ellipsis_zero_test.onnx new file mode 100644 index 00000000000..fc799ae0c72 --- /dev/null +++ b/test/onnx/einsum_ellipsis_zero_test.onnx @@ -0,0 +1,20 @@ + einsum_ellipsis_zero_test:£ +9 +x1 +x2y"Einsum*$ +equation"...qhd,...khd->...hqk einsum_ellipsis_zero_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_empty_term_before_arrow_negative_test.onnx b/test/onnx/einsum_empty_term_before_arrow_negative_test.onnx new file mode 100644 index 00000000000..e558bc7c728 --- /dev/null +++ b/test/onnx/einsum_empty_term_before_arrow_negative_test.onnx @@ -0,0 +1,17 @@ + ,einsum_empty_term_before_arrow_negative_test:œ ++ +x1 +x2y"Einsum* +equation"ii,->ij ,einsum_empty_term_before_arrow_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_empty_term_before_comma_negative_test.onnx b/test/onnx/einsum_empty_term_before_comma_negative_test.onnx new file mode 100644 index 00000000000..fe9b994d712 --- /dev/null +++ b/test/onnx/einsum_empty_term_before_comma_negative_test.onnx @@ -0,0 +1,18 @@ + ,einsum_empty_term_before_comma_negative_test:Ÿ +. +x1 +x2y"Einsum* +equation" +ii,,jj->ij ,einsum_empty_term_before_comma_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_hadamard_product_test.onnx b/test/onnx/einsum_hadamard_product_test.onnx new file mode 100644 index 00000000000..b72c23efb3f --- /dev/null +++ b/test/onnx/einsum_hadamard_product_test.onnx @@ -0,0 +1,17 @@ + einsum_hadamard_product_test:Ž +- +x1 +x2y"Einsum* +equation" ij,ij->ij einsum_hadamard_product_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_last_input_missing_negative_test.onnx b/test/onnx/einsum_last_input_missing_negative_test.onnx new file mode 100644 index 00000000000..2546abfcf3f --- /dev/null +++ b/test/onnx/einsum_last_input_missing_negative_test.onnx @@ -0,0 +1,17 @@ + 'einsum_last_input_missing_negative_test:– +* +x1 +x2y"Einsum* +equation"ii,jj, 'einsum_last_input_missing_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_diagonal_test.onnx b/test/onnx/einsum_matrix_diagonal_test.onnx new file mode 100644 index 00000000000..dc979a40e5a --- /dev/null +++ b/test/onnx/einsum_matrix_diagonal_test.onnx @@ -0,0 +1,12 @@ + einsum_matrix_diagonal_test:i +$ +xy"Einsum* +equation"ii->i einsum_matrix_diagonal_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_dot_product_test.onnx b/test/onnx/einsum_matrix_dot_product_test.onnx new file mode 100644 index 00000000000..a2a048f29b0 --- /dev/null +++ b/test/onnx/einsum_matrix_dot_product_test.onnx @@ -0,0 +1,17 @@ + einsum_matrix_dot_product_test:Š ++ +x1 +x2y"Einsum* +equation"ij,ij-> einsum_matrix_dot_product_testZ +x1 +  + +Z +x2 +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_matrix_multiplication_test.onnx b/test/onnx/einsum_matrix_matrix_multiplication_test.onnx new file mode 100644 index 00000000000..e15eb7e1308 --- /dev/null +++ b/test/onnx/einsum_matrix_matrix_multiplication_test.onnx @@ -0,0 +1,17 @@ + (einsum_matrix_matrix_multiplication_test:š +- +x1 +x2y"Einsum* +equation" ij,kj->ik (einsum_matrix_matrix_multiplication_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_outer_product_test.onnx b/test/onnx/einsum_matrix_outer_product_test.onnx new file mode 100644 index 00000000000..0ba528b0833 --- /dev/null +++ b/test/onnx/einsum_matrix_outer_product_test.onnx @@ -0,0 +1,19 @@ +  einsum_matrix_outer_product_test:œ +/ +x1 +x2y"Einsum* +equation" ij,kl->ijkl  einsum_matrix_outer_product_testZ +x1 +  + +Z +x2 +  + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_trace_implicit_test.onnx b/test/onnx/einsum_matrix_trace_implicit_test.onnx new file mode 100644 index 00000000000..97782fa6558 --- /dev/null +++ b/test/onnx/einsum_matrix_trace_implicit_test.onnx @@ -0,0 +1,12 @@ + !einsum_matrix_trace_implicit_test:l +! +xy"Einsum* +equation"ii !einsum_matrix_trace_implicit_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_trace_test.onnx b/test/onnx/einsum_matrix_trace_test.onnx new file mode 100644 index 00000000000..e805356724b --- /dev/null +++ b/test/onnx/einsum_matrix_trace_test.onnx @@ -0,0 +1,12 @@ + einsum_matrix_trace_test:e +# +xy"Einsum* +equation"ii-> einsum_matrix_trace_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_matrix_vector_multiplication_test.onnx b/test/onnx/einsum_matrix_vector_multiplication_test.onnx new file mode 100644 index 00000000000..931e94ecccb --- /dev/null +++ b/test/onnx/einsum_matrix_vector_multiplication_test.onnx @@ -0,0 +1,17 @@ + (einsum_matrix_vector_multiplication_test: +) +x +vy"Einsum* +equation"ij,j->i (einsum_matrix_vector_multiplication_testZ +x +  + +Z +v + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_missing_equation_negative_test.onnx b/test/onnx/einsum_missing_equation_negative_test.onnx new file mode 100644 index 00000000000..23fae664389 --- /dev/null +++ b/test/onnx/einsum_missing_equation_negative_test.onnx @@ -0,0 +1,16 @@ + %einsum_missing_equation_negative_test:} + +x1 +x2y"Einsum%einsum_missing_equation_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_multiple_arrows_negative_test.onnx b/test/onnx/einsum_multiple_arrows_negative_test.onnx new file mode 100644 index 00000000000..02a7e9fe425 --- /dev/null +++ b/test/onnx/einsum_multiple_arrows_negative_test.onnx @@ -0,0 +1,17 @@ + $einsum_multiple_arrows_negative_test:˜ +/ +x1 +x2y"Einsum* +equation" ii,jj->->ij $einsum_multiple_arrows_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_multiple_diagonals_negative_test.onnx b/test/onnx/einsum_multiple_diagonals_negative_test.onnx new file mode 100644 index 00000000000..022ac5508f4 --- /dev/null +++ b/test/onnx/einsum_multiple_diagonals_negative_test.onnx @@ -0,0 +1,14 @@ + 'einsum_multiple_diagonals_negative_test:„ +' +xy"Einsum* +equation"iijj->ij 'einsum_multiple_diagonals_negative_testZ +x + + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_multiple_ellipses_negative_test.onnx b/test/onnx/einsum_multiple_ellipses_negative_test.onnx new file mode 100644 index 00000000000..289f1eadb70 --- /dev/null +++ b/test/onnx/einsum_multiple_ellipses_negative_test.onnx @@ -0,0 +1,20 @@ + &einsum_multiple_ellipses_negative_test:° +9 +x1 +x2y"Einsum*$ +equation"......ii,...jj->...ij &einsum_multiple_ellipses_negative_testZ +x1 + + + +Z +x2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_output_missing_ellipsis_negative_test.onnx b/test/onnx/einsum_output_missing_ellipsis_negative_test.onnx new file mode 100644 index 00000000000..abb8ffd3edc --- /dev/null +++ b/test/onnx/einsum_output_missing_ellipsis_negative_test.onnx @@ -0,0 +1,19 @@ + ,einsum_output_missing_ellipsis_negative_test:¬ +3 +x1 +x2y"Einsum* +equation"...ii,...jj->ij ,einsum_output_missing_ellipsis_negative_testZ +x1 + + + +Z +x2 + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_output_surplus_label_negative_test.onnx b/test/onnx/einsum_output_surplus_label_negative_test.onnx new file mode 100644 index 00000000000..0ecda44f262 --- /dev/null +++ b/test/onnx/einsum_output_surplus_label_negative_test.onnx @@ -0,0 +1,18 @@ + )einsum_output_surplus_label_negative_test:œ +. +x1 +x2y"Einsum* +equation" +ii,jj->ijk )einsum_output_surplus_label_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_permute_test.onnx b/test/onnx/einsum_permute_test.onnx new file mode 100644 index 00000000000..ece6fc147c4 --- /dev/null +++ b/test/onnx/einsum_permute_test.onnx @@ -0,0 +1,12 @@ + einsum_permute_test:f +% +xy"Einsum* +equation"ij->ji einsum_permute_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_rank_mismatch_negative_test.onnx b/test/onnx/einsum_rank_mismatch_negative_test.onnx new file mode 100644 index 00000000000..c65fe549015 --- /dev/null +++ b/test/onnx/einsum_rank_mismatch_negative_test.onnx @@ -0,0 +1,18 @@ + "einsum_rank_mismatch_negative_test:• +. +x1 +x2y"Einsum* +equation" +iik,jj->ij "einsum_rank_mismatch_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_right_batch_diagonal_negative_test.onnx b/test/onnx/einsum_right_batch_diagonal_negative_test.onnx new file mode 100644 index 00000000000..72143c619b7 --- /dev/null +++ b/test/onnx/einsum_right_batch_diagonal_negative_test.onnx @@ -0,0 +1,13 @@ + )einsum_right_batch_diagonal_negative_test:… +* +xy"Einsum* +equation" ii...->i... )einsum_right_batch_diagonal_negative_testZ +x + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_row_sum_test.onnx b/test/onnx/einsum_row_sum_test.onnx new file mode 100644 index 00000000000..7e2b61abdf2 --- /dev/null +++ b/test/onnx/einsum_row_sum_test.onnx @@ -0,0 +1,12 @@ + einsum_row_sum_test:a +$ +xy"Einsum* +equation"ij->i einsum_row_sum_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_summation_test.onnx b/test/onnx/einsum_summation_test.onnx new file mode 100644 index 00000000000..92bba55d972 --- /dev/null +++ b/test/onnx/einsum_summation_test.onnx @@ -0,0 +1,12 @@ + einsum_summation_test:b +# +xy"Einsum* +equation"ij-> einsum_summation_testZ +x +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_tensor_contraction_test.onnx b/test/onnx/einsum_tensor_contraction_test.onnx new file mode 100644 index 00000000000..0ca55621722 --- /dev/null +++ b/test/onnx/einsum_tensor_contraction_test.onnx @@ -0,0 +1,25 @@ + einsum_tensor_contraction_test:¸ +5 +x1 +x2y"Einsum* +equation"pqrs,tuqvr->pstuv einsum_tensor_contraction_testZ +x1 + + + + +Z +x2 + + + + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/einsum_term_input_mismatch_negative_test.onnx b/test/onnx/einsum_term_input_mismatch_negative_test.onnx new file mode 100644 index 00000000000..eb59ebf5aa0 --- /dev/null +++ b/test/onnx/einsum_term_input_mismatch_negative_test.onnx @@ -0,0 +1,17 @@ + (einsum_term_input_mismatch_negative_test:ž +1 +x1 +x2y"Einsum* +equation" ii,jj,kk->ijk (einsum_term_input_mismatch_negative_testZ +x1 +  + +Z +x2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/einsum_vector_dot_product_test.onnx b/test/onnx/einsum_vector_dot_product_test.onnx new file mode 100644 index 00000000000..1bb949ac3b2 --- /dev/null +++ b/test/onnx/einsum_vector_dot_product_test.onnx @@ -0,0 +1,17 @@ + einsum_vector_dot_product_test:€ +) +x1 +x2y"Einsum* +equation"i,i-> einsum_vector_dot_product_testZ +x1 + + +Z +x2 + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/einsum_vector_outer_product_test.onnx b/test/onnx/einsum_vector_outer_product_test.onnx new file mode 100644 index 00000000000..d91b4e7c056 --- /dev/null +++ b/test/onnx/einsum_vector_outer_product_test.onnx @@ -0,0 +1,17 @@ +  einsum_vector_outer_product_test:ˆ ++ +x1 +x2y"Einsum* +equation"i,j->ij  einsum_vector_outer_product_testZ +x1 + + +Z +x2 + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index e5afa0068fc..e1effe70183 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -2138,6 +2138,749 @@ def dynamicquantizelinear_2d_test(): return ([node], [x], [y, y_scale, y_zero_point]) +@onnx_test() +def einsum_permute_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ij->ji') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_summation_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ij->') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_column_sum_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ij->j') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_row_sum_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ij->i') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_matrix_vector_multiplication_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + v = helper.make_tensor_value_info('v', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 1]) + + node = onnx.helper.make_node('Einsum', + inputs=['x', 'v'], + outputs=['y'], + equation='ij,j->i') + + return ([node], [x, v], [y]) + + +@onnx_test() +def einsum_matrix_matrix_multiplication_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ij,kj->ik') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_vector_dot_product_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='i,i->') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_matrix_dot_product_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ij,ij->') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_hadamard_product_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ij,ij->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_vector_outer_product_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 5]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='i,j->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_matrix_outer_product_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 2, 5]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ij,kl->ijkl') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_batch_matrix_multiplication_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 2, 5]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 5, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ijk,ikl->ijl') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_tensor_contraction_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3, 5, 7]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, + [1, 3, 3, 7, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 7, 1, 3, 7]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='pqrs,tuqvr->pstuv') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_matrix_diagonal_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ii->i') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_batch_matrix_diagonal_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='...ii->...i') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_3d_diagonal_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='iii->i') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_diag_vector_multiply_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,i->i') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_matrix_trace_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ii->') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_matrix_trace_implicit_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ii') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_2d_3d_multiplication_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ij,jkl') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_element_wise_multiplication_and_row_sum_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='i,ij->i') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_broadcast_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 1]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ij, jk -> ik') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_3d_broadcast_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [1, 3, 1]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='bik,bkj->bij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_3d_opposite_broadcast_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [1, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 1, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='bik,bkj->bij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_3_inputs_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 2, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2]) + x3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2', 'x3'], + outputs=['y'], + equation='bac,cd,def->ebc') + + return ([node], [x1, x2, x3], [y]) + + +@onnx_test() +def einsum_bilinear_transformation_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [5, 3, 7]) + x3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [2, 7]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 5]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2', 'x3'], + outputs=['y'], + equation='ik,jkl,il->ij') + + return ([node], [x1, x2, x3], [y]) + + +@onnx_test() +def einsum_ellipsis_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 4, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...ik,kj...->ij...') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_ellipsis_multidim_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 2, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 4, 3, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 3, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...ik,kj...->ij...') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_ellipsis_zero_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [4, 3, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2, 4]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...qhd,...khd->...hqk') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_ellipsis_implicit_form_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 2, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 4, 3, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...qhd,...khd') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_ellipsis_scalar_multiplication_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='..., ...') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_1_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 2, 2, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='bsnh,btnh->bnts') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_2_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 2, 2, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='bsnh,ctnh->nts') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_3_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 2, 2, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='bnst,chst->shn') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_4_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 2, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2, 4, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3, 4]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='bcxd,bcyd->bcxy') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_5_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 2, 3, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 4, 3, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3, 2, 4]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...qhd,...khd->...hqk') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_6_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 2, 2]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='i...k,k...j->i...j') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_common_7_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='...j->...') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_common_8_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,jj->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_missing_equation_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', inputs=['x1', 'x2'], outputs=['y']) + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_multiple_arrows_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,jj->->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_empty_term_before_arrow_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_multiple_ellipses_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='......ii,...jj->...ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_comma_in_output_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,jj->i,j') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_empty_term_before_comma_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,,jj->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_last_input_missing_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,jj,') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_term_input_mismatch_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,jj,kk->ijk') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_ellipsis_mismatch_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...ii,...jj->...ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_rank_mismatch_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='iik,jj->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_output_surplus_label_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='ii,jj->ijk') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_output_missing_ellipsis_negative_test(): + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3, 3, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x1', 'x2'], + outputs=['y'], + equation='...ii,...jj->ij') + + return ([node], [x1, x2], [y]) + + +@onnx_test() +def einsum_multiple_diagonals_negative_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='iijj->ij') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_diagonal_dim_mismatch_negative_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ii->i') + + return ([node], [x], [y]) + + +@onnx_test() +def einsum_right_batch_diagonal_negative_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3]) + + node = onnx.helper.make_node('Einsum', + inputs=['x'], + outputs=['y'], + equation='ii...->i...') + + return ([node], [x], [y]) + + @onnx_test() def elu_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) diff --git a/test/onnx/include/onnx_test.hpp b/test/onnx/include/onnx_test.hpp index aa753444f22..0281059d534 100644 --- a/test/onnx/include/onnx_test.hpp +++ b/test/onnx/include/onnx_test.hpp @@ -35,14 +35,49 @@ #include #include #include - +#include +#include +#include +#include +#include #include +inline static std::string +read_weight_files(const std::unordered_map& onnx_files) +{ + static migraphx::tmp_dir td{"weights"}; + for(const auto& i : onnx_files) + { + if(not migraphx::ends_with(std::string{i.first}, "weight")) + continue; + migraphx::fs::path full_path = td.path / i.first; + migraphx::fs::path parent_path = full_path.parent_path(); + migraphx::fs::create_directories(parent_path); + migraphx::write_buffer(full_path, i.second.data(), i.second.size()); + } + return td.path.string(); +} + +inline migraphx::program read_onnx(const std::string& name, + migraphx::onnx_options options = migraphx::onnx_options{}) +{ + static auto onnx_files{::onnx_files()}; + static std::string external_data_path = read_weight_files(onnx_files); + options.external_data_path = external_data_path; + if(onnx_files.find(name) == onnx_files.end()) + { + std::cerr << "ONNX model file: " << name << " not found, aborting the test." << std::endl; + std::abort(); + } + auto prog = migraphx::parse_onnx_buffer(std::string{onnx_files.at(name)}, options); + return prog; +} + inline migraphx::program optimize_onnx(const std::string& name, bool run_passes = false) { migraphx::onnx_options options; options.skip_unknown_operators = true; - auto prog = migraphx::parse_onnx(name, options); + auto prog = read_onnx(name, options); auto* mm = prog.get_main_module(); if(run_passes) migraphx::run_passes(*mm, diff --git a/test/onnx/include/onnx_test_utils.hpp b/test/onnx/include/onnx_test_utils.hpp index 59ae33ab0ae..f6ebbb150f8 100644 --- a/test/onnx/include/onnx_test_utils.hpp +++ b/test/onnx/include/onnx_test_utils.hpp @@ -25,6 +25,7 @@ #ifndef MIGRAPHX_GUARD_TEST_ONNX_ONNX_TEST_UTILS_HPP #define MIGRAPHX_GUARD_TEST_ONNX_ONNX_TEST_UTILS_HPP +#include #include #include #include @@ -395,7 +396,7 @@ inline void scatter_test_base(const std::string& reduction, int axis, const std: auto r = mm->add_instruction( migraphx::make_op("scatter_" + reduction, {{"axis", axis}}), l0, l1, l2); mm->add_return({r}); - auto prog = migraphx::parse_onnx(onnx_file); + auto prog = read_onnx(onnx_file); EXPECT(p == prog); } diff --git a/test/onnx/onnx_rnn_test.cpp b/test/onnx/onnx_rnn_test.cpp index 80b310043ea..6e2da606388 100644 --- a/test/onnx/onnx_rnn_test.cpp +++ b/test/onnx/onnx_rnn_test.cpp @@ -34,12 +34,12 @@ #include #include - +#include #include "test.hpp" -migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = true) +migraphx::program read_rnn_onnx(const std::string& name, bool eliminate_deadcode = true) { - auto prog = migraphx::parse_onnx(name); + auto prog = read_onnx(name); auto* mm = prog.get_main_module(); if(eliminate_deadcode) migraphx::run_passes(*mm, {migraphx::dead_code_elimination{}}); @@ -95,7 +95,7 @@ TEST_CASE(rnn_test_bidirectional) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_rnn_bi.onnx"); + auto prog = read_rnn_onnx("onnx_rnn_bi.onnx"); EXPECT(p == prog); } @@ -149,7 +149,7 @@ TEST_CASE(rnn_test_bidirectional_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("rnn_bi_layout_test.onnx"); + auto prog = read_rnn_onnx("rnn_bi_layout_test.onnx"); EXPECT(p == prog); } @@ -195,7 +195,7 @@ TEST_CASE(rnn_test_one_direction) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_rnn_forward.onnx"); + auto prog = read_rnn_onnx("onnx_rnn_forward.onnx"); EXPECT(p == prog); } @@ -224,7 +224,7 @@ TEST_CASE(rnn_test_one_direction) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("rnn_f_default_af_test.onnx"); + auto prog = read_rnn_onnx("rnn_f_default_af_test.onnx"); EXPECT(p == prog); } @@ -254,7 +254,7 @@ TEST_CASE(rnn_test_one_direction) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_rnn_reverse.onnx"); + auto prog = read_rnn_onnx("onnx_rnn_reverse.onnx"); EXPECT(p == prog); } @@ -282,7 +282,7 @@ TEST_CASE(rnn_test_one_direction) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_rnn_3args.onnx"); + auto prog = read_rnn_onnx("onnx_rnn_3args.onnx"); EXPECT(p == prog); } @@ -314,7 +314,7 @@ TEST_CASE(rnn_test_one_direction) seq_len, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_rnn_5args.onnx"); + auto prog = read_rnn_onnx("onnx_rnn_5args.onnx"); EXPECT(p == prog); } @@ -369,7 +369,7 @@ TEST_CASE(rnn_test_one_direction_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("rnn_f_layout_test.onnx"); + auto prog = read_rnn_onnx("rnn_f_layout_test.onnx"); EXPECT(p == prog); } @@ -408,7 +408,7 @@ TEST_CASE(rnn_test_one_direction_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("rnn_r_layout_test.onnx"); + auto prog = read_rnn_onnx("rnn_r_layout_test.onnx"); EXPECT(p == prog); } @@ -444,7 +444,7 @@ TEST_CASE(rnn_test_one_direction_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("rnn_r_3arg_layout_test.onnx"); + auto prog = read_rnn_onnx("rnn_r_3arg_layout_test.onnx"); EXPECT(p == prog); } @@ -483,7 +483,7 @@ TEST_CASE(rnn_test_one_direction_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("rnn_f_5arg_layout_test.onnx"); + auto prog = read_rnn_onnx("rnn_f_5arg_layout_test.onnx"); EXPECT(p == prog); } @@ -538,7 +538,7 @@ TEST_CASE(gru_test) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_forward.onnx"); + auto prog = read_rnn_onnx("onnx_gru_forward.onnx"); EXPECT(p == prog); } @@ -578,7 +578,7 @@ TEST_CASE(gru_test) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_reverse.onnx"); + auto prog = read_rnn_onnx("onnx_gru_reverse.onnx"); EXPECT(p == prog); } @@ -620,7 +620,7 @@ TEST_CASE(gru_test) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_bi.onnx"); + auto prog = read_rnn_onnx("onnx_gru_bi.onnx"); EXPECT(p == prog); } @@ -678,7 +678,7 @@ TEST_CASE(gru_layout_test) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("gru_f_layout_test.onnx"); + auto prog = read_rnn_onnx("gru_f_layout_test.onnx"); EXPECT(p == prog); } @@ -726,7 +726,7 @@ TEST_CASE(gru_layout_test) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("gru_r_layout_test.onnx"); + auto prog = read_rnn_onnx("gru_r_layout_test.onnx"); EXPECT(p == prog); } @@ -776,7 +776,7 @@ TEST_CASE(gru_layout_test) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("gru_bi_layout_test.onnx"); + auto prog = read_rnn_onnx("gru_bi_layout_test.onnx"); EXPECT(p == prog); } @@ -820,7 +820,7 @@ TEST_CASE(gru_test_args) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_3arg.onnx"); + auto prog = read_rnn_onnx("onnx_gru_3arg.onnx"); EXPECT(p == prog); } @@ -857,7 +857,7 @@ TEST_CASE(gru_test_args) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_4arg.onnx"); + auto prog = read_rnn_onnx("onnx_gru_4arg.onnx"); EXPECT(p == prog); } @@ -898,7 +898,7 @@ TEST_CASE(gru_test_args) seq_len, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_5arg.onnx"); + auto prog = read_rnn_onnx("onnx_gru_5arg.onnx"); EXPECT(p == prog); } @@ -950,7 +950,7 @@ TEST_CASE(gru_test_args_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("gru_f_3arg_layout_test.onnx"); + auto prog = read_rnn_onnx("gru_f_3arg_layout_test.onnx"); EXPECT(p == prog); } @@ -995,7 +995,7 @@ TEST_CASE(gru_test_args_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("gru_r_4arg_layout_test.onnx"); + auto prog = read_rnn_onnx("gru_r_4arg_layout_test.onnx"); EXPECT(p == prog); } @@ -1044,7 +1044,7 @@ TEST_CASE(gru_test_args_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("gru_bi_5arg_layout_test.onnx"); + auto prog = read_rnn_onnx("gru_bi_5arg_layout_test.onnx"); EXPECT(p == prog); } @@ -1095,7 +1095,7 @@ TEST_CASE(gru_test_actv_funcs) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_bi_0.onnx"); + auto prog = read_rnn_onnx("onnx_gru_bi_0.onnx"); EXPECT(p == prog); } @@ -1138,7 +1138,7 @@ TEST_CASE(gru_test_actv_funcs) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_bi_1.onnx"); + auto prog = read_rnn_onnx("onnx_gru_bi_1.onnx"); EXPECT(p == prog); } @@ -1180,7 +1180,7 @@ TEST_CASE(gru_test_actv_funcs) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_bi_2.onnx"); + auto prog = read_rnn_onnx("onnx_gru_bi_2.onnx"); EXPECT(p == prog); } @@ -1220,7 +1220,7 @@ TEST_CASE(gru_test_actv_funcs) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_forward_0.onnx"); + auto prog = read_rnn_onnx("onnx_gru_forward_0.onnx"); EXPECT(p == prog); } @@ -1260,7 +1260,7 @@ TEST_CASE(gru_test_actv_funcs) seq_len, ih); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_gru_reverse_1.onnx"); + auto prog = read_rnn_onnx("onnx_gru_reverse_1.onnx"); EXPECT(p == prog); } @@ -1319,7 +1319,7 @@ TEST_CASE(lstm_forward) ic, pph); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_forward.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_forward.onnx"); EXPECT(p == prog); } @@ -1353,7 +1353,7 @@ TEST_CASE(lstm_forward) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f3args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f3args.onnx"); EXPECT(p == prog); } @@ -1386,7 +1386,7 @@ TEST_CASE(lstm_forward) und, und, und); - auto prog = optimize_onnx("onnx_lstm_hs.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_hs.onnx"); EXPECT(p == prog); } @@ -1420,7 +1420,7 @@ TEST_CASE(lstm_forward) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_last.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_last.onnx"); EXPECT(p == prog); } @@ -1454,7 +1454,7 @@ TEST_CASE(lstm_forward) und, und); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_cell.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_cell.onnx"); EXPECT(p == prog); } @@ -1489,7 +1489,7 @@ TEST_CASE(lstm_forward) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f4args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f4args.onnx"); EXPECT(p == prog); } @@ -1526,7 +1526,7 @@ TEST_CASE(lstm_forward) und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f5args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f5args.onnx"); EXPECT(p == prog); } @@ -1564,7 +1564,7 @@ TEST_CASE(lstm_forward) und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f6args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f6args.onnx"); EXPECT(p == prog); } @@ -1603,7 +1603,7 @@ TEST_CASE(lstm_forward) und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f7args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f7args.onnx"); EXPECT(p == prog); } @@ -1669,7 +1669,7 @@ TEST_CASE(lstm_forward_layout) out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx"); + auto prog = read_rnn_onnx("lstm_f_layout_hs_test.onnx"); EXPECT(p == prog); } @@ -1712,7 +1712,7 @@ TEST_CASE(lstm_forward_layout) pph); auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); - auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx"); + auto prog = read_rnn_onnx("lstm_f_layout_cell_test.onnx"); EXPECT(p == prog); } @@ -1763,7 +1763,7 @@ TEST_CASE(lstm_forward_actv_func) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f0af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f0af.onnx"); EXPECT(p == prog); } @@ -1799,7 +1799,7 @@ TEST_CASE(lstm_forward_actv_func) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f1af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f1af.onnx"); EXPECT(p == prog); } @@ -1837,7 +1837,7 @@ TEST_CASE(lstm_forward_actv_func) und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_f2af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_f2af.onnx"); EXPECT(p == prog); } @@ -1891,7 +1891,7 @@ TEST_CASE(lstm_reverse) ic, pph); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_reverse.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_reverse.onnx"); EXPECT(p == prog); } @@ -1928,7 +1928,7 @@ TEST_CASE(lstm_reverse) und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_r5args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_r5args.onnx"); EXPECT(p == prog); } @@ -1962,7 +1962,7 @@ TEST_CASE(lstm_reverse) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_r0af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_r0af.onnx"); EXPECT(p == prog); } @@ -2025,7 +2025,7 @@ TEST_CASE(lstm_reverse_layout) std::vector perm_hid{2, 0, 1, 3}; out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); - auto prog = optimize_onnx("lstm_r_layout_test.onnx"); + auto prog = read_rnn_onnx("lstm_r_layout_test.onnx"); EXPECT(p == prog); } @@ -2073,7 +2073,7 @@ TEST_CASE(lstm_reverse_layout) last_output); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); - auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx"); + auto prog = read_rnn_onnx("lstm_r_layout_hs_cell_test.onnx"); EXPECT(p == prog); } @@ -2130,7 +2130,7 @@ TEST_CASE(lstm_bidirectional) ic, pph); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi.onnx"); EXPECT(p == prog); } @@ -2167,7 +2167,7 @@ TEST_CASE(lstm_bidirectional) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi3args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi3args.onnx"); EXPECT(p == prog); } @@ -2205,7 +2205,7 @@ TEST_CASE(lstm_bidirectional) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi4args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi4args.onnx"); EXPECT(p == prog); } @@ -2244,7 +2244,7 @@ TEST_CASE(lstm_bidirectional) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi5args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi5args.onnx"); EXPECT(p == prog); } @@ -2284,7 +2284,7 @@ TEST_CASE(lstm_bidirectional) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi6args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi6args.onnx"); EXPECT(p == prog); } @@ -2325,7 +2325,7 @@ TEST_CASE(lstm_bidirectional) ic, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi7args.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi7args.onnx"); EXPECT(p == prog); } @@ -2392,7 +2392,7 @@ TEST_CASE(lstm_bidirectional_layout) out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); - auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx"); + auto prog = read_rnn_onnx("lstm_bi_layout_last_test.onnx"); EXPECT(p == prog); } @@ -2437,7 +2437,7 @@ TEST_CASE(lstm_bidirectional_layout) pph); auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); - auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx"); + auto prog = read_rnn_onnx("lstm_bi_layout_cell_test.onnx"); EXPECT(p == prog); } @@ -2492,7 +2492,7 @@ TEST_CASE(lstm_bi_actv_funcs) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi0af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi0af.onnx"); EXPECT(p == prog); } @@ -2531,7 +2531,7 @@ TEST_CASE(lstm_bi_actv_funcs) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi1af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi1af.onnx"); EXPECT(p == prog); } @@ -2570,7 +2570,7 @@ TEST_CASE(lstm_bi_actv_funcs) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi2af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi2af.onnx"); EXPECT(p == prog); } @@ -2610,7 +2610,7 @@ TEST_CASE(lstm_bi_actv_funcs) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi3af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi3af.onnx"); EXPECT(p == prog); } @@ -2651,7 +2651,7 @@ TEST_CASE(lstm_bi_actv_funcs) ic, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi4af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi4af.onnx"); EXPECT(p == prog); } @@ -2688,7 +2688,7 @@ TEST_CASE(lstm_bi_actv_funcs) und, und); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); - auto prog = optimize_onnx("onnx_lstm_bi5af.onnx"); + auto prog = read_rnn_onnx("onnx_lstm_bi5af.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/add_scalar_test.cpp b/test/onnx/parse/add_scalar_test.cpp index 67adc677e73..c08059c52f3 100644 --- a/test/onnx/parse/add_scalar_test.cpp +++ b/test/onnx/parse/add_scalar_test.cpp @@ -32,9 +32,8 @@ TEST_CASE(add_scalar_test) auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint8_type}); auto m1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1); - auto r = mm->add_instruction(migraphx::make_op("add"), l0, m1); - mm->add_return({r}); - auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); + mm->add_instruction(migraphx::make_op("add"), l0, m1); + auto prog = optimize_onnx("add_scalar_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/argmax_dyn_test.cpp b/test/onnx/parse/argmax_dyn_test.cpp index e9e66c250dc..f1fc595afca 100644 --- a/test/onnx/parse/argmax_dyn_test.cpp +++ b/test/onnx/parse/argmax_dyn_test.cpp @@ -36,7 +36,7 @@ TEST_CASE(argmax_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("argmax_dyn_test.onnx", options); + auto prog = read_onnx("argmax_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_dyn_asym_padding_error_test.cpp b/test/onnx/parse/averagepool_dyn_asym_padding_error_test.cpp index 429993a5828..69984894827 100644 --- a/test/onnx/parse/averagepool_dyn_asym_padding_error_test.cpp +++ b/test/onnx/parse/averagepool_dyn_asym_padding_error_test.cpp @@ -28,6 +28,6 @@ TEST_CASE(averagepool_dyn_asym_padding_error_test) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("averagepool_dyn_asym_padding_error_test.onnx", options); })); + EXPECT( + test::throws([&] { read_onnx("averagepool_dyn_asym_padding_error_test.onnx", options); })); } diff --git a/test/onnx/parse/averagepool_dyn_autopad_test.cpp b/test/onnx/parse/averagepool_dyn_autopad_test.cpp index 25df78589ef..891d7b25afb 100644 --- a/test/onnx/parse/averagepool_dyn_autopad_test.cpp +++ b/test/onnx/parse/averagepool_dyn_autopad_test.cpp @@ -47,6 +47,6 @@ TEST_CASE(averagepool_dyn_autopad_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("averagepool_dyn_autopad_test.onnx", options); + auto prog = read_onnx("averagepool_dyn_autopad_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_dyn_cip_error_test.cpp b/test/onnx/parse/averagepool_dyn_cip_error_test.cpp index 470b069edd9..d847c7ad5f6 100644 --- a/test/onnx/parse/averagepool_dyn_cip_error_test.cpp +++ b/test/onnx/parse/averagepool_dyn_cip_error_test.cpp @@ -28,6 +28,5 @@ TEST_CASE(averagepool_dyn_cip_error_test) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("averagepool_dyn_cip_error_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("averagepool_dyn_cip_error_test.onnx", options); })); } diff --git a/test/onnx/parse/averagepool_dyn_test.cpp b/test/onnx/parse/averagepool_dyn_test.cpp index aa26b6b2034..ce14078edc7 100644 --- a/test/onnx/parse/averagepool_dyn_test.cpp +++ b/test/onnx/parse/averagepool_dyn_test.cpp @@ -47,6 +47,6 @@ TEST_CASE(averagepool_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("averagepool_dyn_test.onnx", options); + auto prog = read_onnx("averagepool_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_notset_test.cpp b/test/onnx/parse/averagepool_notset_test.cpp index 55cea093667..0f2f8fb1a61 100644 --- a/test/onnx/parse/averagepool_notset_test.cpp +++ b/test/onnx/parse/averagepool_notset_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(averagepool_notset_test) auto ret = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("averagepool_notset_test.onnx"); + auto prog = read_onnx("averagepool_notset_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_nt_cip_test.cpp b/test/onnx/parse/averagepool_nt_cip_test.cpp index 1b72f0d05dd..ae69c6b8cf1 100644 --- a/test/onnx/parse/averagepool_nt_cip_test.cpp +++ b/test/onnx/parse/averagepool_nt_cip_test.cpp @@ -41,6 +41,6 @@ TEST_CASE(averagepool_nt_cip_test) ins_pad); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("averagepool_nt_cip_test.onnx"); + auto prog = read_onnx("averagepool_nt_cip_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_same_lower_test.cpp b/test/onnx/parse/averagepool_same_lower_test.cpp index 87a77892d93..4e0331ec564 100644 --- a/test/onnx/parse/averagepool_same_lower_test.cpp +++ b/test/onnx/parse/averagepool_same_lower_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(averagepool_same_lower_test) auto ret = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {0, 0}}, {"ends", {5, 5}}}), ins); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("averagepool_same_lower_test.onnx"); + auto prog = read_onnx("averagepool_same_lower_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_same_upper_test.cpp b/test/onnx/parse/averagepool_same_upper_test.cpp index 9278da5920e..7d1a7bb1b2a 100644 --- a/test/onnx/parse/averagepool_same_upper_test.cpp +++ b/test/onnx/parse/averagepool_same_upper_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(averagepool_same_upper_test) auto ret = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("averagepool_same_upper_test.onnx"); + auto prog = read_onnx("averagepool_same_upper_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/averagepool_sl_cip_test.cpp b/test/onnx/parse/averagepool_sl_cip_test.cpp index 3b4ca36dd3f..8ac7940704f 100644 --- a/test/onnx/parse/averagepool_sl_cip_test.cpp +++ b/test/onnx/parse/averagepool_sl_cip_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(averagepool_sl_cip_test) {"dilations", {1, 1}}}), ins_pad); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx"); + auto prog = read_onnx("averagepool_sl_cip_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/binary_dyn_brcst_add_test.cpp b/test/onnx/parse/binary_dyn_brcst_add_test.cpp index dce3ab61037..ac7a64885c4 100644 --- a/test/onnx/parse/binary_dyn_brcst_add_test.cpp +++ b/test/onnx/parse/binary_dyn_brcst_add_test.cpp @@ -37,7 +37,7 @@ TEST_CASE(binary_dyn_brcst_add_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("binary_dyn_brcst_add_test.onnx", options); + auto prog = read_onnx("binary_dyn_brcst_add_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/binary_dyn_brcst_attr_error_test.cpp b/test/onnx/parse/binary_dyn_brcst_attr_error_test.cpp index 90316099c2a..4b0f1fed564 100644 --- a/test/onnx/parse/binary_dyn_brcst_attr_error_test.cpp +++ b/test/onnx/parse/binary_dyn_brcst_attr_error_test.cpp @@ -28,6 +28,5 @@ TEST_CASE(binary_dyn_brcst_attr_error_test) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("binary_dyn_brcst_attr_error_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("binary_dyn_brcst_attr_error_test.onnx", options); })); } diff --git a/test/onnx/parse/binary_dyn_brcst_mul_test.cpp b/test/onnx/parse/binary_dyn_brcst_mul_test.cpp index 51c05aa60de..f146c3cbd77 100644 --- a/test/onnx/parse/binary_dyn_brcst_mul_test.cpp +++ b/test/onnx/parse/binary_dyn_brcst_mul_test.cpp @@ -47,7 +47,7 @@ TEST_CASE(binary_dyn_brcst_mul_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("binary_dyn_brcst_mul_test.onnx", options); + auto prog = read_onnx("binary_dyn_brcst_mul_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/binary_dyn_brcst_prelu_test.cpp b/test/onnx/parse/binary_dyn_brcst_prelu_test.cpp index 2b8c6f37b02..e0fda227858 100644 --- a/test/onnx/parse/binary_dyn_brcst_prelu_test.cpp +++ b/test/onnx/parse/binary_dyn_brcst_prelu_test.cpp @@ -37,7 +37,7 @@ TEST_CASE(binary_dyn_brcst_prelu_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("binary_dyn_brcst_prelu_test.onnx", options); + auto prog = read_onnx("binary_dyn_brcst_prelu_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/castlike_error_test.cpp b/test/onnx/parse/castlike_error_test.cpp index 942b3a8a4f2..ae3a0f3d68a 100644 --- a/test/onnx/parse/castlike_error_test.cpp +++ b/test/onnx/parse/castlike_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(castlike_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("castlike_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("castlike_error_test.onnx"); })); } diff --git a/test/onnx/parse/celu_wrong_type_test.cpp b/test/onnx/parse/celu_wrong_type_test.cpp index e15f7f7d092..8a853275f8d 100644 --- a/test/onnx/parse/celu_wrong_type_test.cpp +++ b/test/onnx/parse/celu_wrong_type_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(celu_wrong_type_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("celu_wrong_type_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("celu_wrong_type_test.onnx"); })); } diff --git a/test/onnx/parse/celu_zero_alpha_test.cpp b/test/onnx/parse/celu_zero_alpha_test.cpp index 258bc698df0..66c532ff252 100644 --- a/test/onnx/parse/celu_zero_alpha_test.cpp +++ b/test/onnx/parse/celu_zero_alpha_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(celu_zero_alpha_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("celu_zero_alpha_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("celu_zero_alpha_test.onnx"); })); } diff --git a/test/onnx/parse/clip_dyn_min_max_test.cpp b/test/onnx/parse/clip_dyn_min_max_test.cpp index db3b1935f2d..c7d250396e2 100644 --- a/test/onnx/parse/clip_dyn_min_max_test.cpp +++ b/test/onnx/parse/clip_dyn_min_max_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(clip_dyn_min_max_test) migraphx::onnx_options options; options.default_dyn_dim_value = {2, 8, {3}}; - auto prog = parse_onnx("clip_dyn_min_max_test.onnx", options); + auto prog = read_onnx("clip_dyn_min_max_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/clip_dyn_min_only_test.cpp b/test/onnx/parse/clip_dyn_min_only_test.cpp index 66d2c1236d6..de72642cec9 100644 --- a/test/onnx/parse/clip_dyn_min_only_test.cpp +++ b/test/onnx/parse/clip_dyn_min_only_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(clip_dyn_min_only_test) migraphx::onnx_options options; options.default_dyn_dim_value = {2, 8, {3}}; - auto prog = parse_onnx("clip_dyn_min_only_test.onnx", options); + auto prog = read_onnx("clip_dyn_min_only_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/clip_test_args_type_mismatch.cpp b/test/onnx/parse/clip_test_args_type_mismatch.cpp index 4cc6af80470..c4f07ff8552 100644 --- a/test/onnx/parse/clip_test_args_type_mismatch.cpp +++ b/test/onnx/parse/clip_test_args_type_mismatch.cpp @@ -42,6 +42,6 @@ TEST_CASE(clip_test_args_type_mismatch) migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), max_val); auto r = mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_return({r}); - auto prog = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx"); + auto prog = read_onnx("clip_test_args_type_mismatch.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/clip_test_op11_max_only.cpp b/test/onnx/parse/clip_test_op11_max_only.cpp index c75afd4a23f..f62bdec660d 100644 --- a/test/onnx/parse/clip_test_op11_max_only.cpp +++ b/test/onnx/parse/clip_test_op11_max_only.cpp @@ -36,7 +36,7 @@ TEST_CASE(clip_test_op11_max_only) auto r = mm->add_instruction(migraphx::make_op("min"), l0, max_val); mm->add_return({r}); - auto prog = migraphx::parse_onnx("clip_test_op11_max_only.onnx"); + auto prog = read_onnx("clip_test_op11_max_only.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/clip_test_op11_no_args1.cpp b/test/onnx/parse/clip_test_op11_no_args1.cpp index 60817594f05..69b9a50fb26 100644 --- a/test/onnx/parse/clip_test_op11_no_args1.cpp +++ b/test/onnx/parse/clip_test_op11_no_args1.cpp @@ -33,7 +33,7 @@ TEST_CASE(clip_test_op11_no_args1) mm->add_instruction(migraphx::make_op("undefined")); auto r = mm->add_instruction(migraphx::make_op("identity"), l0); mm->add_return({r}); - auto prog = migraphx::parse_onnx("clip_test_op11_no_args1.onnx"); + auto prog = read_onnx("clip_test_op11_no_args1.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/concat_dyn_test.cpp b/test/onnx/parse/concat_dyn_test.cpp index e82b855450f..6e2a28558ed 100644 --- a/test/onnx/parse/concat_dyn_test.cpp +++ b/test/onnx/parse/concat_dyn_test.cpp @@ -38,7 +38,7 @@ TEST_CASE(concat_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("concat_dyn_test.onnx", options); + auto prog = read_onnx("concat_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/const_of_shape_dyn_float_test.cpp b/test/onnx/parse/const_of_shape_dyn_float_test.cpp index 065cc199735..2d6e8ec6344 100644 --- a/test/onnx/parse/const_of_shape_dyn_float_test.cpp +++ b/test/onnx/parse/const_of_shape_dyn_float_test.cpp @@ -38,6 +38,6 @@ TEST_CASE(const_of_shape_dyn_float_test) mm->add_return({fill_ins}); migraphx::onnx_options options; - auto prog = parse_onnx("const_of_shape_dyn_float_test.onnx", options); + auto prog = read_onnx("const_of_shape_dyn_float_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/const_of_shape_dyn_int64_test.cpp b/test/onnx/parse/const_of_shape_dyn_int64_test.cpp index 8daf30b34a1..1642d2c2c53 100644 --- a/test/onnx/parse/const_of_shape_dyn_int64_test.cpp +++ b/test/onnx/parse/const_of_shape_dyn_int64_test.cpp @@ -38,6 +38,6 @@ TEST_CASE(const_of_shape_dyn_int64_test) mm->add_return({fill_ins}); migraphx::onnx_options options; - auto prog = parse_onnx("const_of_shape_dyn_int64_test.onnx", options); + auto prog = read_onnx("const_of_shape_dyn_int64_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_attr_fail_test.cpp b/test/onnx/parse/conv_attr_fail_test.cpp index 82e75d6a940..99c5d01ee98 100644 --- a/test/onnx/parse/conv_attr_fail_test.cpp +++ b/test/onnx/parse/conv_attr_fail_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(conv_attr_fail_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("conv_attr_fail_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("conv_attr_fail_test.onnx"); })); } diff --git a/test/onnx/parse/conv_bad_bias_test.cpp b/test/onnx/parse/conv_bad_bias_test.cpp index c0ffac73dc6..aec4b14f4e1 100644 --- a/test/onnx/parse/conv_bad_bias_test.cpp +++ b/test/onnx/parse/conv_bad_bias_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(conv_bad_bias_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("conv_bad_bias_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("conv_bad_bias_test.onnx"); })); } diff --git a/test/onnx/parse/conv_dynamic_batch_same_upper_test.cpp b/test/onnx/parse/conv_dynamic_batch_same_upper_test.cpp index 9d37678bbe8..8b3b83374ff 100644 --- a/test/onnx/parse/conv_dynamic_batch_same_upper_test.cpp +++ b/test/onnx/parse/conv_dynamic_batch_same_upper_test.cpp @@ -41,6 +41,6 @@ TEST_CASE(conv_dynamic_batch_same_upper) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto prog = migraphx::parse_onnx("conv_dynamic_batch_same_upper_test.onnx", options); + auto prog = read_onnx("conv_dynamic_batch_same_upper_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_batch_test.cpp b/test/onnx/parse/conv_dynamic_batch_test.cpp index 0cb0c906ecc..1aaf2509530 100644 --- a/test/onnx/parse/conv_dynamic_batch_test.cpp +++ b/test/onnx/parse/conv_dynamic_batch_test.cpp @@ -41,6 +41,6 @@ TEST_CASE(conv_dynamic_batch_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 6}; - auto prog = migraphx::parse_onnx("conv_dynamic_batch_test.onnx", options); + auto prog = read_onnx("conv_dynamic_batch_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_bias_test.cpp b/test/onnx/parse/conv_dynamic_bias_test.cpp index 20f8d91464d..1f3b5c6090e 100644 --- a/test/onnx/parse/conv_dynamic_bias_test.cpp +++ b/test/onnx/parse/conv_dynamic_bias_test.cpp @@ -39,6 +39,6 @@ TEST_CASE(conv_dynamic_bias_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 6}; - auto prog = migraphx::parse_onnx("conv_dynamic_bias_test.onnx", options); + auto prog = read_onnx("conv_dynamic_bias_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_img_and_weights_test.cpp b/test/onnx/parse/conv_dynamic_img_and_weights_test.cpp index 8797ffc83d8..0ce49a1f599 100644 --- a/test/onnx/parse/conv_dynamic_img_and_weights_test.cpp +++ b/test/onnx/parse/conv_dynamic_img_and_weights_test.cpp @@ -43,6 +43,6 @@ TEST_CASE(conv_dynamic_img_and_weights_test) options.default_dyn_dim_value = {5, 10}; options.map_dyn_input_dims["1"] = {{1, 1}, {3, 3}, {2, 4}, {2, 4}}; - auto prog = migraphx::parse_onnx("conv_dynamic_img_and_weights_test.onnx", options); + auto prog = read_onnx("conv_dynamic_img_and_weights_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_img_same_upper_test.cpp b/test/onnx/parse/conv_dynamic_img_same_upper_test.cpp index 3033f4c8d54..2c0d93ed747 100644 --- a/test/onnx/parse/conv_dynamic_img_same_upper_test.cpp +++ b/test/onnx/parse/conv_dynamic_img_same_upper_test.cpp @@ -45,6 +45,6 @@ TEST_CASE(conv_dynamic_img_same_upper) migraphx::onnx_options options; options.default_dyn_dim_value = {5, 10}; - auto prog = migraphx::parse_onnx("conv_dynamic_img_same_upper_test.onnx", options); + auto prog = read_onnx("conv_dynamic_img_same_upper_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_img_test.cpp b/test/onnx/parse/conv_dynamic_img_test.cpp index d6156fd362b..a02ffbebc1b 100644 --- a/test/onnx/parse/conv_dynamic_img_test.cpp +++ b/test/onnx/parse/conv_dynamic_img_test.cpp @@ -41,6 +41,6 @@ TEST_CASE(conv_dynamic_img_test) migraphx::onnx_options options; options.default_dyn_dim_value = {5, 10}; - auto prog = migraphx::parse_onnx("conv_dynamic_img_test.onnx", options); + auto prog = read_onnx("conv_dynamic_img_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_kernel_same_lower_test.cpp b/test/onnx/parse/conv_dynamic_kernel_same_lower_test.cpp index c4d71ce22c9..922f2320075 100644 --- a/test/onnx/parse/conv_dynamic_kernel_same_lower_test.cpp +++ b/test/onnx/parse/conv_dynamic_kernel_same_lower_test.cpp @@ -44,6 +44,6 @@ TEST_CASE(conv_dynamic_kernel_same_lower) migraphx::onnx_options options; options.default_dyn_dim_value = {2, 4}; - auto prog = migraphx::parse_onnx("conv_dynamic_kernel_same_lower_test.onnx", options); + auto prog = read_onnx("conv_dynamic_kernel_same_lower_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_dynamic_weights_test.cpp b/test/onnx/parse/conv_dynamic_weights_test.cpp index 62655ba7db5..fd4646b857a 100644 --- a/test/onnx/parse/conv_dynamic_weights_test.cpp +++ b/test/onnx/parse/conv_dynamic_weights_test.cpp @@ -41,6 +41,6 @@ TEST_CASE(conv_dynamic_weights_test) migraphx::onnx_options options; options.default_dyn_dim_value = {2, 4}; - auto prog = migraphx::parse_onnx("conv_dynamic_weights_test.onnx", options); + auto prog = read_onnx("conv_dynamic_weights_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_transpose_auto_pad_test.cpp b/test/onnx/parse/conv_transpose_auto_pad_test.cpp index 5c5630abedd..6a8fcef6d55 100644 --- a/test/onnx/parse/conv_transpose_auto_pad_test.cpp +++ b/test/onnx/parse/conv_transpose_auto_pad_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(conv_transpose_auto_pad_error) { - EXPECT(test::throws([&] { migraphx::parse_onnx("conv_transpose_auto_pad_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("conv_transpose_auto_pad_test.onnx"); })); } diff --git a/test/onnx/parse/conv_transpose_dyn_asym_padding_test.cpp b/test/onnx/parse/conv_transpose_dyn_asym_padding_test.cpp index a19bdaa554d..4c266444824 100644 --- a/test/onnx/parse/conv_transpose_dyn_asym_padding_test.cpp +++ b/test/onnx/parse/conv_transpose_dyn_asym_padding_test.cpp @@ -28,6 +28,5 @@ TEST_CASE(conv_transpose_dyn_asym_padding_error) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("conv_transpose_dyn_asym_padding_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("conv_transpose_dyn_asym_padding_test.onnx", options); })); } diff --git a/test/onnx/parse/conv_transpose_dyn_batch_test.cpp b/test/onnx/parse/conv_transpose_dyn_batch_test.cpp index a72e19da7a0..f18cc059689 100644 --- a/test/onnx/parse/conv_transpose_dyn_batch_test.cpp +++ b/test/onnx/parse/conv_transpose_dyn_batch_test.cpp @@ -36,6 +36,6 @@ TEST_CASE(conv_transpose_dyn_batch_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("conv_transpose_dyn_batch_test.onnx", options); + auto prog = read_onnx("conv_transpose_dyn_batch_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_transpose_dyn_img_test.cpp b/test/onnx/parse/conv_transpose_dyn_img_test.cpp index caf21a1003d..230caa4067f 100644 --- a/test/onnx/parse/conv_transpose_dyn_img_test.cpp +++ b/test/onnx/parse/conv_transpose_dyn_img_test.cpp @@ -36,6 +36,6 @@ TEST_CASE(conv_transpose_dyn_img_test) migraphx::onnx_options options; options.default_dyn_dim_value = {3, 6}; - auto prog = parse_onnx("conv_transpose_dyn_img_test.onnx", options); + auto prog = read_onnx("conv_transpose_dyn_img_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/conv_transpose_dyn_output_shape_test.cpp b/test/onnx/parse/conv_transpose_dyn_output_shape_test.cpp index 3cfdde3c2b0..54cd097323f 100644 --- a/test/onnx/parse/conv_transpose_dyn_output_shape_test.cpp +++ b/test/onnx/parse/conv_transpose_dyn_output_shape_test.cpp @@ -28,6 +28,5 @@ TEST_CASE(conv_transpose_dyn_output_shape_error) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("conv_transpose_dyn_output_shape_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("conv_transpose_dyn_output_shape_test.onnx", options); })); } diff --git a/test/onnx/parse/dim_param_test.cpp b/test/onnx/parse/dim_param_test.cpp index 50fc21bdeb0..de8b8feae7c 100644 --- a/test/onnx/parse/dim_param_test.cpp +++ b/test/onnx/parse/dim_param_test.cpp @@ -34,7 +34,7 @@ TEST_CASE(dim_param_fixed_test) migraphx::onnx_options opt; opt.dim_params = {{"dim0", migraphx::shape::dynamic_dimension{2, 2}}, {"dim1", migraphx::shape::dynamic_dimension{4, 4}}}; - auto prog = migraphx::parse_onnx("dim_param_test.onnx", opt); + auto prog = read_onnx("dim_param_test.onnx", opt); EXPECT(p == prog); } @@ -51,6 +51,6 @@ TEST_CASE(dim_param_dynamic_test) migraphx::onnx_options opt; opt.dim_params = {{"dim0", migraphx::shape::dynamic_dimension{1, 2}}, {"dim1", migraphx::shape::dynamic_dimension{2, 4}}}; - auto prog = migraphx::parse_onnx("dim_param_test.onnx", opt); + auto prog = read_onnx("dim_param_test.onnx", opt); EXPECT(p == prog); } diff --git a/test/onnx/parse/dropout_test.cpp b/test/onnx/parse/dropout_test.cpp index 7ef0cab32eb..af39859ce93 100644 --- a/test/onnx/parse/dropout_test.cpp +++ b/test/onnx/parse/dropout_test.cpp @@ -35,6 +35,6 @@ TEST_CASE(dropout_test) mm->add_literal(migraphx::literal(s, vec)); mm->add_return({out}); - auto prog = migraphx::parse_onnx("dropout_test.onnx"); + auto prog = read_onnx("dropout_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/einsum_negative_tests.cpp b/test/onnx/parse/einsum_negative_tests.cpp new file mode 100644 index 00000000000..973070d10ae --- /dev/null +++ b/test/onnx/parse/einsum_negative_tests.cpp @@ -0,0 +1,100 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(einsum_missing_equation_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_missing_equation_negative_test.onnx"); })); +} + +TEST_CASE(einsum_multiple_arrows_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_multiple_arrows_negative_test.onnx"); })); +} + +TEST_CASE(einsum_empty_term_before_arrow_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_empty_term_before_arrow_negative_test.onnx"); })); +} + +TEST_CASE(einsum_multiple_ellipses_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_multiple_ellipses_negative_test.onnx"); })); +} + +TEST_CASE(einsum_comma_in_output_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_comma_in_output_negative_test.onnx"); })); +} + +TEST_CASE(einsum_empty_term_before_comma_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_empty_term_before_comma_negative_test.onnx"); })); +} + +TEST_CASE(einsum_last_input_missing_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_last_input_missing_negative_test.onnx"); })); +} + +TEST_CASE(einsum_term_input_mismatch_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_term_input_mismatch_negative_test.onnx"); })); +} + +TEST_CASE(einsum_ellipsis_mismatch_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_ellipsis_mismatch_negative_test.onnx"); })); +} + +TEST_CASE(einsum_rank_mismatch_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_rank_mismatch_negative_test.onnx"); })); +} + +TEST_CASE(einsum_output_surplus_label_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_output_surplus_label_negative_test.onnx"); })); +} + +TEST_CASE(einsum_output_missing_ellipsis_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_output_missing_ellipsis_negative_test.onnx"); })); +} + +TEST_CASE(einsum_multiple_diagonals_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_multiple_diagonals_negative_test.onnx"); })); +} + +TEST_CASE(einsum_diagonal_dim_mismatch_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_diagonal_dim_mismatch_negative_test.onnx"); })); +} + +TEST_CASE(einsum_right_batch_diagonal_negative_test) +{ + EXPECT(test::throws([&] { read_onnx("einsum_right_batch_diagonal_negative_test.onnx"); })); +} diff --git a/test/onnx/parse/embedding_bag_offset_test.cpp b/test/onnx/parse/embedding_bag_offset_test.cpp index fceb1386f89..63dba9a64ee 100644 --- a/test/onnx/parse/embedding_bag_offset_test.cpp +++ b/test/onnx/parse/embedding_bag_offset_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(embedding_bag_offset_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("embedding_bag_offset_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("embedding_bag_offset_test.onnx"); })); } diff --git a/test/onnx/parse/embedding_bag_test.cpp b/test/onnx/parse/embedding_bag_test.cpp index 9d5b4e62282..dba4f4086cb 100644 --- a/test/onnx/parse/embedding_bag_test.cpp +++ b/test/onnx/parse/embedding_bag_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(embedding_bag_test) auto r3 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), l6); mm->add_return({r1, r2, r3}); - auto prog = migraphx::parse_onnx("embedding_bag_test.onnx"); + auto prog = read_onnx("embedding_bag_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/equal_bool_test.cpp b/test/onnx/parse/equal_bool_test.cpp index 8efc9e27904..d8a33132705 100644 --- a/test/onnx/parse/equal_bool_test.cpp +++ b/test/onnx/parse/equal_bool_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(equal_bool_test) auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("equal_bool_test.onnx"); + auto prog = read_onnx("equal_bool_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/equal_test.cpp b/test/onnx/parse/equal_test.cpp index f37507b8827..5b5ad76cfe6 100644 --- a/test/onnx/parse/equal_test.cpp +++ b/test/onnx/parse/equal_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(equal_test) eq); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("equal_test.onnx"); + auto prog = read_onnx("equal_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/expand_test.cpp b/test/onnx/parse/expand_test.cpp index 9f48934f402..6933b31c768 100644 --- a/test/onnx/parse/expand_test.cpp +++ b/test/onnx/parse/expand_test.cpp @@ -65,7 +65,7 @@ TEST_CASE(expand_dyn_input_dyn_output_test) migraphx::onnx_options options; options.default_dyn_dim_value = {3, 8}; - auto prog = parse_onnx("expand_dyn_input_dyn_output_test.onnx", options); + auto prog = read_onnx("expand_dyn_input_dyn_output_test.onnx", options); EXPECT(p == prog); } @@ -73,6 +73,5 @@ TEST_CASE(expand_dyn_input_static_dims_throw) { migraphx::onnx_options options; options.default_dyn_dim_value = {3, 8}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("expand_dyn_input_static_dims_throw.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("expand_dyn_input_static_dims_throw.onnx", options); })); } diff --git a/test/onnx/parse/ext_path_external_data_test.cpp b/test/onnx/parse/ext_path_external_data_test.cpp index ceeceb1c09f..f339d476fa2 100644 --- a/test/onnx/parse/ext_path_external_data_test.cpp +++ b/test/onnx/parse/ext_path_external_data_test.cpp @@ -22,7 +22,12 @@ * THE SOFTWARE. */ -#include +#include "onnx_test.hpp" +#include +#include +#include +#include +#include #include TEST_CASE(external_data_diff_path_test) diff --git a/test/onnx/parse/eyelike_k_outofbounds_neg_test.cpp b/test/onnx/parse/eyelike_k_outofbounds_neg_test.cpp index ffcd698fabf..e3505a3fb56 100644 --- a/test/onnx/parse/eyelike_k_outofbounds_neg_test.cpp +++ b/test/onnx/parse/eyelike_k_outofbounds_neg_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(eyelike_k_outofbounds_neg_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("eyelike_k_outofbounds_neg_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("eyelike_k_outofbounds_neg_test.onnx"); })); } diff --git a/test/onnx/parse/eyelike_k_outofbounds_pos_test.cpp b/test/onnx/parse/eyelike_k_outofbounds_pos_test.cpp index 02b3e933a5a..12543d8769f 100644 --- a/test/onnx/parse/eyelike_k_outofbounds_pos_test.cpp +++ b/test/onnx/parse/eyelike_k_outofbounds_pos_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(eyelike_k_outofbounds_pos_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("eyelike_k_outofbounds_pos_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("eyelike_k_outofbounds_pos_test.onnx"); })); } diff --git a/test/onnx/parse/eyelike_not_rank2_test.cpp b/test/onnx/parse/eyelike_not_rank2_test.cpp index 200dcbab8c6..e9e5942ac82 100644 --- a/test/onnx/parse/eyelike_not_rank2_test.cpp +++ b/test/onnx/parse/eyelike_not_rank2_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(eyelike_not_rank2_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("eyelike_not_rank2_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("eyelike_not_rank2_test.onnx"); })); } diff --git a/test/onnx/parse/flatten_dyn_test.cpp b/test/onnx/parse/flatten_dyn_test.cpp index cd35ed5fc9b..bf97f1c7160 100644 --- a/test/onnx/parse/flatten_dyn_test.cpp +++ b/test/onnx/parse/flatten_dyn_test.cpp @@ -36,6 +36,6 @@ TEST_CASE(flatten_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("flatten_dyn_test.onnx", options); + auto prog = read_onnx("flatten_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gather_dyn_test.cpp b/test/onnx/parse/gather_dyn_test.cpp index 903a2b3f9a2..13b3e6c0557 100644 --- a/test/onnx/parse/gather_dyn_test.cpp +++ b/test/onnx/parse/gather_dyn_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(gather_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("gather_dyn_test.onnx", options); + auto prog = read_onnx("gather_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gather_elements_axis0_test.cpp b/test/onnx/parse/gather_elements_axis0_test.cpp index 7ed63e933b7..54e53c4a0e3 100644 --- a/test/onnx/parse/gather_elements_axis0_test.cpp +++ b/test/onnx/parse/gather_elements_axis0_test.cpp @@ -48,7 +48,7 @@ TEST_CASE(gather_elements_axis0_test) auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("gather_elements_axis0_test.onnx"); + auto prog = read_onnx("gather_elements_axis0_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/gather_elements_axis1_test.cpp b/test/onnx/parse/gather_elements_axis1_test.cpp index f63286fcce7..bcbadd19fa5 100644 --- a/test/onnx/parse/gather_elements_axis1_test.cpp +++ b/test/onnx/parse/gather_elements_axis1_test.cpp @@ -48,7 +48,7 @@ TEST_CASE(gather_elements_axis1_test) auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("gather_elements_axis1_test.onnx"); + auto prog = read_onnx("gather_elements_axis1_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/gathernd_dyn_test.cpp b/test/onnx/parse/gathernd_dyn_test.cpp index 38b65053579..e72a4438c1f 100644 --- a/test/onnx/parse/gathernd_dyn_test.cpp +++ b/test/onnx/parse/gathernd_dyn_test.cpp @@ -38,6 +38,6 @@ TEST_CASE(gathernd_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["data"] = {{2, 4, {2}}, {2, 4}}; options.map_dyn_input_dims["indices"] = {{1, 3}, {2, 2}}; - auto prog = migraphx::parse_onnx("gathernd_dyn_test.onnx", options); + auto prog = read_onnx("gathernd_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gelu_bias_invalid_type_test.cpp b/test/onnx/parse/gelu_bias_invalid_type_test.cpp index 270a084fddd..af5bd0fa8c4 100644 --- a/test/onnx/parse/gelu_bias_invalid_type_test.cpp +++ b/test/onnx/parse/gelu_bias_invalid_type_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(gelu_bias_invalid_type_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("gelu_bias_invalid_type_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("gelu_bias_invalid_type_test.onnx"); })); } diff --git a/test/onnx/parse/gelu_fast_invald_bias_test.cpp b/test/onnx/parse/gelu_fast_invald_bias_test.cpp index 52664038d6b..cf6c262cb65 100644 --- a/test/onnx/parse/gelu_fast_invald_bias_test.cpp +++ b/test/onnx/parse/gelu_fast_invald_bias_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(gelu_fast_invalid_bias_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("gelu_fast_invalid_bias_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("gelu_fast_invalid_bias_test.onnx"); })); } diff --git a/test/onnx/parse/gelu_fast_invald_x_test.cpp b/test/onnx/parse/gelu_fast_invald_x_test.cpp index f4ef9eda473..9644401353b 100644 --- a/test/onnx/parse/gelu_fast_invald_x_test.cpp +++ b/test/onnx/parse/gelu_fast_invald_x_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(gelu_fast_invalid_x_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("gelu_fast_invalid_x_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("gelu_fast_invalid_x_test.onnx"); })); } diff --git a/test/onnx/parse/gelu_invalid_input_type_test.cpp b/test/onnx/parse/gelu_invalid_input_type_test.cpp index 7e6e9fa5276..9c088ae3710 100644 --- a/test/onnx/parse/gelu_invalid_input_type_test.cpp +++ b/test/onnx/parse/gelu_invalid_input_type_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(gelu_invalid_input_type_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("gelu_invalid_input_type_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("gelu_invalid_input_type_test.onnx"); })); } diff --git a/test/onnx/parse/gemm_dyn_bias_test.cpp b/test/onnx/parse/gemm_dyn_bias_test.cpp index 29fe4971a86..65caa831467 100644 --- a/test/onnx/parse/gemm_dyn_bias_test.cpp +++ b/test/onnx/parse/gemm_dyn_bias_test.cpp @@ -40,6 +40,6 @@ TEST_CASE(gemm_dyn_bias_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options); + auto prog = read_onnx("gemm_dyn_bias_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gemm_dyn_inner_test.cpp b/test/onnx/parse/gemm_dyn_inner_test.cpp index c72e3781e0e..288232c947b 100644 --- a/test/onnx/parse/gemm_dyn_inner_test.cpp +++ b/test/onnx/parse/gemm_dyn_inner_test.cpp @@ -42,6 +42,6 @@ TEST_CASE(gemm_dyn_inner_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10, {8}}; - auto prog = migraphx::parse_onnx("gemm_dyn_inner_test.onnx", options); + auto prog = read_onnx("gemm_dyn_inner_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gemm_dyn_outer_test.cpp b/test/onnx/parse/gemm_dyn_outer_test.cpp index f8bff6212bb..9b983ee6408 100644 --- a/test/onnx/parse/gemm_dyn_outer_test.cpp +++ b/test/onnx/parse/gemm_dyn_outer_test.cpp @@ -42,6 +42,6 @@ TEST_CASE(gemm_dyn_outer_test) migraphx::onnx_options options; options.default_dyn_dim_value = {5, 10, {7}}; - auto prog = migraphx::parse_onnx("gemm_dyn_outer_test.onnx", options); + auto prog = read_onnx("gemm_dyn_outer_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gemm_rank_error.cpp b/test/onnx/parse/gemm_rank_error.cpp index 5af081a8fc2..10b441f5624 100644 --- a/test/onnx/parse/gemm_rank_error.cpp +++ b/test/onnx/parse/gemm_rank_error.cpp @@ -26,5 +26,5 @@ TEST_CASE(gemm_rank_error) { - EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); })); + EXPECT(test::throws([&] { read_onnx("gemm_rank_error.onnx"); })); } diff --git a/test/onnx/parse/globalavgpool_dyn_test.cpp b/test/onnx/parse/globalavgpool_dyn_test.cpp index b4a0a62425e..2a0726ee070 100644 --- a/test/onnx/parse/globalavgpool_dyn_test.cpp +++ b/test/onnx/parse/globalavgpool_dyn_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(globalavgpool_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("globalavgpool_dyn_test.onnx", options); + auto prog = read_onnx("globalavgpool_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/globallppool_dyn_test.cpp b/test/onnx/parse/globallppool_dyn_test.cpp index 4202b1cd117..f10aa18f724 100644 --- a/test/onnx/parse/globallppool_dyn_test.cpp +++ b/test/onnx/parse/globallppool_dyn_test.cpp @@ -41,7 +41,7 @@ TEST_CASE(globallppool_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {16, 32}; - auto prog = migraphx::parse_onnx("globallppool_dyn_test.onnx", options); + auto prog = read_onnx("globallppool_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/globalmaxpool_dyn_test.cpp b/test/onnx/parse/globalmaxpool_dyn_test.cpp index 618eeb98cd6..617726fb037 100644 --- a/test/onnx/parse/globalmaxpool_dyn_test.cpp +++ b/test/onnx/parse/globalmaxpool_dyn_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(globalmaxpool_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("globalmaxpool_dyn_test.onnx", options); + auto prog = read_onnx("globalmaxpool_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/greater_bool_test.cpp b/test/onnx/parse/greater_bool_test.cpp index 19ce678b96c..fa4f6ea9984 100644 --- a/test/onnx/parse/greater_bool_test.cpp +++ b/test/onnx/parse/greater_bool_test.cpp @@ -40,6 +40,6 @@ TEST_CASE(greater_bool_test) auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("greater_bool_test.onnx"); + auto prog = read_onnx("greater_bool_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/greater_test.cpp b/test/onnx/parse/greater_test.cpp index c42b2b8b96f..72df7aa83a1 100644 --- a/test/onnx/parse/greater_test.cpp +++ b/test/onnx/parse/greater_test.cpp @@ -40,6 +40,6 @@ TEST_CASE(greater_test) gr); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("greater_test.onnx"); + auto prog = read_onnx("greater_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/greaterorequal_test.cpp b/test/onnx/parse/greaterorequal_test.cpp index 1c9c982ac26..abbd971989e 100644 --- a/test/onnx/parse/greaterorequal_test.cpp +++ b/test/onnx/parse/greaterorequal_test.cpp @@ -38,6 +38,6 @@ TEST_CASE(greaterorequal_test) mm->add_return({ge}); - auto prog = migraphx::parse_onnx("greaterorequal_test.onnx"); + auto prog = read_onnx("greaterorequal_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/group_norm_invalid_bias_shape_test.cpp b/test/onnx/parse/group_norm_invalid_bias_shape_test.cpp index af223de3a09..dde23ec9660 100644 --- a/test/onnx/parse/group_norm_invalid_bias_shape_test.cpp +++ b/test/onnx/parse/group_norm_invalid_bias_shape_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(group_norm_invalid_bias_shape_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("group_norm_invalid_bias_shape_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("group_norm_invalid_bias_shape_test.onnx"); })); } diff --git a/test/onnx/parse/group_norm_invalid_input_count_error_test.cpp b/test/onnx/parse/group_norm_invalid_input_count_error_test.cpp index 046aef8429f..713459f70e5 100644 --- a/test/onnx/parse/group_norm_invalid_input_count_error_test.cpp +++ b/test/onnx/parse/group_norm_invalid_input_count_error_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(group_norm_invalid_input_count_error_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("group_norm_invalid_input_count_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("group_norm_invalid_input_count_error_test.onnx"); })); } diff --git a/test/onnx/parse/group_norm_invalid_input_shape_error_test.cpp b/test/onnx/parse/group_norm_invalid_input_shape_error_test.cpp index 5e817c813cd..098e4cdc517 100644 --- a/test/onnx/parse/group_norm_invalid_input_shape_error_test.cpp +++ b/test/onnx/parse/group_norm_invalid_input_shape_error_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(group_norm_invalid_input_shape_error_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("group_norm_invalid_input_shape_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("group_norm_invalid_input_shape_error_test.onnx"); })); } diff --git a/test/onnx/parse/group_norm_invalid_num_groups_error_test.cpp b/test/onnx/parse/group_norm_invalid_num_groups_error_test.cpp index 27029b92808..127229c44a9 100644 --- a/test/onnx/parse/group_norm_invalid_num_groups_error_test.cpp +++ b/test/onnx/parse/group_norm_invalid_num_groups_error_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(group_norm_invalid_num_groups_error_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("group_norm_invalid_num_groups_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("group_norm_invalid_num_groups_error_test.onnx"); })); } diff --git a/test/onnx/parse/group_norm_invalid_scale_shape_test.cpp b/test/onnx/parse/group_norm_invalid_scale_shape_test.cpp index bdfc3597150..18b70c9f02e 100644 --- a/test/onnx/parse/group_norm_invalid_scale_shape_test.cpp +++ b/test/onnx/parse/group_norm_invalid_scale_shape_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(group_norm_invalid_scale_shape_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("group_norm_invalid_scale_shape_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("group_norm_invalid_scale_shape_test.onnx"); })); } diff --git a/test/onnx/parse/group_norm_missing_attribute_error_test.cpp b/test/onnx/parse/group_norm_missing_attribute_error_test.cpp index b0b3741fb20..99809bca202 100644 --- a/test/onnx/parse/group_norm_missing_attribute_error_test.cpp +++ b/test/onnx/parse/group_norm_missing_attribute_error_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(group_norm_missing_attribute_error_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("group_norm_missing_attribute_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("group_norm_missing_attribute_error_test.onnx"); })); } diff --git a/test/onnx/parse/if_else_test.cpp b/test/onnx/parse/if_else_test.cpp index d169e0bb098..498023c68bb 100644 --- a/test/onnx/parse/if_else_test.cpp +++ b/test/onnx/parse/if_else_test.cpp @@ -52,6 +52,6 @@ TEST_CASE(if_else_test) auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); mm->add_return({r}); - auto prog = migraphx::parse_onnx("if_else_test.onnx"); + auto prog = read_onnx("if_else_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_else_test_inlined.cpp b/test/onnx/parse/if_else_test_inlined.cpp index 81e46db72f7..8b93026e5b6 100644 --- a/test/onnx/parse/if_else_test_inlined.cpp +++ b/test/onnx/parse/if_else_test_inlined.cpp @@ -45,6 +45,6 @@ TEST_CASE(if_else_test_inlined) auto re = mm->add_instruction(migraphx::make_op("mul"), y, l2); mm->add_return({re}); - auto prog = migraphx::parse_onnx("if_else_test_inlined.onnx"); + auto prog = read_onnx("if_else_test_inlined.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_literal_test.cpp b/test/onnx/parse/if_literal_test.cpp index 9582fe87046..c2679fd4197 100644 --- a/test/onnx/parse/if_literal_test.cpp +++ b/test/onnx/parse/if_literal_test.cpp @@ -49,6 +49,6 @@ TEST_CASE(if_literal_test) auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); mm->add_return({r}); - auto prog = migraphx::parse_onnx("if_literal_test.onnx"); + auto prog = read_onnx("if_literal_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_param_excp1_test.cpp b/test/onnx/parse/if_param_excp1_test.cpp index f2136628765..f7110cc73b4 100644 --- a/test/onnx/parse/if_param_excp1_test.cpp +++ b/test/onnx/parse/if_param_excp1_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(if_param_excp1_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("if_param_excp1_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("if_param_excp1_test.onnx"); })); } diff --git a/test/onnx/parse/if_param_excp_test.cpp b/test/onnx/parse/if_param_excp_test.cpp index ac94229b6ff..7fa65f02876 100644 --- a/test/onnx/parse/if_param_excp_test.cpp +++ b/test/onnx/parse/if_param_excp_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(if_param_excp_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("if_param_excp_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("if_param_excp_test.onnx"); })); } diff --git a/test/onnx/parse/if_param_test.cpp b/test/onnx/parse/if_param_test.cpp index b84778691b0..0578129c593 100644 --- a/test/onnx/parse/if_param_test.cpp +++ b/test/onnx/parse/if_param_test.cpp @@ -50,6 +50,6 @@ TEST_CASE(if_param_test) auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); mm->add_return({r}); - auto prog = migraphx::parse_onnx("if_param_test.onnx"); + auto prog = read_onnx("if_param_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_pl_test.cpp b/test/onnx/parse/if_pl_test.cpp index dcf9c66391f..b875551436a 100644 --- a/test/onnx/parse/if_pl_test.cpp +++ b/test/onnx/parse/if_pl_test.cpp @@ -55,6 +55,6 @@ TEST_CASE(if_pl_test) mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r}); - auto prog = migraphx::parse_onnx("if_pl_test.onnx"); + auto prog = read_onnx("if_pl_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_then_else_multi_output_shapes_inlined_test.cpp b/test/onnx/parse/if_then_else_multi_output_shapes_inlined_test.cpp index b5b9b8438d8..1dc0c9775dd 100644 --- a/test/onnx/parse/if_then_else_multi_output_shapes_inlined_test.cpp +++ b/test/onnx/parse/if_then_else_multi_output_shapes_inlined_test.cpp @@ -48,6 +48,6 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test) mm->add_return({rt, rt2}); - auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx"); + auto prog = read_onnx("if_then_else_multi_output_shapes_inlined_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_then_else_multi_output_shapes_test.cpp b/test/onnx/parse/if_then_else_multi_output_shapes_test.cpp index bdf3c099a90..41f6c8caa94 100644 --- a/test/onnx/parse/if_then_else_multi_output_shapes_test.cpp +++ b/test/onnx/parse/if_then_else_multi_output_shapes_test.cpp @@ -56,6 +56,6 @@ TEST_CASE(if_then_else_multi_output_shapes_test) auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r1, r2}); - auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx"); + auto prog = read_onnx("if_then_else_multi_output_shapes_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_then_test.cpp b/test/onnx/parse/if_then_test.cpp index 7081fc17217..f9a0358df3b 100644 --- a/test/onnx/parse/if_then_test.cpp +++ b/test/onnx/parse/if_then_test.cpp @@ -52,6 +52,6 @@ TEST_CASE(if_then_test) auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); mm->add_return({r}); - auto prog = migraphx::parse_onnx("if_then_test.onnx"); + auto prog = read_onnx("if_then_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_then_test_inlined.cpp b/test/onnx/parse/if_then_test_inlined.cpp index afaeae0f203..6820eac164a 100644 --- a/test/onnx/parse/if_then_test_inlined.cpp +++ b/test/onnx/parse/if_then_test_inlined.cpp @@ -46,6 +46,6 @@ TEST_CASE(if_then_test_inlined) auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1); mm->add_return({rt}); - auto prog = migraphx::parse_onnx("if_then_test_inlined.onnx"); + auto prog = read_onnx("if_then_test_inlined.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/if_tuple_test.cpp b/test/onnx/parse/if_tuple_test.cpp index f463c868275..5ad0b8ae9d8 100644 --- a/test/onnx/parse/if_tuple_test.cpp +++ b/test/onnx/parse/if_tuple_test.cpp @@ -62,6 +62,6 @@ TEST_CASE(if_tuple_test) auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r0, r1}); - auto prog = migraphx::parse_onnx("if_tuple_test.onnx"); + auto prog = read_onnx("if_tuple_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/implicit_add_bcast_test.cpp b/test/onnx/parse/implicit_add_bcast_test.cpp index af02569b1ba..51f0898b95a 100644 --- a/test/onnx/parse/implicit_add_bcast_test.cpp +++ b/test/onnx/parse/implicit_add_bcast_test.cpp @@ -53,7 +53,7 @@ TEST_CASE(implicit_add_bcast_user_input_shape_test) migraphx::onnx_options options; options.map_input_dims["0"] = {3, 4, 5, 6}; options.map_input_dims["1"] = {4, 5, 1}; - auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx", options); + auto prog = read_onnx("implicit_add_bcast_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/instance_norm_dyn_batch_half_test.cpp b/test/onnx/parse/instance_norm_dyn_batch_half_test.cpp index 20873089655..548ae05df22 100644 --- a/test/onnx/parse/instance_norm_dyn_batch_half_test.cpp +++ b/test/onnx/parse/instance_norm_dyn_batch_half_test.cpp @@ -69,6 +69,6 @@ TEST_CASE(instance_norm_dyn_batch_half_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 2, {2}}; - auto prog = migraphx::parse_onnx("instance_norm_dyn_batch_half_test.onnx", options); + auto prog = read_onnx("instance_norm_dyn_batch_half_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/instance_norm_dyn_batch_test.cpp b/test/onnx/parse/instance_norm_dyn_batch_test.cpp index 245dc51e485..c58705ddb8c 100644 --- a/test/onnx/parse/instance_norm_dyn_batch_test.cpp +++ b/test/onnx/parse/instance_norm_dyn_batch_test.cpp @@ -55,7 +55,7 @@ TEST_CASE(instance_norm_dyn_batch_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 2, {2}}; - auto prog = migraphx::parse_onnx("instance_norm_dyn_batch_test.onnx", options); + auto prog = read_onnx("instance_norm_dyn_batch_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/instance_norm_invalid_type_test.cpp b/test/onnx/parse/instance_norm_invalid_type_test.cpp index 151f69009af..72359343eee 100644 --- a/test/onnx/parse/instance_norm_invalid_type_test.cpp +++ b/test/onnx/parse/instance_norm_invalid_type_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(instance_norm_invalid_type_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_invalid_type_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("instance_norm_invalid_type_test.onnx"); })); } diff --git a/test/onnx/parse/instance_norm_nonbroadcastable_test.cpp b/test/onnx/parse/instance_norm_nonbroadcastable_test.cpp index 22b4eef8c52..f228b868e0e 100644 --- a/test/onnx/parse/instance_norm_nonbroadcastable_test.cpp +++ b/test/onnx/parse/instance_norm_nonbroadcastable_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(instance_norm_nonbroadcastable_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_nonbroadcastable_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("instance_norm_nonbroadcastable_test.onnx"); })); } diff --git a/test/onnx/parse/instance_norm_test.cpp b/test/onnx/parse/instance_norm_test.cpp index d0115043867..375ae176ab0 100644 --- a/test/onnx/parse/instance_norm_test.cpp +++ b/test/onnx/parse/instance_norm_test.cpp @@ -58,7 +58,7 @@ TEST_CASE(instance_norm_test) mm->add_return({ret}); migraphx::onnx_options options; - auto prog = migraphx::parse_onnx("instance_norm_test.onnx", options); + auto prog = read_onnx("instance_norm_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/instance_norm_type_mismatch_test.cpp b/test/onnx/parse/instance_norm_type_mismatch_test.cpp index 47f62fe0347..73262e2ef1a 100644 --- a/test/onnx/parse/instance_norm_type_mismatch_test.cpp +++ b/test/onnx/parse/instance_norm_type_mismatch_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(instance_norm_type_mismatch_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_type_mismatch_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("instance_norm_type_mismatch_test.onnx"); })); } diff --git a/test/onnx/parse/isinf_double_pos_test.cpp b/test/onnx/parse/isinf_double_pos_test.cpp index 58374ac9051..cb77c706fcd 100644 --- a/test/onnx/parse/isinf_double_pos_test.cpp +++ b/test/onnx/parse/isinf_double_pos_test.cpp @@ -44,6 +44,6 @@ TEST_CASE(isinf_double_pos_test) auto ret = mm->add_instruction(migraphx::make_op("logical_and"), is_inf, is_neg); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("isinf_double_pos_test.onnx"); + auto prog = read_onnx("isinf_double_pos_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/isinf_half_test.cpp b/test/onnx/parse/isinf_half_test.cpp index 4c34da0c955..8848872f090 100644 --- a/test/onnx/parse/isinf_half_test.cpp +++ b/test/onnx/parse/isinf_half_test.cpp @@ -33,6 +33,6 @@ TEST_CASE(isinf_half_test) auto ret = mm->add_instruction(migraphx::make_op("isinf"), t1); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("isinf_half_test.onnx"); + auto prog = read_onnx("isinf_half_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/isinf_neg_test.cpp b/test/onnx/parse/isinf_neg_test.cpp index ce6107c3214..bb097beea16 100644 --- a/test/onnx/parse/isinf_neg_test.cpp +++ b/test/onnx/parse/isinf_neg_test.cpp @@ -44,6 +44,6 @@ TEST_CASE(isinf_neg_test) auto ret = mm->add_instruction(migraphx::make_op("logical_and"), is_inf, is_neg); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("isinf_neg_test.onnx"); + auto prog = read_onnx("isinf_neg_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/isinf_no_detect_test.cpp b/test/onnx/parse/isinf_no_detect_test.cpp index 879c5f40fdf..9a9384b2218 100644 --- a/test/onnx/parse/isinf_no_detect_test.cpp +++ b/test/onnx/parse/isinf_no_detect_test.cpp @@ -35,6 +35,6 @@ TEST_CASE(isinf_no_detect_test) mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::bool_type}, {false}})); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("isinf_no_detect_test.onnx"); + auto prog = read_onnx("isinf_no_detect_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/isnan_float_test.cpp b/test/onnx/parse/isnan_float_test.cpp index 03d43981f63..c057e8cd3c7 100644 --- a/test/onnx/parse/isnan_float_test.cpp +++ b/test/onnx/parse/isnan_float_test.cpp @@ -33,6 +33,6 @@ TEST_CASE(isnan_float_test) auto ret = mm->add_instruction(migraphx::make_op("isnan"), t1); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("isnan_float_test.onnx"); + auto prog = read_onnx("isnan_float_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/isnan_half_test.cpp b/test/onnx/parse/isnan_half_test.cpp index 66bdef8de5a..5919e25b5a1 100644 --- a/test/onnx/parse/isnan_half_test.cpp +++ b/test/onnx/parse/isnan_half_test.cpp @@ -33,6 +33,6 @@ TEST_CASE(isnan_half_test) auto ret = mm->add_instruction(migraphx::make_op("isnan"), t1); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("isnan_half_test.onnx"); + auto prog = read_onnx("isnan_half_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/layer_norm_invalid_axis_error_test.cpp b/test/onnx/parse/layer_norm_invalid_axis_error_test.cpp index c8eb4872189..41e61e09aa2 100644 --- a/test/onnx/parse/layer_norm_invalid_axis_error_test.cpp +++ b/test/onnx/parse/layer_norm_invalid_axis_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(layer_norm_invalid_axis_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("layer_norm_invalid_axis_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("layer_norm_invalid_axis_error_test.onnx"); })); } diff --git a/test/onnx/parse/layer_norm_invalid_input_count_error_test.cpp b/test/onnx/parse/layer_norm_invalid_input_count_error_test.cpp index 2c1e09b3245..6b58e366c74 100644 --- a/test/onnx/parse/layer_norm_invalid_input_count_error_test.cpp +++ b/test/onnx/parse/layer_norm_invalid_input_count_error_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(layer_norm_invalid_input_count_error_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("layer_norm_invalid_input_count_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("layer_norm_invalid_input_count_error_test.onnx"); })); } diff --git a/test/onnx/parse/layer_norm_invalid_minus_axis_error_test.cpp b/test/onnx/parse/layer_norm_invalid_minus_axis_error_test.cpp index 0db89aa9a76..2f0e4f12ebc 100644 --- a/test/onnx/parse/layer_norm_invalid_minus_axis_error_test.cpp +++ b/test/onnx/parse/layer_norm_invalid_minus_axis_error_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(layer_norm_invalid_minus_axis_error_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("layer_norm_invalid_minus_axis_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("layer_norm_invalid_minus_axis_error_test.onnx"); })); } diff --git a/test/onnx/parse/less_bool_test.cpp b/test/onnx/parse/less_bool_test.cpp index 249e77d50e1..aec6675634d 100644 --- a/test/onnx/parse/less_bool_test.cpp +++ b/test/onnx/parse/less_bool_test.cpp @@ -40,6 +40,6 @@ TEST_CASE(less_bool_test) auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("less_bool_test.onnx"); + auto prog = read_onnx("less_bool_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/less_test.cpp b/test/onnx/parse/less_test.cpp index e085b4756d3..aea6b9c7a29 100644 --- a/test/onnx/parse/less_test.cpp +++ b/test/onnx/parse/less_test.cpp @@ -40,6 +40,6 @@ TEST_CASE(less_test) le); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("less_test.onnx"); + auto prog = read_onnx("less_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/lessorequal_test.cpp b/test/onnx/parse/lessorequal_test.cpp index 3a6f08399a1..c27f9e0b2da 100644 --- a/test/onnx/parse/lessorequal_test.cpp +++ b/test/onnx/parse/lessorequal_test.cpp @@ -38,6 +38,6 @@ TEST_CASE(lessorequal_test) mm->add_return({le}); - auto prog = migraphx::parse_onnx("lessorequal_test.onnx"); + auto prog = read_onnx("lessorequal_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/logical_and_bcast_test.cpp b/test/onnx/parse/logical_and_bcast_test.cpp index 97e5af6242d..2e407eddde1 100644 --- a/test/onnx/parse/logical_and_bcast_test.cpp +++ b/test/onnx/parse/logical_and_bcast_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(logical_and_bcast_test) auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("logical_and_bcast_test.onnx"); + auto prog = read_onnx("logical_and_bcast_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/logical_or_test.cpp b/test/onnx/parse/logical_or_test.cpp index b321402db0d..ceb121ae23a 100644 --- a/test/onnx/parse/logical_or_test.cpp +++ b/test/onnx/parse/logical_or_test.cpp @@ -33,7 +33,7 @@ TEST_CASE(logical_or_test) auto ret = mm->add_instruction(migraphx::make_op("logical_or"), l0, l1); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("logical_or_test.onnx"); + auto prog = read_onnx("logical_or_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/logical_xor_bcast_test.cpp b/test/onnx/parse/logical_xor_bcast_test.cpp index e83dbd3c503..7c5136fe832 100644 --- a/test/onnx/parse/logical_xor_bcast_test.cpp +++ b/test/onnx/parse/logical_xor_bcast_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(logical_xor_bcast_test) auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("logical_xor_bcast_test.onnx"); + auto prog = read_onnx("logical_xor_bcast_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/logsoftmax_nonstd_input_test.cpp b/test/onnx/parse/logsoftmax_nonstd_input_test.cpp index 46e09830606..7022e5e9a95 100644 --- a/test/onnx/parse/logsoftmax_nonstd_input_test.cpp +++ b/test/onnx/parse/logsoftmax_nonstd_input_test.cpp @@ -34,7 +34,7 @@ TEST_CASE(logsoftmax_nonstd_input_test) auto l2 = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", -1}}), l1); mm->add_return({l2}); - auto prog = migraphx::parse_onnx("logsoftmax_nonstd_input_test.onnx"); + auto prog = read_onnx("logsoftmax_nonstd_input_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/loop_default_test.cpp b/test/onnx/parse/loop_default_test.cpp index 050c4d6e08c..87bb5d572d2 100644 --- a/test/onnx/parse/loop_default_test.cpp +++ b/test/onnx/parse/loop_default_test.cpp @@ -58,7 +58,7 @@ TEST_CASE(loop_default_test) auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp); mm->add_return({r0, r2}); - auto prog = migraphx::parse_onnx("loop_default_test.onnx"); + auto prog = read_onnx("loop_default_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/loop_test.cpp b/test/onnx/parse/loop_test.cpp index 8917e638ec8..7a94181f53a 100644 --- a/test/onnx/parse/loop_test.cpp +++ b/test/onnx/parse/loop_test.cpp @@ -56,7 +56,7 @@ TEST_CASE(loop_test) auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp); mm->add_return({r0, r2}); - auto prog = migraphx::parse_onnx("loop_test.onnx"); + auto prog = read_onnx("loop_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/lpnormalization_axis_error_test.cpp b/test/onnx/parse/lpnormalization_axis_error_test.cpp index 24f6f838a59..af226b44ac5 100644 --- a/test/onnx/parse/lpnormalization_axis_error_test.cpp +++ b/test/onnx/parse/lpnormalization_axis_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(lpnormalization_axis_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_axis_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("lpnormalization_axis_error_test.onnx"); })); } diff --git a/test/onnx/parse/lpnormalization_p_error_test.cpp b/test/onnx/parse/lpnormalization_p_error_test.cpp index 873e4adfbf2..040c19e07a4 100644 --- a/test/onnx/parse/lpnormalization_p_error_test.cpp +++ b/test/onnx/parse/lpnormalization_p_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(lpnormalization_p_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_p_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("lpnormalization_p_error_test.onnx"); })); } diff --git a/test/onnx/parse/matmul_dyn_broadcast_test.cpp b/test/onnx/parse/matmul_dyn_broadcast_test.cpp index 12149fb84be..bcad585c301 100644 --- a/test/onnx/parse/matmul_dyn_broadcast_test.cpp +++ b/test/onnx/parse/matmul_dyn_broadcast_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(matmul_dyn_broadcast_test) migraphx::onnx_options options; options.map_dyn_input_dims["2"] = {{5, 5}, {7, 7}, {4, 8, {6}}}; - auto prog = parse_onnx("matmul_dyn_broadcast_test.onnx", options); + auto prog = read_onnx("matmul_dyn_broadcast_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/matmul_dyn_mm_test.cpp b/test/onnx/parse/matmul_dyn_mm_test.cpp index 9c75f25b7ec..fafaa644635 100644 --- a/test/onnx/parse/matmul_dyn_mm_test.cpp +++ b/test/onnx/parse/matmul_dyn_mm_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(matmul_dyn_mm_test) migraphx::onnx_options options; options.map_dyn_input_dims["1"] = {{4, 8, {6}}, {7, 7}}; options.map_dyn_input_dims["2"] = {{7, 7}, {1, 5, {3}}}; - auto prog = parse_onnx("matmul_dyn_mm_test.onnx", options); + auto prog = read_onnx("matmul_dyn_mm_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/matmul_dyn_mv_test.cpp b/test/onnx/parse/matmul_dyn_mv_test.cpp index d6bf8e5c6c1..005611cd835 100644 --- a/test/onnx/parse/matmul_dyn_mv_test.cpp +++ b/test/onnx/parse/matmul_dyn_mv_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(matmul_dyn_mv_test) migraphx::onnx_options options; options.map_dyn_input_dims["1"] = {{4, 8, {6}}, {7, 7}}; - auto prog = parse_onnx("matmul_dyn_mv_test.onnx", options); + auto prog = read_onnx("matmul_dyn_mv_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/matmul_dyn_vm_test.cpp b/test/onnx/parse/matmul_dyn_vm_test.cpp index 49fea5ebc52..bfd65c0c288 100644 --- a/test/onnx/parse/matmul_dyn_vm_test.cpp +++ b/test/onnx/parse/matmul_dyn_vm_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(matmul_dyn_vm_test) migraphx::onnx_options options; options.map_dyn_input_dims["2"] = {{7, 7}, {4, 10, {8}}}; - auto prog = parse_onnx("matmul_dyn_vm_test.onnx", options); + auto prog = read_onnx("matmul_dyn_vm_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/matmul_dyn_vv_test.cpp b/test/onnx/parse/matmul_dyn_vv_test.cpp index d613363698f..481ffc73755 100644 --- a/test/onnx/parse/matmul_dyn_vv_test.cpp +++ b/test/onnx/parse/matmul_dyn_vv_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(matmul_dyn_vv_test) migraphx::onnx_options options; options.default_dyn_dim_value = dd; - auto prog = parse_onnx("matmul_dyn_vv_test.onnx", options); + auto prog = read_onnx("matmul_dyn_vv_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/matmulinteger_dyn_error.cpp b/test/onnx/parse/matmulinteger_dyn_error.cpp index 251fa6624fb..6def59fddcd 100644 --- a/test/onnx/parse/matmulinteger_dyn_error.cpp +++ b/test/onnx/parse/matmulinteger_dyn_error.cpp @@ -28,5 +28,5 @@ TEST_CASE(matmulinteger_dyn_error) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws([&] { migraphx::parse_onnx("matmulinteger_dyn_error.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("matmulinteger_dyn_error.onnx", options); })); } diff --git a/test/onnx/parse/mean_invalid_broadcast_test.cpp b/test/onnx/parse/mean_invalid_broadcast_test.cpp index fe3b7aa1cd3..8c70f251328 100644 --- a/test/onnx/parse/mean_invalid_broadcast_test.cpp +++ b/test/onnx/parse/mean_invalid_broadcast_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(mean_invalid_broadcast_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("mean_invalid_broadcast_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("mean_invalid_broadcast_test.onnx"); })); } diff --git a/test/onnx/parse/mean_single_input_test.cpp b/test/onnx/parse/mean_single_input_test.cpp index d9ad7aa0ce4..07106a2d75a 100644 --- a/test/onnx/parse/mean_single_input_test.cpp +++ b/test/onnx/parse/mean_single_input_test.cpp @@ -31,7 +31,7 @@ TEST_CASE(mean_single_input_test) auto data0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}); mm->add_return({data0}); - auto prog = migraphx::parse_onnx("mean_single_input_test.onnx"); + auto prog = read_onnx("mean_single_input_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/mod_test_half.cpp b/test/onnx/parse/mod_test_half.cpp index e6e8d146742..129c606d7c1 100644 --- a/test/onnx/parse/mod_test_half.cpp +++ b/test/onnx/parse/mod_test_half.cpp @@ -26,5 +26,5 @@ TEST_CASE(mod_test_half) { - EXPECT(test::throws([&] { migraphx::parse_onnx("mod_test_half.onnx"); })); + EXPECT(test::throws([&] { read_onnx("mod_test_half.onnx"); })); } diff --git a/test/onnx/parse/multinomial_autoseed_dyn_test.cpp b/test/onnx/parse/multinomial_autoseed_dyn_test.cpp index 09641cb08b6..1d447cc1fe6 100644 --- a/test/onnx/parse/multinomial_autoseed_dyn_test.cpp +++ b/test/onnx/parse/multinomial_autoseed_dyn_test.cpp @@ -70,6 +70,6 @@ TEST_CASE(multinomial_autoseed_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, categories}; options.print_program_on_error = true; - auto prog = migraphx::parse_onnx("multinomial_autoseed_dyn_test.onnx", options); + auto prog = read_onnx("multinomial_autoseed_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/multinomial_dtype_error_test.cpp b/test/onnx/parse/multinomial_dtype_error_test.cpp index 11b612f6c49..1cd1093e33e 100644 --- a/test/onnx/parse/multinomial_dtype_error_test.cpp +++ b/test/onnx/parse/multinomial_dtype_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(multinomial_dtype_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("multinomial_dtype_error_test.onnx"); })); } diff --git a/test/onnx/parse/multinomial_dyn_test.cpp b/test/onnx/parse/multinomial_dyn_test.cpp index d3c790f0594..625a1f16317 100644 --- a/test/onnx/parse/multinomial_dyn_test.cpp +++ b/test/onnx/parse/multinomial_dyn_test.cpp @@ -76,6 +76,6 @@ TEST_CASE(multinomial_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, categories}; options.print_program_on_error = true; - auto prog = migraphx::parse_onnx("multinomial_dyn_test.onnx", options); + auto prog = read_onnx("multinomial_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/mvn_axes_rank_too_small_test.cpp b/test/onnx/parse/mvn_axes_rank_too_small_test.cpp index 58764f8cb97..8de46a2e96b 100644 --- a/test/onnx/parse/mvn_axes_rank_too_small_test.cpp +++ b/test/onnx/parse/mvn_axes_rank_too_small_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(mvn_axes_rank_too_small_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_axes_rank_too_small_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("mvn_axes_rank_too_small_test.onnx"); })); } diff --git a/test/onnx/parse/mvn_default_axes_rank_too_big_test.cpp b/test/onnx/parse/mvn_default_axes_rank_too_big_test.cpp index d3fd92eedfe..015b55140c6 100644 --- a/test/onnx/parse/mvn_default_axes_rank_too_big_test.cpp +++ b/test/onnx/parse/mvn_default_axes_rank_too_big_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(mvn_default_axes_rank_too_big_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_default_axes_rank_too_big_test.onnx"); })); + EXPECT(test::throws([&] { optimize_onnx("mvn_default_axes_rank_too_big_test.onnx"); })); } diff --git a/test/onnx/parse/neg_dynamic_test.cpp b/test/onnx/parse/neg_dynamic_test.cpp index 6e9a5b4c25b..087149f0807 100644 --- a/test/onnx/parse/neg_dynamic_test.cpp +++ b/test/onnx/parse/neg_dynamic_test.cpp @@ -35,6 +35,6 @@ TEST_CASE(neg_dynamic_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto prog = migraphx::parse_onnx("neg_dynamic_test.onnx", options); + auto prog = read_onnx("neg_dynamic_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/neg_test.cpp b/test/onnx/parse/neg_test.cpp index 15f37f47ff4..94978339c9c 100644 --- a/test/onnx/parse/neg_test.cpp +++ b/test/onnx/parse/neg_test.cpp @@ -33,7 +33,7 @@ TEST_CASE(neg_test) auto ret = mm->add_instruction(migraphx::make_op("neg"), input); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("neg_test.onnx"); + auto prog = read_onnx("neg_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/nms_dynamic_batch_test.cpp b/test/onnx/parse/nms_dynamic_batch_test.cpp index bfca5ea736c..d836f4db403 100644 --- a/test/onnx/parse/nms_dynamic_batch_test.cpp +++ b/test/onnx/parse/nms_dynamic_batch_test.cpp @@ -52,6 +52,6 @@ TEST_CASE(nms_dynamic_batch_test) options.default_dyn_dim_value = {1, 10}; options.use_dyn_output = true; - auto prog = migraphx::parse_onnx("nms_dynamic_batch_test.onnx", options); + auto prog = read_onnx("nms_dynamic_batch_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/nms_dynamic_boxes_test.cpp b/test/onnx/parse/nms_dynamic_boxes_test.cpp index 71fc6fe2670..5c84306e98f 100644 --- a/test/onnx/parse/nms_dynamic_boxes_test.cpp +++ b/test/onnx/parse/nms_dynamic_boxes_test.cpp @@ -46,6 +46,6 @@ TEST_CASE(nms_dynamic_boxes_test) options.default_dyn_dim_value = {6, 20}; options.use_dyn_output = true; - auto prog = migraphx::parse_onnx("nms_dynamic_boxes_test.onnx", options); + auto prog = read_onnx("nms_dynamic_boxes_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/nms_dynamic_classes_test.cpp b/test/onnx/parse/nms_dynamic_classes_test.cpp index b19756a92d9..090734f8a4b 100644 --- a/test/onnx/parse/nms_dynamic_classes_test.cpp +++ b/test/onnx/parse/nms_dynamic_classes_test.cpp @@ -46,6 +46,6 @@ TEST_CASE(nms_dynamic_classes_test) options.default_dyn_dim_value = {1, 10}; options.use_dyn_output = true; - auto prog = migraphx::parse_onnx("nms_dynamic_classes_test.onnx", options); + auto prog = read_onnx("nms_dynamic_classes_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/nms_test.cpp b/test/onnx/parse/nms_test.cpp index 34ba9920b17..c69c0584b08 100644 --- a/test/onnx/parse/nms_test.cpp +++ b/test/onnx/parse/nms_test.cpp @@ -47,6 +47,6 @@ TEST_CASE(nms_test) migraphx::make_op("nonmaxsuppression", {{"center_point_box", true}}), b, s, mo, iou, st); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("nms_test.onnx"); + auto prog = read_onnx("nms_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/nms_use_dyn_output_false_test.cpp b/test/onnx/parse/nms_use_dyn_output_false_test.cpp index 947b03e8e94..77f5c519d89 100644 --- a/test/onnx/parse/nms_use_dyn_output_false_test.cpp +++ b/test/onnx/parse/nms_use_dyn_output_false_test.cpp @@ -50,6 +50,6 @@ TEST_CASE(nms_overwrite_use_dyn_output_test) migraphx::onnx_options options; options.use_dyn_output = true; - auto prog = migraphx::parse_onnx("nms_use_dyn_output_false_test.onnx", options); + auto prog = read_onnx("nms_use_dyn_output_false_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/nonzero_dynamic_test.cpp b/test/onnx/parse/nonzero_dynamic_test.cpp index 9f2c40913c8..ae4a456ce83 100644 --- a/test/onnx/parse/nonzero_dynamic_test.cpp +++ b/test/onnx/parse/nonzero_dynamic_test.cpp @@ -33,6 +33,6 @@ TEST_CASE(nonzero_dynamic_test) auto r = mm->add_instruction(migraphx::make_op("nonzero"), data); mm->add_return({r}); - auto prog = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); + auto prog = read_onnx("nonzero_dynamic_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/nonzero_int_test.cpp b/test/onnx/parse/nonzero_int_test.cpp index dd73d883612..835c52c8d58 100644 --- a/test/onnx/parse/nonzero_int_test.cpp +++ b/test/onnx/parse/nonzero_int_test.cpp @@ -37,6 +37,6 @@ TEST_CASE(nonzero_int_test) auto r = mm->add_literal(migraphx::literal(si, indices)); mm->add_return({r}); - auto prog = migraphx::parse_onnx("nonzero_int_test.onnx"); + auto prog = read_onnx("nonzero_int_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/nonzero_test.cpp b/test/onnx/parse/nonzero_test.cpp index c396077e917..5f2d49d57c2 100644 --- a/test/onnx/parse/nonzero_test.cpp +++ b/test/onnx/parse/nonzero_test.cpp @@ -37,6 +37,6 @@ TEST_CASE(nonzero_test) auto r = mm->add_literal(migraphx::literal(si, indices)); mm->add_return({r}); - auto prog = migraphx::parse_onnx("nonzero_test.onnx"); + auto prog = read_onnx("nonzero_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/not_bool_test.cpp b/test/onnx/parse/not_bool_test.cpp index 3566351d5df..86f0da61cd3 100644 --- a/test/onnx/parse/not_bool_test.cpp +++ b/test/onnx/parse/not_bool_test.cpp @@ -32,7 +32,7 @@ TEST_CASE(not_bool_test) auto ret = mm->add_instruction(migraphx::make_op("not"), l0); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("not_bool_test.onnx"); + auto prog = read_onnx("not_bool_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/not_test.cpp b/test/onnx/parse/not_test.cpp index 13901200224..a02fc79b0b7 100644 --- a/test/onnx/parse/not_test.cpp +++ b/test/onnx/parse/not_test.cpp @@ -32,7 +32,7 @@ TEST_CASE(not_test) auto ret = mm->add_instruction(migraphx::make_op("not"), l0); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("not_test.onnx"); + auto prog = read_onnx("not_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/onehot_test.cpp b/test/onnx/parse/onehot_test.cpp index c035580603b..ab6b8c50042 100644 --- a/test/onnx/parse/onehot_test.cpp +++ b/test/onnx/parse/onehot_test.cpp @@ -52,7 +52,7 @@ TEST_CASE(onehot_test) auto r = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val); mm->add_return({r}); - auto prog = migraphx::parse_onnx("onehot_test.onnx"); + auto prog = read_onnx("onehot_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_3arg_test.cpp b/test/onnx/parse/pad_3arg_test.cpp index 5aaaa2e4fb8..8a793ba79ea 100644 --- a/test/onnx/parse/pad_3arg_test.cpp +++ b/test/onnx/parse/pad_3arg_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(pad_3arg_test) migraphx::make_op("pad", {{"pads", {1, 1, 2, 2}}, {"value", 1.0f}}), l0); mm->add_return({r}); - auto prog = migraphx::parse_onnx("pad_3arg_test.onnx"); + auto prog = read_onnx("pad_3arg_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_4arg_axes_test.cpp b/test/onnx/parse/pad_4arg_axes_test.cpp index 7e3cae37604..d1d2d6c3086 100644 --- a/test/onnx/parse/pad_4arg_axes_test.cpp +++ b/test/onnx/parse/pad_4arg_axes_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(pad_4arg_axes_test) migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0); mm->add_return({r}); - auto prog = migraphx::parse_onnx("pad_4arg_axes_test.onnx"); + auto prog = read_onnx("pad_4arg_axes_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_4arg_invalid_axes_error_test.cpp b/test/onnx/parse/pad_4arg_invalid_axes_error_test.cpp index e723be96c9e..d35268e5dbd 100644 --- a/test/onnx/parse/pad_4arg_invalid_axes_error_test.cpp +++ b/test/onnx/parse/pad_4arg_invalid_axes_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(pad_4arg_invalid_axes_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("pad_4arg_invalid_axes_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("pad_4arg_invalid_axes_error_test.onnx"); })); } diff --git a/test/onnx/parse/pad_4arg_neg_axes_test.cpp b/test/onnx/parse/pad_4arg_neg_axes_test.cpp index 60083b5d99c..422e5bd80e7 100644 --- a/test/onnx/parse/pad_4arg_neg_axes_test.cpp +++ b/test/onnx/parse/pad_4arg_neg_axes_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(pad_4arg_neg_axes_test) migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0); mm->add_return({r}); - auto prog = migraphx::parse_onnx("pad_4arg_neg_axes_test.onnx"); + auto prog = read_onnx("pad_4arg_neg_axes_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_asym_invalid_pads_error_test.cpp b/test/onnx/parse/pad_asym_invalid_pads_error_test.cpp index 6c54ccb4f28..fc0c6368ec9 100644 --- a/test/onnx/parse/pad_asym_invalid_pads_error_test.cpp +++ b/test/onnx/parse/pad_asym_invalid_pads_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(pad_asym_invalid_pads_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("pad_asym_invalid_pads_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("pad_asym_invalid_pads_error_test.onnx"); })); } diff --git a/test/onnx/parse/pad_attr_dyn_test.cpp b/test/onnx/parse/pad_attr_dyn_test.cpp index 5d01257bfef..a161373afcc 100644 --- a/test/onnx/parse/pad_attr_dyn_test.cpp +++ b/test/onnx/parse/pad_attr_dyn_test.cpp @@ -35,6 +35,6 @@ TEST_CASE(pad_attr_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["0"] = {{2, 4, {2}}, {2, 4, {2}}}; - auto prog = parse_onnx("pad_attr_dyn_test.onnx", options); + auto prog = read_onnx("pad_attr_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_cnst_dyn_test.cpp b/test/onnx/parse/pad_cnst_dyn_test.cpp index ccbfdf550ac..76eb6c098fe 100644 --- a/test/onnx/parse/pad_cnst_dyn_test.cpp +++ b/test/onnx/parse/pad_cnst_dyn_test.cpp @@ -36,6 +36,6 @@ TEST_CASE(pad_cnst_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["0"] = {{2, 4, {2}}, {2, 4, {2}}}; - auto prog = parse_onnx("pad_cnst_dyn_test.onnx", options); + auto prog = read_onnx("pad_cnst_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_dyn_reflect_error.cpp b/test/onnx/parse/pad_dyn_reflect_error.cpp index 82a4f3bd281..057f95f6fce 100644 --- a/test/onnx/parse/pad_dyn_reflect_error.cpp +++ b/test/onnx/parse/pad_dyn_reflect_error.cpp @@ -28,5 +28,5 @@ TEST_CASE(pad_dyn_reflect_error) { migraphx::onnx_options options; options.default_dyn_dim_value = {2, 4, {2}}; - EXPECT(test::throws([&] { migraphx::parse_onnx("pad_dyn_reflect_error.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("pad_dyn_reflect_error.onnx", options); })); } diff --git a/test/onnx/parse/pad_reflect_multiaxis_test.cpp b/test/onnx/parse/pad_reflect_multiaxis_test.cpp index 0588b01745f..a7e04084e69 100644 --- a/test/onnx/parse/pad_reflect_multiaxis_test.cpp +++ b/test/onnx/parse/pad_reflect_multiaxis_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(pad_reflect_multiaxis_test) auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), l3, l4, l5); mm->add_return({r}); - auto prog = migraphx::parse_onnx("pad_reflect_multiaxis_test.onnx"); + auto prog = read_onnx("pad_reflect_multiaxis_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_reflect_test.cpp b/test/onnx/parse/pad_reflect_test.cpp index 887f1122f8f..141c0fc3123 100644 --- a/test/onnx/parse/pad_reflect_test.cpp +++ b/test/onnx/parse/pad_reflect_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(pad_reflect_test) auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3); mm->add_return({r}); - auto prog = migraphx::parse_onnx("pad_reflect_test.onnx"); + auto prog = read_onnx("pad_reflect_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pad_reflect_with_axes_test.cpp b/test/onnx/parse/pad_reflect_with_axes_test.cpp index 65ff4c355fb..20848a30f4b 100644 --- a/test/onnx/parse/pad_reflect_with_axes_test.cpp +++ b/test/onnx/parse/pad_reflect_with_axes_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(pad_reflect_with_axes_test) auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3); mm->add_return({r}); - auto prog = migraphx::parse_onnx("pad_reflect_with_axes_test.onnx"); + auto prog = read_onnx("pad_reflect_with_axes_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pow_fp32_i64_test.cpp b/test/onnx/parse/pow_fp32_i64_test.cpp index 8168f9186d1..3619b788f6d 100644 --- a/test/onnx/parse/pow_fp32_i64_test.cpp +++ b/test/onnx/parse/pow_fp32_i64_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(pow_fp32_i64_test) auto ret = mm->add_instruction(migraphx::make_op("pow"), l0, l1f); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("pow_fp32_i64_test.onnx"); + auto prog = read_onnx("pow_fp32_i64_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/pow_i64_fp32_test.cpp b/test/onnx/parse/pow_i64_fp32_test.cpp index 63d15be1776..fcb0d5748ec 100644 --- a/test/onnx/parse/pow_i64_fp32_test.cpp +++ b/test/onnx/parse/pow_i64_fp32_test.cpp @@ -37,7 +37,7 @@ TEST_CASE(pow_i64_fp32_test) migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), fr); mm->add_return({ir}); - auto prog = migraphx::parse_onnx("pow_i64_fp32_test.onnx"); + auto prog = read_onnx("pow_i64_fp32_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/prefix_scan_sum_test.cpp b/test/onnx/parse/prefix_scan_sum_test.cpp index b14a2fd4091..18458571cb3 100644 --- a/test/onnx/parse/prefix_scan_sum_test.cpp +++ b/test/onnx/parse/prefix_scan_sum_test.cpp @@ -35,6 +35,6 @@ TEST_CASE(prefix_scan_sum) l0); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("prefix_scan_sum_test.onnx"); + auto prog = read_onnx("prefix_scan_sum_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/prelu_brcst_test.cpp b/test/onnx/parse/prelu_brcst_test.cpp index 97498907bab..759fefe5f43 100644 --- a/test/onnx/parse/prelu_brcst_test.cpp +++ b/test/onnx/parse/prelu_brcst_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(prelu_brcst_test) auto ret = mm->add_instruction(migraphx::make_op("prelu"), l0, bl1); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("prelu_brcst_test.onnx"); + auto prog = read_onnx("prelu_brcst_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/qlinearadd_test.cpp b/test/onnx/parse/qlinearadd_test.cpp index 3b0a7ea1115..431a3e615bf 100644 --- a/test/onnx/parse/qlinearadd_test.cpp +++ b/test/onnx/parse/qlinearadd_test.cpp @@ -72,7 +72,7 @@ TEST_CASE(qlinearadd_test) mm->add_return({c}); - auto prog = migraphx::parse_onnx("qlinearadd_test.onnx"); + auto prog = read_onnx("qlinearadd_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearaveragepool_notset_test.cpp b/test/onnx/parse/qlinearaveragepool_notset_test.cpp index e43fc8a702e..9d006f7db15 100644 --- a/test/onnx/parse/qlinearaveragepool_notset_test.cpp +++ b/test/onnx/parse/qlinearaveragepool_notset_test.cpp @@ -68,7 +68,7 @@ TEST_CASE(qlinearaveragepool_notset_test) mm->add_instruction(migraphx::make_op("quantizelinear"), fp_y, scale_y_bcast, z_pt_y_bcast); mm->add_return({y}); - auto prog = migraphx::parse_onnx("qlinearaveragepool_notset_test.onnx"); + auto prog = read_onnx("qlinearaveragepool_notset_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/qlinearconcat_test.cpp b/test/onnx/parse/qlinearconcat_test.cpp index 0c5f06ed671..bf24cae938a 100644 --- a/test/onnx/parse/qlinearconcat_test.cpp +++ b/test/onnx/parse/qlinearconcat_test.cpp @@ -72,7 +72,7 @@ TEST_CASE(qlinearconcat_test) mm->add_return({y}); - auto prog = migraphx::parse_onnx("qlinearconcat_test.onnx"); + auto prog = read_onnx("qlinearconcat_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/qlinearconv_test.cpp b/test/onnx/parse/qlinearconv_test.cpp index 04813c02b56..10acc772c44 100644 --- a/test/onnx/parse/qlinearconv_test.cpp +++ b/test/onnx/parse/qlinearconv_test.cpp @@ -74,7 +74,7 @@ TEST_CASE(qlinearconv_test) mm->add_return({y}); - auto prog = migraphx::parse_onnx("qlinearconv_test.onnx"); + auto prog = read_onnx("qlinearconv_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearglobalavgpool_test.cpp b/test/onnx/parse/qlinearglobalavgpool_test.cpp index f98d79c9d91..c7e085f818d 100644 --- a/test/onnx/parse/qlinearglobalavgpool_test.cpp +++ b/test/onnx/parse/qlinearglobalavgpool_test.cpp @@ -65,7 +65,7 @@ TEST_CASE(qlinearglobalavgpool_test) mm->add_return({y}); - auto prog = migraphx::parse_onnx("qlinearglobalavgpool_test.onnx"); + auto prog = read_onnx("qlinearglobalavgpool_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearleakyrelu_test.cpp b/test/onnx/parse/qlinearleakyrelu_test.cpp index bdc9e0e75bc..331c9cef319 100644 --- a/test/onnx/parse/qlinearleakyrelu_test.cpp +++ b/test/onnx/parse/qlinearleakyrelu_test.cpp @@ -59,7 +59,7 @@ TEST_CASE(qlinearleakyrelu_test) mm->add_return({y}); - auto prog = migraphx::parse_onnx("qlinearleakyrelu_test.onnx"); + auto prog = read_onnx("qlinearleakyrelu_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearmatmul_1D_test.cpp b/test/onnx/parse/qlinearmatmul_1D_test.cpp index 39e3f1f1d9b..9550beb255e 100644 --- a/test/onnx/parse/qlinearmatmul_1D_test.cpp +++ b/test/onnx/parse/qlinearmatmul_1D_test.cpp @@ -78,7 +78,7 @@ TEST_CASE(qlinearmatmul_1D_test) mm->add_return({c}); - auto prog = migraphx::parse_onnx("qlinearmatmul_1D_test.onnx"); + auto prog = read_onnx("qlinearmatmul_1D_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearmatmul_2D_test.cpp b/test/onnx/parse/qlinearmatmul_2D_test.cpp index 4ccd3378308..b61f34d5f1f 100644 --- a/test/onnx/parse/qlinearmatmul_2D_test.cpp +++ b/test/onnx/parse/qlinearmatmul_2D_test.cpp @@ -72,7 +72,7 @@ TEST_CASE(qlinearmatmul_2D_test) mm->add_return({c}); - auto prog = migraphx::parse_onnx("qlinearmatmul_2D_test.onnx"); + auto prog = read_onnx("qlinearmatmul_2D_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearmul_test.cpp b/test/onnx/parse/qlinearmul_test.cpp index d307655b5ec..fbf31d45b1d 100644 --- a/test/onnx/parse/qlinearmul_test.cpp +++ b/test/onnx/parse/qlinearmul_test.cpp @@ -72,7 +72,7 @@ TEST_CASE(qlinearmul_test) mm->add_return({c}); - auto prog = migraphx::parse_onnx("qlinearmul_test.onnx"); + auto prog = read_onnx("qlinearmul_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/qlinearsigmoid_test.cpp b/test/onnx/parse/qlinearsigmoid_test.cpp index a92885d6201..06753e06ea8 100644 --- a/test/onnx/parse/qlinearsigmoid_test.cpp +++ b/test/onnx/parse/qlinearsigmoid_test.cpp @@ -59,7 +59,7 @@ TEST_CASE(qlinearsigmoid_test) mm->add_return({y}); - auto prog = migraphx::parse_onnx("qlinearsigmoid_test.onnx"); + auto prog = read_onnx("qlinearsigmoid_test.onnx"); EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/randomnormal_dtype_error_test.cpp b/test/onnx/parse/randomnormal_dtype_error_test.cpp index cec5462c322..01947d014ad 100644 --- a/test/onnx/parse/randomnormal_dtype_error_test.cpp +++ b/test/onnx/parse/randomnormal_dtype_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(randomnormal_dtype_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_dtype_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("randomnormal_dtype_error_test.onnx"); })); } diff --git a/test/onnx/parse/randomnormal_shape_error_test.cpp b/test/onnx/parse/randomnormal_shape_error_test.cpp index 590d2f72ab3..df0ee3200da 100644 --- a/test/onnx/parse/randomnormal_shape_error_test.cpp +++ b/test/onnx/parse/randomnormal_shape_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(randomnormal_shape_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_shape_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("randomnormal_shape_error_test.onnx"); })); } diff --git a/test/onnx/parse/randomnormallike_type_error_test.cpp b/test/onnx/parse/randomnormallike_type_error_test.cpp index 8015439e605..5a80fcfe0a4 100644 --- a/test/onnx/parse/randomnormallike_type_error_test.cpp +++ b/test/onnx/parse/randomnormallike_type_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(randomnormallike_type_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormallike_type_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("randomnormallike_type_error_test.onnx"); })); } diff --git a/test/onnx/parse/randomuniform_dtype_error_test.cpp b/test/onnx/parse/randomuniform_dtype_error_test.cpp index 53caf45b283..87acca789d5 100644 --- a/test/onnx/parse/randomuniform_dtype_error_test.cpp +++ b/test/onnx/parse/randomuniform_dtype_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(randomuniform_dtype_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_dtype_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("randomuniform_dtype_error_test.onnx"); })); } diff --git a/test/onnx/parse/randomuniform_shape_error_test.cpp b/test/onnx/parse/randomuniform_shape_error_test.cpp index 052095b1b5c..dd72dfd1bc9 100644 --- a/test/onnx/parse/randomuniform_shape_error_test.cpp +++ b/test/onnx/parse/randomuniform_shape_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(randomuniform_shape_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_shape_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("randomuniform_shape_error_test.onnx"); })); } diff --git a/test/onnx/parse/randomuniformlike_type_error_test.cpp b/test/onnx/parse/randomuniformlike_type_error_test.cpp index c79930ae229..7749d7797b3 100644 --- a/test/onnx/parse/randomuniformlike_type_error_test.cpp +++ b/test/onnx/parse/randomuniformlike_type_error_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(randomuniformlike_type_error_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniformlike_type_error_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("randomuniformlike_type_error_test.onnx"); })); } diff --git a/test/onnx/parse/reducel1_dyn_test.cpp b/test/onnx/parse/reducel1_dyn_test.cpp index f2491a2c228..2050887a36b 100644 --- a/test/onnx/parse/reducel1_dyn_test.cpp +++ b/test/onnx/parse/reducel1_dyn_test.cpp @@ -41,7 +41,7 @@ TEST_CASE(reducel1_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}; - auto prog = migraphx::parse_onnx("reducel1_dyn_test.onnx", options); + auto prog = read_onnx("reducel1_dyn_test.onnx", options); EXPECT(p == prog); } @@ -62,7 +62,7 @@ TEST_CASE(reducel1_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}; - auto prog = migraphx::parse_onnx("reducel1_dyn_noaxes_test.onnx", options); + auto prog = read_onnx("reducel1_dyn_noaxes_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/reducemax_dyn_test.cpp b/test/onnx/parse/reducemax_dyn_test.cpp index 17b82b2ade7..9be36c3e383 100644 --- a/test/onnx/parse/reducemax_dyn_test.cpp +++ b/test/onnx/parse/reducemax_dyn_test.cpp @@ -37,7 +37,7 @@ TEST_CASE(reducemax_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{3, 5}, {4, 4}, {5, 5}, {6, 6}}; - auto prog = migraphx::parse_onnx("reducemax_dyn_test.onnx", options); + auto prog = read_onnx("reducemax_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/reducesum_empty_axes_test.cpp b/test/onnx/parse/reducesum_empty_axes_test.cpp index d5a1e544ecb..02e925009b0 100644 --- a/test/onnx/parse/reducesum_empty_axes_test.cpp +++ b/test/onnx/parse/reducesum_empty_axes_test.cpp @@ -34,7 +34,7 @@ TEST_CASE(reducesum_empty_axes_test) auto r = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2, 3}}}), l1); mm->add_return({r}); - auto prog = migraphx::parse_onnx("reducesum_empty_axes_test.onnx"); + auto prog = read_onnx("reducesum_empty_axes_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/reducesum_noop_test.cpp b/test/onnx/parse/reducesum_noop_test.cpp index eeaa2eab964..aa5eea77cee 100644 --- a/test/onnx/parse/reducesum_noop_test.cpp +++ b/test/onnx/parse/reducesum_noop_test.cpp @@ -31,7 +31,7 @@ TEST_CASE(reducesum_noop_test) mm->add_literal(migraphx::literal{migraphx::shape::int64_type}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); mm->add_return({x}); - auto prog = migraphx::parse_onnx("reducesum_noop_test.onnx"); + auto prog = read_onnx("reducesum_noop_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp b/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp index d7ca00f188a..6aff7d30dff 100644 --- a/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp +++ b/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp @@ -25,6 +25,5 @@ TEST_CASE(reducesum_variable_axes_keepdims_clear_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("reducesum_variable_axes_keepdims_clear_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("reducesum_variable_axes_keepdims_clear_test.onnx"); })); } diff --git a/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp b/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp index 63c5200ff71..a5747d6805e 100644 --- a/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp +++ b/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp @@ -50,6 +50,6 @@ TEST_CASE(reducesum_variable_dynamic_axes_test) migraphx::onnx_options options; options.map_dyn_input_dims["axes"] = axes->get_shape().dyn_dims(); - auto prog = parse_onnx("reducesum_variable_dynamic_axes_test.onnx", options); + auto prog = read_onnx("reducesum_variable_dynamic_axes_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/reducesum_variable_empty_axes_test.cpp b/test/onnx/parse/reducesum_variable_empty_axes_test.cpp index cb660001b97..e55e1999e12 100644 --- a/test/onnx/parse/reducesum_variable_empty_axes_test.cpp +++ b/test/onnx/parse/reducesum_variable_empty_axes_test.cpp @@ -40,6 +40,6 @@ TEST_CASE(reducesum_variable_empty_axes_test) migraphx::onnx_options options; options.map_input_dims["axes"] = axes->get_shape().lens(); - auto prog = parse_onnx("reducesum_variable_axes_test.onnx", options); + auto prog = read_onnx("reducesum_variable_axes_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/reshape_variable_input_dyn_test.cpp b/test/onnx/parse/reshape_variable_input_dyn_test.cpp index 71aeb5356a3..26ea428c298 100644 --- a/test/onnx/parse/reshape_variable_input_dyn_test.cpp +++ b/test/onnx/parse/reshape_variable_input_dyn_test.cpp @@ -38,6 +38,6 @@ TEST_CASE(reshape_variable_input_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("reshape_variable_input_dyn_test.onnx", options); + auto prog = read_onnx("reshape_variable_input_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_downsample_c_test.cpp b/test/onnx/parse/resize_downsample_c_test.cpp index fcc45faff53..ff676505f3b 100644 --- a/test/onnx/parse/resize_downsample_c_test.cpp +++ b/test/onnx/parse/resize_downsample_c_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(resize_downsample_c_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx"); + auto prog = read_onnx("resize_downsample_c_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_downsample_f_dyn3_test.cpp b/test/onnx/parse/resize_downsample_f_dyn3_test.cpp index e5e30344b78..e67cad332f3 100644 --- a/test/onnx/parse/resize_downsample_f_dyn3_test.cpp +++ b/test/onnx/parse/resize_downsample_f_dyn3_test.cpp @@ -45,6 +45,6 @@ TEST_CASE(resize_downsample_f_dyn3_test) migraphx::onnx_options options; options.map_dyn_input_dims["X"] = {{1, 4, {1, 4}}, {1, 1}, {5, 5}, {9, 9}}; - auto prog = migraphx::parse_onnx("resize_downsample_f_dyn3_test.onnx", options); + auto prog = read_onnx("resize_downsample_f_dyn3_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_downsample_f_dyn_test.cpp b/test/onnx/parse/resize_downsample_f_dyn_test.cpp index 57156653bee..97cd0afed80 100644 --- a/test/onnx/parse/resize_downsample_f_dyn_test.cpp +++ b/test/onnx/parse/resize_downsample_f_dyn_test.cpp @@ -49,6 +49,6 @@ TEST_CASE(resize_downsample_f_dyn_test) migraphx::onnx_options options; options.map_dyn_input_dims["X"] = {{1, 4, {1, 4}}, {1, 1}, {5, 5}, {9, 9}}; - auto prog = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options); + auto prog = read_onnx("resize_downsample_f_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_downsample_f_test.cpp b/test/onnx/parse/resize_downsample_f_test.cpp index 142842466c5..b4e4f0c92f3 100644 --- a/test/onnx/parse/resize_downsample_f_test.cpp +++ b/test/onnx/parse/resize_downsample_f_test.cpp @@ -45,7 +45,7 @@ TEST_CASE(resize_downsample_f_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx"); + auto prog = read_onnx("resize_downsample_f_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_downsample_linear_test.cpp b/test/onnx/parse/resize_downsample_linear_test.cpp index 279baf60f36..9bea2481c64 100644 --- a/test/onnx/parse/resize_downsample_linear_test.cpp +++ b/test/onnx/parse/resize_downsample_linear_test.cpp @@ -88,6 +88,6 @@ TEST_CASE(resize_downsample_linear_test) auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10); mm->add_return({add1}); - auto prog = migraphx::parse_onnx("resize_downsample_linear_test.onnx"); + auto prog = read_onnx("resize_downsample_linear_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_dyn_err1test.cpp b/test/onnx/parse/resize_dyn_err1test.cpp index 04ad4998956..5f10d54ddf6 100644 --- a/test/onnx/parse/resize_dyn_err1test.cpp +++ b/test/onnx/parse/resize_dyn_err1test.cpp @@ -31,5 +31,5 @@ TEST_CASE(resize_dyn_err1_test) migraphx::onnx_options options; options.default_dyn_dim_value = dd; - EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err1_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("resize_dyn_err1_test.onnx", options); })); } diff --git a/test/onnx/parse/resize_linear_non_const_test.cpp b/test/onnx/parse/resize_linear_non_const_test.cpp index e0ed9c393db..a14f458d4ff 100644 --- a/test/onnx/parse/resize_linear_non_const_test.cpp +++ b/test/onnx/parse/resize_linear_non_const_test.cpp @@ -28,5 +28,5 @@ TEST_CASE(resize_linear_non_const_test) { // runtime (non-constant) input is only supported in "nearest" mode migraphx::onnx_options options; - EXPECT(test::throws([&] { parse_onnx("resize_linear_non_const_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("resize_linear_non_const_test.onnx", options); })); } diff --git a/test/onnx/parse/resize_no_scale_test.cpp b/test/onnx/parse/resize_no_scale_test.cpp index abb54e5f606..0a363a89af4 100644 --- a/test/onnx/parse/resize_no_scale_test.cpp +++ b/test/onnx/parse/resize_no_scale_test.cpp @@ -27,5 +27,5 @@ TEST_CASE(resize_no_scale_test) { // input node has neither shapes nor scales - EXPECT(test::throws([&] { migraphx::parse_onnx("resize_no_scale_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("resize_no_scale_test.onnx"); })); } diff --git a/test/onnx/parse/resize_nonstd_input_test.cpp b/test/onnx/parse/resize_nonstd_input_test.cpp index a7388466237..c485d185f3a 100644 --- a/test/onnx/parse/resize_nonstd_input_test.cpp +++ b/test/onnx/parse/resize_nonstd_input_test.cpp @@ -48,7 +48,7 @@ TEST_CASE(resize_nonstd_input_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("resize_nonstd_input_test.onnx"); + auto prog = read_onnx("resize_nonstd_input_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_outsize_test.cpp b/test/onnx/parse/resize_outsize_test.cpp index e80f8b9eb9e..b859713eed2 100644 --- a/test/onnx/parse/resize_outsize_test.cpp +++ b/test/onnx/parse/resize_outsize_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(resize_outsize_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("resize_outsize_test.onnx"); + auto prog = read_onnx("resize_outsize_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_upsample_linear_ac_test.cpp b/test/onnx/parse/resize_upsample_linear_ac_test.cpp index b7db0efaabb..fc31f4ef8ca 100644 --- a/test/onnx/parse/resize_upsample_linear_ac_test.cpp +++ b/test/onnx/parse/resize_upsample_linear_ac_test.cpp @@ -28,6 +28,6 @@ TEST_CASE(resize_upsample_linear_ac_test) { auto p = create_upsample_linear_prog(); - auto prog = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx"); + auto prog = read_onnx("resize_upsample_linear_ac_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_upsample_linear_test.cpp b/test/onnx/parse/resize_upsample_linear_test.cpp index 27d828a608c..21b3df96baa 100644 --- a/test/onnx/parse/resize_upsample_linear_test.cpp +++ b/test/onnx/parse/resize_upsample_linear_test.cpp @@ -115,6 +115,6 @@ TEST_CASE(resize_upsample_linear_test) auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10); mm->add_return({add1}); - auto prog = migraphx::parse_onnx("resize_upsample_linear_test.onnx"); + auto prog = read_onnx("resize_upsample_linear_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_upsample_pc_test.cpp b/test/onnx/parse/resize_upsample_pc_test.cpp index 6e95db22a07..b7d71f84a71 100644 --- a/test/onnx/parse/resize_upsample_pc_test.cpp +++ b/test/onnx/parse/resize_upsample_pc_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(resize_upsample_pc_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("resize_upsample_pc_test.onnx"); + auto prog = read_onnx("resize_upsample_pc_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/resize_upsample_pf_test.cpp b/test/onnx/parse/resize_upsample_pf_test.cpp index 28a10de5da9..d84138b4166 100644 --- a/test/onnx/parse/resize_upsample_pf_test.cpp +++ b/test/onnx/parse/resize_upsample_pf_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(resize_upsample_pf_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("resize_upsample_pf_test.onnx"); + auto prog = read_onnx("resize_upsample_pf_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/reversesequence_batch_axis_err_test.cpp b/test/onnx/parse/reversesequence_batch_axis_err_test.cpp index 115eee8d0a2..ebb752bcbd6 100644 --- a/test/onnx/parse/reversesequence_batch_axis_err_test.cpp +++ b/test/onnx/parse/reversesequence_batch_axis_err_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(reversesequence_batch_axis_err_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_batch_axis_err_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("reversesequence_batch_axis_err_test.onnx"); })); } diff --git a/test/onnx/parse/reversesequence_batch_test.cpp b/test/onnx/parse/reversesequence_batch_test.cpp index 769fc9959f0..fb9d63ca470 100644 --- a/test/onnx/parse/reversesequence_batch_test.cpp +++ b/test/onnx/parse/reversesequence_batch_test.cpp @@ -63,6 +63,6 @@ TEST_CASE(reversesequence_batch_test) } mm->add_return({ret}); - auto prog = migraphx::parse_onnx("reversesequence_batch_test.onnx"); + auto prog = read_onnx("reversesequence_batch_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/reversesequence_rank_err_test.cpp b/test/onnx/parse/reversesequence_rank_err_test.cpp index dbf89d6d7bc..86e45a2141e 100644 --- a/test/onnx/parse/reversesequence_rank_err_test.cpp +++ b/test/onnx/parse/reversesequence_rank_err_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(reversesequence_rank_err_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_rank_err_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("reversesequence_rank_err_test.onnx"); })); } diff --git a/test/onnx/parse/reversesequence_same_axis_err_test.cpp b/test/onnx/parse/reversesequence_same_axis_err_test.cpp index 1d1954ebc0f..5e390293fbe 100644 --- a/test/onnx/parse/reversesequence_same_axis_err_test.cpp +++ b/test/onnx/parse/reversesequence_same_axis_err_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(reversesequence_same_axis_err_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_same_axis_err_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("reversesequence_same_axis_err_test.onnx"); })); } diff --git a/test/onnx/parse/reversesequence_sequence_lens_shape_err_test.cpp b/test/onnx/parse/reversesequence_sequence_lens_shape_err_test.cpp index 424bc41cc85..5735083ce0b 100644 --- a/test/onnx/parse/reversesequence_sequence_lens_shape_err_test.cpp +++ b/test/onnx/parse/reversesequence_sequence_lens_shape_err_test.cpp @@ -26,6 +26,5 @@ TEST_CASE(reversesequence_sequence_lens_shape_err_test) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("reversesequence_sequence_lens_shape_err_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("reversesequence_sequence_lens_shape_err_test.onnx"); })); } diff --git a/test/onnx/parse/reversesequence_time_axis_err_test.cpp b/test/onnx/parse/reversesequence_time_axis_err_test.cpp index 198c89fc739..2527114ac59 100644 --- a/test/onnx/parse/reversesequence_time_axis_err_test.cpp +++ b/test/onnx/parse/reversesequence_time_axis_err_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(reversesequence_time_axis_err_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_time_axis_err_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("reversesequence_time_axis_err_test.onnx"); })); } diff --git a/test/onnx/parse/reversesequence_time_test.cpp b/test/onnx/parse/reversesequence_time_test.cpp index 6d7de89aa9e..6dc453395a6 100644 --- a/test/onnx/parse/reversesequence_time_test.cpp +++ b/test/onnx/parse/reversesequence_time_test.cpp @@ -71,6 +71,6 @@ TEST_CASE(reversesequence_time_test) ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("reversesequence_time_test.onnx"); + auto prog = read_onnx("reversesequence_time_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/roialign_default_test.cpp b/test/onnx/parse/roialign_default_test.cpp index 5bd4872a248..ceca14df9aa 100644 --- a/test/onnx/parse/roialign_default_test.cpp +++ b/test/onnx/parse/roialign_default_test.cpp @@ -45,7 +45,7 @@ TEST_CASE(roialign_default_test) bi); mm->add_return({r}); - auto prog = migraphx::parse_onnx("roialign_default_test.onnx"); + auto prog = read_onnx("roialign_default_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/roialign_test.cpp b/test/onnx/parse/roialign_test.cpp index 853fc68da77..567117e6a0f 100644 --- a/test/onnx/parse/roialign_test.cpp +++ b/test/onnx/parse/roialign_test.cpp @@ -48,7 +48,7 @@ TEST_CASE(roialign_test) bi); mm->add_return({r}); - auto prog = migraphx::parse_onnx("roialign_test.onnx"); + auto prog = read_onnx("roialign_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/scatternd_dyn_test.cpp b/test/onnx/parse/scatternd_dyn_test.cpp index ee7546e5e1e..d4d15f90594 100644 --- a/test/onnx/parse/scatternd_dyn_test.cpp +++ b/test/onnx/parse/scatternd_dyn_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(scatternd_dyn_test) options.map_dyn_input_dims["data"] = {{1, 3, {2}}, {2, 2}, {2, 2}}; options.map_dyn_input_dims["indices"] = {{2, 1, {2}}, {1, 1}, {2, 2}}; options.map_dyn_input_dims["updates"] = {{2, 1, {2}}, {1, 1}, {2, 2}}; - auto prog = migraphx::parse_onnx("scatternd_dyn_test.onnx", options); + auto prog = read_onnx("scatternd_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/scatternd_invalid_reduction_test.cpp b/test/onnx/parse/scatternd_invalid_reduction_test.cpp index 4afab8583a1..b5d049d97e6 100644 --- a/test/onnx/parse/scatternd_invalid_reduction_test.cpp +++ b/test/onnx/parse/scatternd_invalid_reduction_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(scatternd_invalid_reduction_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("scatternd_invalid_reduction_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("scatternd_invalid_reduction_test.onnx"); })); } diff --git a/test/onnx/parse/selu_test.cpp b/test/onnx/parse/selu_test.cpp index 8506ef30bfa..b77d324f447 100644 --- a/test/onnx/parse/selu_test.cpp +++ b/test/onnx/parse/selu_test.cpp @@ -52,7 +52,7 @@ TEST_CASE(selu_test) auto r = mm->add_instruction(migraphx::make_op("mul"), item12, mblg); mm->add_return({r}); - auto prog = migraphx::parse_onnx("selu_test.onnx"); + auto prog = read_onnx("selu_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/shape_dyn_test0.cpp b/test/onnx/parse/shape_dyn_test0.cpp index 41119c2ead3..abf00e77195 100644 --- a/test/onnx/parse/shape_dyn_test0.cpp +++ b/test/onnx/parse/shape_dyn_test0.cpp @@ -36,7 +36,7 @@ TEST_CASE(shape_dyn_test0) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - auto prog = parse_onnx("shape_dyn_test0.onnx", options); + auto prog = read_onnx("shape_dyn_test0.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/shape_dyn_test1.cpp b/test/onnx/parse/shape_dyn_test1.cpp index d6eebbca512..0e4de4a6cf7 100644 --- a/test/onnx/parse/shape_dyn_test1.cpp +++ b/test/onnx/parse/shape_dyn_test1.cpp @@ -37,7 +37,7 @@ TEST_CASE(shape_dyn_test1) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - auto prog = parse_onnx("shape_dyn_test1.onnx", options); + auto prog = read_onnx("shape_dyn_test1.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/shape_dyn_test2.cpp b/test/onnx/parse/shape_dyn_test2.cpp index 90d6ee908d3..da405f1bbfd 100644 --- a/test/onnx/parse/shape_dyn_test2.cpp +++ b/test/onnx/parse/shape_dyn_test2.cpp @@ -37,7 +37,7 @@ TEST_CASE(shape_dyn_test2) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - auto prog = parse_onnx("shape_dyn_test2.onnx", options); + auto prog = read_onnx("shape_dyn_test2.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/shape_dyn_test3.cpp b/test/onnx/parse/shape_dyn_test3.cpp index 5c6a6de9de9..b5852be0d45 100644 --- a/test/onnx/parse/shape_dyn_test3.cpp +++ b/test/onnx/parse/shape_dyn_test3.cpp @@ -37,7 +37,7 @@ TEST_CASE(shape_dyn_test3) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - auto prog = parse_onnx("shape_dyn_test3.onnx", options); + auto prog = read_onnx("shape_dyn_test3.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/shape_end_less_start_error.cpp b/test/onnx/parse/shape_end_less_start_error.cpp index cda43eb3b54..8497606ec9b 100644 --- a/test/onnx/parse/shape_end_less_start_error.cpp +++ b/test/onnx/parse/shape_end_less_start_error.cpp @@ -28,5 +28,5 @@ TEST_CASE(shape_end_less_start_error) { migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - EXPECT(test::throws([&] { migraphx::parse_onnx("shape_end_less_start_error.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("shape_end_less_start_error.onnx", options); })); } diff --git a/test/onnx/parse/shape_end_oob_test.cpp b/test/onnx/parse/shape_end_oob_test.cpp index 06a3d8979aa..cc4f5e26dc2 100644 --- a/test/onnx/parse/shape_end_oob_test.cpp +++ b/test/onnx/parse/shape_end_oob_test.cpp @@ -36,7 +36,7 @@ TEST_CASE(shape_end_oob_test) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - auto prog = migraphx::parse_onnx("shape_end_oob_test.onnx", options); + auto prog = read_onnx("shape_end_oob_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/shape_start_oob_test.cpp b/test/onnx/parse/shape_start_oob_test.cpp index 9c0e4f17292..dc19c775fbb 100644 --- a/test/onnx/parse/shape_start_oob_test.cpp +++ b/test/onnx/parse/shape_start_oob_test.cpp @@ -36,7 +36,7 @@ TEST_CASE(shape_start_oob_test) migraphx::onnx_options options; options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}; - auto prog = migraphx::parse_onnx("shape_start_oob_test.onnx", options); + auto prog = read_onnx("shape_start_oob_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/sinh_dynamic_test.cpp b/test/onnx/parse/sinh_dynamic_test.cpp index 6b5390b6b9c..187468c6e75 100644 --- a/test/onnx/parse/sinh_dynamic_test.cpp +++ b/test/onnx/parse/sinh_dynamic_test.cpp @@ -37,7 +37,7 @@ TEST_CASE(sinh_dynamic_test) migraphx::onnx_options options; options.default_dyn_dim_value = dd; - auto prog = parse_onnx("sinh_dynamic_test.onnx", options); + auto prog = read_onnx("sinh_dynamic_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_3arg_test.cpp b/test/onnx/parse/slice_3arg_test.cpp index 13376a36292..46b3a66dd10 100644 --- a/test/onnx/parse/slice_3arg_test.cpp +++ b/test/onnx/parse/slice_3arg_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(slice_3arg_test) migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 5}}}), l0); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("slice_3arg_test.onnx"); + auto prog = read_onnx("slice_3arg_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_5arg_reverse_test.cpp b/test/onnx/parse/slice_5arg_reverse_test.cpp index 66dec623032..3cb6ba2a461 100644 --- a/test/onnx/parse/slice_5arg_reverse_test.cpp +++ b/test/onnx/parse/slice_5arg_reverse_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(slice_5arg_reverse_test) auto ret = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("slice_5arg_reverse_test.onnx"); + auto prog = read_onnx("slice_5arg_reverse_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_5arg_step_test.cpp b/test/onnx/parse/slice_5arg_step_test.cpp index d03b84967c2..e80892b9662 100644 --- a/test/onnx/parse/slice_5arg_step_test.cpp +++ b/test/onnx/parse/slice_5arg_step_test.cpp @@ -43,7 +43,7 @@ TEST_CASE(slice_5arg_step_test) migraphx::make_op("step", {{"axes", {-1, -2}}, {"steps", {2, 2}}}), reverse_out); mm->add_return({step_out}); - auto prog = migraphx::parse_onnx("slice_5arg_step_test.onnx"); + auto prog = read_onnx("slice_5arg_step_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_5arg_test.cpp b/test/onnx/parse/slice_5arg_test.cpp index 44a632af248..9981abf8888 100644 --- a/test/onnx/parse/slice_5arg_test.cpp +++ b/test/onnx/parse/slice_5arg_test.cpp @@ -38,7 +38,7 @@ TEST_CASE(slice_5arg_test) l0); mm->add_return({ret}); - auto prog = migraphx::parse_onnx("slice_5arg_test.onnx"); + auto prog = read_onnx("slice_5arg_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_dyn_test.cpp b/test/onnx/parse/slice_dyn_test.cpp index 835ccd2d37b..44479cfa3d4 100644 --- a/test/onnx/parse/slice_dyn_test.cpp +++ b/test/onnx/parse/slice_dyn_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(slice_dyn_test) // Parser converts the dynamic input shape to static unless there is at least one non-fixed // dynamic dimension. Slicing is not allowed along the non-fixed axis 1. options.map_dyn_input_dims["0"] = {{3, 3}, {1, 3}, {2, 2}}; - auto prog = migraphx::parse_onnx("slice_dyn_test.onnx", options); + auto prog = read_onnx("slice_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_reverse_dyn_test.cpp b/test/onnx/parse/slice_reverse_dyn_test.cpp index 1a9c6b0cf70..dfde0180bb2 100644 --- a/test/onnx/parse/slice_reverse_dyn_test.cpp +++ b/test/onnx/parse/slice_reverse_dyn_test.cpp @@ -30,5 +30,5 @@ TEST_CASE(slice_reverse_dyn_test) // parsing. At the time of writing, Reverse doesn't support dynamic shape input. migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws([&] { migraphx::parse_onnx("slice_reverse_dyn_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("slice_reverse_dyn_test.onnx", options); })); } diff --git a/test/onnx/parse/slice_step_dyn_test.cpp b/test/onnx/parse/slice_step_dyn_test.cpp index d4b0aaf6cc9..2e97b62048e 100644 --- a/test/onnx/parse/slice_step_dyn_test.cpp +++ b/test/onnx/parse/slice_step_dyn_test.cpp @@ -30,5 +30,5 @@ TEST_CASE(slice_step_dyn_test) // At the time of writing, Step doesn't support dynamic shape input. migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws([&] { migraphx::parse_onnx("slice_step_dyn_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("slice_step_dyn_test.onnx", options); })); } diff --git a/test/onnx/parse/slice_var_input_default_steps.cpp b/test/onnx/parse/slice_var_input_default_steps.cpp index 0e52410bfd4..19e853be8f1 100644 --- a/test/onnx/parse/slice_var_input_default_steps.cpp +++ b/test/onnx/parse/slice_var_input_default_steps.cpp @@ -39,6 +39,6 @@ TEST_CASE(slice_var_input_default_steps) migraphx::onnx_options options; options.default_dyn_dim_value = {3, 8}; - auto prog = parse_onnx("slice_var_input_default_steps.onnx", options); + auto prog = read_onnx("slice_var_input_default_steps.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_var_input_dyn0.cpp b/test/onnx/parse/slice_var_input_dyn0.cpp index 963b408ddb2..a5a3495f8f9 100644 --- a/test/onnx/parse/slice_var_input_dyn0.cpp +++ b/test/onnx/parse/slice_var_input_dyn0.cpp @@ -38,6 +38,6 @@ TEST_CASE(slice_var_input_dyn0) migraphx::onnx_options options; options.default_dyn_dim_value = {3, 8}; - auto prog = parse_onnx("slice_var_input_dyn0.onnx", options); + auto prog = read_onnx("slice_var_input_dyn0.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_var_input_dyn1.cpp b/test/onnx/parse/slice_var_input_dyn1.cpp index 60c741982a8..adcfc4d1833 100644 --- a/test/onnx/parse/slice_var_input_dyn1.cpp +++ b/test/onnx/parse/slice_var_input_dyn1.cpp @@ -38,6 +38,6 @@ TEST_CASE(slice_var_input_dyn1) migraphx::onnx_options options; options.default_dyn_dim_value = {3, 8}; - auto prog = parse_onnx("slice_var_input_dyn1.onnx", options); + auto prog = read_onnx("slice_var_input_dyn1.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/slice_var_input_steps_error.cpp b/test/onnx/parse/slice_var_input_steps_error.cpp index 2eef62e4ebf..28acdeb6e82 100644 --- a/test/onnx/parse/slice_var_input_steps_error.cpp +++ b/test/onnx/parse/slice_var_input_steps_error.cpp @@ -26,5 +26,5 @@ TEST_CASE(slice_var_input_steps_error) { - EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); })); + EXPECT(test::throws([&] { read_onnx("slice_var_input_steps_error.onnx"); })); } diff --git a/test/onnx/parse/softmax_dyn_test.cpp b/test/onnx/parse/softmax_dyn_test.cpp index 023f1bea8b1..d582dfffad6 100644 --- a/test/onnx/parse/softmax_dyn_test.cpp +++ b/test/onnx/parse/softmax_dyn_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(softmax_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("softmax_dyn_test.onnx", options); + auto prog = read_onnx("softmax_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/softmax_nonstd_input_test.cpp b/test/onnx/parse/softmax_nonstd_input_test.cpp index 6cba85dbc06..b04b52546c7 100644 --- a/test/onnx/parse/softmax_nonstd_input_test.cpp +++ b/test/onnx/parse/softmax_nonstd_input_test.cpp @@ -34,7 +34,7 @@ TEST_CASE(softmax_nonstd_input_test) auto l2 = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), l1); mm->add_return({l2}); - auto prog = migraphx::parse_onnx("softmax_nonstd_input_test.onnx"); + auto prog = read_onnx("softmax_nonstd_input_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/spacetodepth_invalid_blocksize_test.cpp b/test/onnx/parse/spacetodepth_invalid_blocksize_test.cpp index 9c879a9b408..a0ad7588b8e 100644 --- a/test/onnx/parse/spacetodepth_invalid_blocksize_test.cpp +++ b/test/onnx/parse/spacetodepth_invalid_blocksize_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(spacetodepth_invalid_blocksize) { - EXPECT(test::throws([&] { migraphx::parse_onnx("spacetodepth_invalid_blocksize_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("spacetodepth_invalid_blocksize_test.onnx"); })); } diff --git a/test/onnx/parse/spacetodepth_nondivisibility_test.cpp b/test/onnx/parse/spacetodepth_nondivisibility_test.cpp index 2aa0dfb7b85..9722f625135 100644 --- a/test/onnx/parse/spacetodepth_nondivisibility_test.cpp +++ b/test/onnx/parse/spacetodepth_nondivisibility_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(spacetodepth_nondivisibility_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("spacetodepth_nondivisibility_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("spacetodepth_nondivisibility_test.onnx"); })); } diff --git a/test/onnx/parse/split_minus_axis_test.cpp b/test/onnx/parse/split_minus_axis_test.cpp index 4d467a69617..c1200eb7878 100644 --- a/test/onnx/parse/split_minus_axis_test.cpp +++ b/test/onnx/parse/split_minus_axis_test.cpp @@ -37,7 +37,7 @@ TEST_CASE(split_minus_axis_test) migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {10}}, {"ends", {15}}}), input); mm->add_return({r1, r2, r3}); - auto prog = migraphx::parse_onnx("split_minus_axis_test.onnx"); + auto prog = read_onnx("split_minus_axis_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/split_test.cpp b/test/onnx/parse/split_test.cpp index 2036c17a335..698022dd8ea 100644 --- a/test/onnx/parse/split_test.cpp +++ b/test/onnx/parse/split_test.cpp @@ -37,6 +37,6 @@ TEST_CASE(split_test) migraphx::make_op("slice", {{"axes", {1}}, {"starts", {11}}, {"ends", {15}}}), input); mm->add_return({r1, r2, r3}); - auto prog = migraphx::parse_onnx("split_test.onnx"); + auto prog = read_onnx("split_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/split_test_default.cpp b/test/onnx/parse/split_test_default.cpp index 189ada5ca00..b52aabd04fe 100644 --- a/test/onnx/parse/split_test_default.cpp +++ b/test/onnx/parse/split_test_default.cpp @@ -35,6 +35,6 @@ TEST_CASE(split_test_default) migraphx::make_op("slice", {{"axes", {0}}, {"starts", {5}}, {"ends", {10}}}), input); mm->add_return({r1, r2}); - auto prog = migraphx::parse_onnx("split_test_default.onnx"); + auto prog = read_onnx("split_test_default.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/split_test_invalid_num_outputs.cpp b/test/onnx/parse/split_test_invalid_num_outputs.cpp index 8ad8a532cf6..e80c421b2ce 100644 --- a/test/onnx/parse/split_test_invalid_num_outputs.cpp +++ b/test/onnx/parse/split_test_invalid_num_outputs.cpp @@ -26,5 +26,5 @@ TEST_CASE(split_test_invalid_num_outputs) { - EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_num_outputs.onnx"); })); + EXPECT(test::throws([&] { read_onnx("split_test_invalid_num_outputs.onnx"); })); } diff --git a/test/onnx/parse/split_test_invalid_split.cpp b/test/onnx/parse/split_test_invalid_split.cpp index 9c8cf7f46e2..a0dafe36116 100644 --- a/test/onnx/parse/split_test_invalid_split.cpp +++ b/test/onnx/parse/split_test_invalid_split.cpp @@ -26,5 +26,5 @@ TEST_CASE(split_test_invalid_split) { - EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_split.onnx"); })); + EXPECT(test::throws([&] { read_onnx("split_test_invalid_split.onnx"); })); } diff --git a/test/onnx/parse/split_test_no_attribute.cpp b/test/onnx/parse/split_test_no_attribute.cpp index 937a9d3bd99..615143a1b6f 100644 --- a/test/onnx/parse/split_test_no_attribute.cpp +++ b/test/onnx/parse/split_test_no_attribute.cpp @@ -45,6 +45,6 @@ TEST_CASE(split_test_no_attribute) mm->add_return({r1, r2, r3, r4}); - auto prog = migraphx::parse_onnx("split_test_no_attribute.onnx"); + auto prog = read_onnx("split_test_no_attribute.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/split_test_no_attribute_invalid_input_split.cpp b/test/onnx/parse/split_test_no_attribute_invalid_input_split.cpp index 7e87265d945..5503886f5ba 100644 --- a/test/onnx/parse/split_test_no_attribute_invalid_input_split.cpp +++ b/test/onnx/parse/split_test_no_attribute_invalid_input_split.cpp @@ -26,6 +26,5 @@ TEST_CASE(split_test_no_attribute_invalid_input_split) { - EXPECT(test::throws( - [&] { migraphx::parse_onnx("split_test_no_attribute_invalid_input_split.onnx"); })); + EXPECT(test::throws([&] { read_onnx("split_test_no_attribute_invalid_input_split.onnx"); })); } diff --git a/test/onnx/parse/split_test_no_attribute_invalid_split.cpp b/test/onnx/parse/split_test_no_attribute_invalid_split.cpp index 76f4a2793f1..b849ca6fc6c 100644 --- a/test/onnx/parse/split_test_no_attribute_invalid_split.cpp +++ b/test/onnx/parse/split_test_no_attribute_invalid_split.cpp @@ -26,6 +26,5 @@ TEST_CASE(split_test_no_attribute_invalid_split) { - EXPECT( - test::throws([&] { migraphx::parse_onnx("split_test_no_attribute_invalid_split.onnx"); })); + EXPECT(test::throws([&] { read_onnx("split_test_no_attribute_invalid_split.onnx"); })); } diff --git a/test/onnx/parse/split_test_uneven.cpp b/test/onnx/parse/split_test_uneven.cpp index 834de5b3180..cb4d4583319 100644 --- a/test/onnx/parse/split_test_uneven.cpp +++ b/test/onnx/parse/split_test_uneven.cpp @@ -41,6 +41,6 @@ TEST_CASE(split_test_uneven) migraphx::make_op("slice", {{"axes", {0}}, {"starts", {12}}, {"ends", {12}}}), input); mm->add_return({r1, r2, r3, r4, r5}); - auto prog = migraphx::parse_onnx("split_test_uneven.onnx"); + auto prog = read_onnx("split_test_uneven.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/split_test_uneven_num_outputs.cpp b/test/onnx/parse/split_test_uneven_num_outputs.cpp index b84addd2fef..5ee32b8399f 100644 --- a/test/onnx/parse/split_test_uneven_num_outputs.cpp +++ b/test/onnx/parse/split_test_uneven_num_outputs.cpp @@ -39,6 +39,6 @@ TEST_CASE(split_test_uneven_num_outputs) migraphx::make_op("slice", {{"axes", {0}}, {"starts", {9}}, {"ends", {11}}}), input); mm->add_return({r1, r2, r3, r4}); - auto prog = migraphx::parse_onnx("split_test_uneven_num_outputs.onnx"); + auto prog = read_onnx("split_test_uneven_num_outputs.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/squeeze_axes_input_test.cpp b/test/onnx/parse/squeeze_axes_input_test.cpp index 32768bf2893..10d2b41615c 100644 --- a/test/onnx/parse/squeeze_axes_input_test.cpp +++ b/test/onnx/parse/squeeze_axes_input_test.cpp @@ -33,7 +33,7 @@ TEST_CASE(squeeze_axes_input_test) auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), l0); mm->add_return({l1}); - auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx"); + auto prog = read_onnx("squeeze_axes_input_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/squeeze_empty_axes_test.cpp b/test/onnx/parse/squeeze_empty_axes_test.cpp index bb34c341b02..6f30bb27c40 100644 --- a/test/onnx/parse/squeeze_empty_axes_test.cpp +++ b/test/onnx/parse/squeeze_empty_axes_test.cpp @@ -33,7 +33,7 @@ TEST_CASE(squeeze_empty_axes_test) auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0); mm->add_return({l1}); - auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx"); + auto prog = read_onnx("squeeze_empty_axes_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/squeeze_unsqueeze_dyn_test.cpp b/test/onnx/parse/squeeze_unsqueeze_dyn_test.cpp index 79db7fa518a..c520f788cd5 100644 --- a/test/onnx/parse/squeeze_unsqueeze_dyn_test.cpp +++ b/test/onnx/parse/squeeze_unsqueeze_dyn_test.cpp @@ -41,7 +41,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("squeeze_unsqueeze_dyn_test.onnx", options); + auto prog = read_onnx("squeeze_unsqueeze_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/sum_type_test.cpp b/test/onnx/parse/sum_type_test.cpp index 5f3b4f13930..de4920b7672 100644 --- a/test/onnx/parse/sum_type_test.cpp +++ b/test/onnx/parse/sum_type_test.cpp @@ -69,7 +69,7 @@ TEST_CASE(sum_type_test) auto s6 = mm->add_instruction(migraphx::make_op("add"), s5, l_raw); mm->add_return({s6}); - auto prog = migraphx::parse_onnx("sum_type_test.onnx"); + auto prog = read_onnx("sum_type_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/topk_attrk_test.cpp b/test/onnx/parse/topk_attrk_test.cpp index d33e5ac2dc7..b134ab65e4a 100644 --- a/test/onnx/parse/topk_attrk_test.cpp +++ b/test/onnx/parse/topk_attrk_test.cpp @@ -35,7 +35,7 @@ TEST_CASE(topk_attrk_test) auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out); mm->add_return({val, ind}); - auto prog = migraphx::parse_onnx("topk_attrk_test.onnx"); + auto prog = read_onnx("topk_attrk_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/topk_neg_axis_test.cpp b/test/onnx/parse/topk_neg_axis_test.cpp index d11de4aca42..7c4974c894b 100644 --- a/test/onnx/parse/topk_neg_axis_test.cpp +++ b/test/onnx/parse/topk_neg_axis_test.cpp @@ -38,7 +38,7 @@ TEST_CASE(topk_neg_axis_test) auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out); mm->add_return({val, ind}); - auto prog = migraphx::parse_onnx("topk_neg_axis_test.onnx"); + auto prog = read_onnx("topk_neg_axis_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/topk_test.cpp b/test/onnx/parse/topk_test.cpp index bf806c3dc6f..d9529260ecd 100644 --- a/test/onnx/parse/topk_test.cpp +++ b/test/onnx/parse/topk_test.cpp @@ -38,7 +38,7 @@ TEST_CASE(topk_test) auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out); mm->add_return({val, ind}); - auto prog = migraphx::parse_onnx("topk_test.onnx"); + auto prog = read_onnx("topk_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/transpose_default_perm_test.cpp b/test/onnx/parse/transpose_default_perm_test.cpp index 8fdfbf2d047..8a2eef1fbff 100644 --- a/test/onnx/parse/transpose_default_perm_test.cpp +++ b/test/onnx/parse/transpose_default_perm_test.cpp @@ -33,7 +33,7 @@ TEST_CASE(transpose_default_perm_test) auto r = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input); mm->add_return({r}); - auto prog = migraphx::parse_onnx("transpose_default_perm_test.onnx"); + auto prog = read_onnx("transpose_default_perm_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/transpose_dyn_test.cpp b/test/onnx/parse/transpose_dyn_test.cpp index 62e47230271..92b820baf91 100644 --- a/test/onnx/parse/transpose_dyn_test.cpp +++ b/test/onnx/parse/transpose_dyn_test.cpp @@ -36,7 +36,7 @@ TEST_CASE(transpose_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options); + auto prog = read_onnx("transpose_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/transpose_invalid_perm_test.cpp b/test/onnx/parse/transpose_invalid_perm_test.cpp index 93645129d3a..8580bdf26e7 100644 --- a/test/onnx/parse/transpose_invalid_perm_test.cpp +++ b/test/onnx/parse/transpose_invalid_perm_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(transpose_invalid_perm_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("transpose_invalid_perm_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("transpose_invalid_perm_test.onnx"); })); } diff --git a/test/onnx/parse/undefined_test.cpp b/test/onnx/parse/undefined_test.cpp index 8385a90f561..47bba1a4bec 100644 --- a/test/onnx/parse/undefined_test.cpp +++ b/test/onnx/parse/undefined_test.cpp @@ -33,7 +33,7 @@ TEST_CASE(undefined_test) auto l2 = mm->add_instruction(migraphx::make_op("identity"), l1); mm->add_return({l2}); - auto prog = migraphx::parse_onnx("undefined_test.onnx"); + auto prog = read_onnx("undefined_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/unique_dynamic_sorted_3D_test.cpp b/test/onnx/parse/unique_dynamic_sorted_3D_test.cpp index db5597bc771..9e4370bdbe5 100644 --- a/test/onnx/parse/unique_dynamic_sorted_3D_test.cpp +++ b/test/onnx/parse/unique_dynamic_sorted_3D_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(unique_dynamic_sorted_3D_test) auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out); mm->add_return({y, y_ind, x_ind, count}); - auto prog = migraphx::parse_onnx("unique_dynamic_sorted_3D_test.onnx"); + auto prog = read_onnx("unique_dynamic_sorted_3D_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/unique_dynamic_sorted_test.cpp b/test/onnx/parse/unique_dynamic_sorted_test.cpp index e8ed74710c7..9c0f957de35 100644 --- a/test/onnx/parse/unique_dynamic_sorted_test.cpp +++ b/test/onnx/parse/unique_dynamic_sorted_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(unique_dynamic_sorted_test) auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out); mm->add_return({y, y_ind, x_ind, count}); - auto prog = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx"); + auto prog = read_onnx("unique_dynamic_sorted_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/unique_sorted_test.cpp b/test/onnx/parse/unique_sorted_test.cpp index b077d181a2b..93dac2d57ee 100644 --- a/test/onnx/parse/unique_sorted_test.cpp +++ b/test/onnx/parse/unique_sorted_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(unique_sorted_test) auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out); auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out); mm->add_return({y, y_idx, x_idx, count}); - auto prog = migraphx::parse_onnx("unique_sorted_test.onnx"); + auto prog = read_onnx("unique_sorted_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/unique_unsorted_test.cpp b/test/onnx/parse/unique_unsorted_test.cpp index 687b8972427..7e81da1033d 100644 --- a/test/onnx/parse/unique_unsorted_test.cpp +++ b/test/onnx/parse/unique_unsorted_test.cpp @@ -39,7 +39,7 @@ TEST_CASE(unique_unsorted_test) auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out); auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out); mm->add_return({y, y_idx, x_idx, count}); - auto prog = migraphx::parse_onnx("unique_unsorted_test.onnx"); + auto prog = read_onnx("unique_unsorted_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/unknown_aten_test.cpp b/test/onnx/parse/unknown_aten_test.cpp index 75455da4fcc..be6facbf06c 100644 --- a/test/onnx/parse/unknown_aten_test.cpp +++ b/test/onnx/parse/unknown_aten_test.cpp @@ -26,5 +26,5 @@ TEST_CASE(unknown_aten_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_aten_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("unknown_aten_test.onnx"); })); } diff --git a/test/onnx/parse/unknown_test.cpp b/test/onnx/parse/unknown_test.cpp index 9530a748d52..05e169ccaca 100644 --- a/test/onnx/parse/unknown_test.cpp +++ b/test/onnx/parse/unknown_test.cpp @@ -40,12 +40,12 @@ TEST_CASE(unknown_test) TEST_CASE(unknown_test_throw) { - EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); })); + EXPECT(test::throws([&] { read_onnx("unknown_test.onnx"); })); } TEST_CASE(unknown_test_throw_print_error) { migraphx::onnx_options options; options.print_program_on_error = true; - EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("unknown_test.onnx", options); })); } diff --git a/test/onnx/parse/upsample_linear_test.cpp b/test/onnx/parse/upsample_linear_test.cpp index a7a29342700..d644ef87b00 100644 --- a/test/onnx/parse/upsample_linear_test.cpp +++ b/test/onnx/parse/upsample_linear_test.cpp @@ -28,6 +28,6 @@ TEST_CASE(upsample_linear_test) { auto p = create_upsample_linear_prog(); - auto prog = migraphx::parse_onnx("upsample_linear_test.onnx"); + auto prog = read_onnx("upsample_linear_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/upsample_test.cpp b/test/onnx/parse/upsample_test.cpp index 38ebd19f1ea..040aa96373d 100644 --- a/test/onnx/parse/upsample_test.cpp +++ b/test/onnx/parse/upsample_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(upsample_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("upsample_test.onnx"); + auto prog = read_onnx("upsample_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/upsample_ver7_test.cpp b/test/onnx/parse/upsample_ver7_test.cpp index 5541fe30e5e..ad20e25d1ce 100644 --- a/test/onnx/parse/upsample_ver7_test.cpp +++ b/test/onnx/parse/upsample_ver7_test.cpp @@ -40,7 +40,7 @@ TEST_CASE(upsample_ver7_test) auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li); mm->add_return({r}); - auto prog = migraphx::parse_onnx("upsample_ver7_test.onnx"); + auto prog = read_onnx("upsample_ver7_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/variable_batch_test.cpp b/test/onnx/parse/variable_batch_test.cpp index bfb1cc79e2b..05a715facb4 100644 --- a/test/onnx/parse/variable_batch_test.cpp +++ b/test/onnx/parse/variable_batch_test.cpp @@ -46,7 +46,7 @@ TEST_CASE(variable_batch_user_input_test1) migraphx::onnx_options options; options.default_dyn_dim_value = {2, 2}; - auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + auto prog = read_onnx("variable_batch_test.onnx", options); EXPECT(p == prog); } @@ -63,7 +63,7 @@ TEST_CASE(variable_batch_user_input_test2) migraphx::onnx_options options; options.default_dyn_dim_value = {2, 5}; - auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + auto prog = read_onnx("variable_batch_test.onnx", options); EXPECT(p == prog); } @@ -80,7 +80,7 @@ TEST_CASE(variable_batch_user_input_test3) migraphx::onnx_options options; options.map_dyn_input_dims["0"] = {{2, 5}, {3, 3}, {16, 16}, {16, 16}}; - auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + auto prog = read_onnx("variable_batch_test.onnx", options); EXPECT(p == prog); } @@ -96,7 +96,7 @@ TEST_CASE(variable_batch_user_input_test4) migraphx::onnx_options options; options.default_dim_value = 2; - auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + auto prog = read_onnx("variable_batch_test.onnx", options); EXPECT(p == prog); } @@ -108,7 +108,7 @@ TEST_CASE(variable_batch_user_input_test5) options.default_dim_value = 2; options.default_dyn_dim_value = {1, 2}; - EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("variable_batch_test.onnx", options); })); } TEST_CASE(variable_batch_user_input_test6) @@ -118,7 +118,7 @@ TEST_CASE(variable_batch_user_input_test6) options.map_dyn_input_dims["0"] = {{2, 5}, {3, 3}, {16, 16}, {16, 16}}; options.map_input_dims["0"] = {2, 3, 16, 16}; - EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("variable_batch_test.onnx", options); })); } TEST_CASE(variable_batch_user_input_test7) @@ -134,7 +134,7 @@ TEST_CASE(variable_batch_user_input_test7) migraphx::onnx_options options; options.map_dyn_input_dims["0"] = {{2, 2, {2}}, {3, 3}, {16, 16}, {16, 16}}; - auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + auto prog = read_onnx("variable_batch_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/where_dyn_test.cpp b/test/onnx/parse/where_dyn_test.cpp index 5fc38b9040e..2a71f241d81 100644 --- a/test/onnx/parse/where_dyn_test.cpp +++ b/test/onnx/parse/where_dyn_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(where_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = parse_onnx("where_dyn_test.onnx", options); + auto prog = read_onnx("where_dyn_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/where_mixed_test.cpp b/test/onnx/parse/where_mixed_test.cpp index f34cbe228c7..2cee95e961c 100644 --- a/test/onnx/parse/where_mixed_test.cpp +++ b/test/onnx/parse/where_mixed_test.cpp @@ -29,5 +29,5 @@ TEST_CASE(where_mixed_test) // mixture of static and dynamic input shapes is not supported migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws([&] { migraphx::parse_onnx("where_mixed_test.onnx", options); })); + EXPECT(test::throws([&] { read_onnx("where_mixed_test.onnx", options); })); } diff --git a/test/onnx/parse/where_test.cpp b/test/onnx/parse/where_test.cpp index 893e8095781..c6cee4a65cc 100644 --- a/test/onnx/parse/where_test.cpp +++ b/test/onnx/parse/where_test.cpp @@ -42,7 +42,7 @@ TEST_CASE(where_test) auto r = mm->add_instruction(migraphx::make_op("where"), lccm, lxm, lym); mm->add_return({r}); - auto prog = migraphx::parse_onnx("where_test.onnx"); + auto prog = read_onnx("where_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/verify/averagepool_notset_test.cpp b/test/onnx/verify/averagepool_notset_test.cpp index 2fa58a88a72..eb4c4cd78ea 100644 --- a/test/onnx/verify/averagepool_notset_test.cpp +++ b/test/onnx/verify/averagepool_notset_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(averagepool_notset_test) { - auto p = migraphx::parse_onnx("averagepool_notset_test.onnx"); + auto p = read_onnx("averagepool_notset_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; diff --git a/test/onnx/verify/averagepool_nt_cip_test.cpp b/test/onnx/verify/averagepool_nt_cip_test.cpp index 3d2d720f587..6ca2bf98560 100644 --- a/test/onnx/verify/averagepool_nt_cip_test.cpp +++ b/test/onnx/verify/averagepool_nt_cip_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(averagepool_nt_cip_test) { - auto p = migraphx::parse_onnx("averagepool_nt_cip_test.onnx"); + auto p = read_onnx("averagepool_nt_cip_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; diff --git a/test/onnx/verify/batch_norm_1d_test.cpp b/test/onnx/verify/batch_norm_1d_test.cpp index 0f24f6dc36e..edf59f9b42d 100644 --- a/test/onnx/verify/batch_norm_1d_test.cpp +++ b/test/onnx/verify/batch_norm_1d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(batch_norm_1d_test) { - migraphx::program p = migraphx::parse_onnx("batch_norm_1d_test.onnx"); + migraphx::program p = read_onnx("batch_norm_1d_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x_shape{migraphx::shape::half_type, {2, 3, 4}}; diff --git a/test/onnx/verify/batch_norm_2d_test.cpp b/test/onnx/verify/batch_norm_2d_test.cpp index 12d20c4e180..0e7ccc99dea 100644 --- a/test/onnx/verify/batch_norm_2d_test.cpp +++ b/test/onnx/verify/batch_norm_2d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(batch_norm_2d_test) { - migraphx::program p = migraphx::parse_onnx("batch_norm_2d_test.onnx"); + migraphx::program p = read_onnx("batch_norm_2d_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; diff --git a/test/onnx/verify/batch_norm_3d_test.cpp b/test/onnx/verify/batch_norm_3d_test.cpp index 6dd0f457d70..d4b0044e0d8 100644 --- a/test/onnx/verify/batch_norm_3d_test.cpp +++ b/test/onnx/verify/batch_norm_3d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(batch_norm_3d_test) { - migraphx::program p = migraphx::parse_onnx("batch_norm_3d_test.onnx"); + migraphx::program p = read_onnx("batch_norm_3d_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x_shape{migraphx::shape::half_type, {2, 2, 2, 2, 2}}; diff --git a/test/onnx/verify/batch_norm_flat_test.cpp b/test/onnx/verify/batch_norm_flat_test.cpp index 4ceaf45c980..6fd0c498f5d 100644 --- a/test/onnx/verify/batch_norm_flat_test.cpp +++ b/test/onnx/verify/batch_norm_flat_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(batch_norm_flat_test) { - migraphx::program p = migraphx::parse_onnx("batch_norm_flat_test.onnx"); + migraphx::program p = read_onnx("batch_norm_flat_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x_shape{migraphx::shape::float_type, {10}}; diff --git a/test/onnx/verify/batch_norm_rank_2_test.cpp b/test/onnx/verify/batch_norm_rank_2_test.cpp index 0d5c5f26502..564cb2beab6 100644 --- a/test/onnx/verify/batch_norm_rank_2_test.cpp +++ b/test/onnx/verify/batch_norm_rank_2_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(batch_norm_rank_2_test) { - migraphx::program p = migraphx::parse_onnx("batch_norm_rank_2_test.onnx"); + migraphx::program p = read_onnx("batch_norm_rank_2_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x_shape{migraphx::shape::float_type, {2, 5}}; diff --git a/test/onnx/verify/celu_verify_test.cpp b/test/onnx/verify/celu_verify_test.cpp index 414a65e4c74..26674cb718f 100644 --- a/test/onnx/verify/celu_verify_test.cpp +++ b/test/onnx/verify/celu_verify_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(celu_verify_test) { - migraphx::program p = migraphx::parse_onnx("celu_verify_test.onnx"); + migraphx::program p = read_onnx("celu_verify_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/onnx/verify/clip_test_args_type_mismatch.cpp b/test/onnx/verify/clip_test_args_type_mismatch.cpp index bed3f0c130e..b49eb1953e0 100644 --- a/test/onnx/verify/clip_test_args_type_mismatch.cpp +++ b/test/onnx/verify/clip_test_args_type_mismatch.cpp @@ -28,7 +28,7 @@ TEST_CASE(clip_args_type_mismatch) { - auto p = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx"); + auto p = read_onnx("clip_test_args_type_mismatch.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_0{migraphx::shape::float_type, {3, 3}}; migraphx::parameter_map pp; diff --git a/test/onnx/verify/depthtospace_simple_test.cpp b/test/onnx/verify/depthtospace_simple_test.cpp index 01e2b9b25e5..18f8af4d8a7 100644 --- a/test/onnx/verify/depthtospace_simple_test.cpp +++ b/test/onnx/verify/depthtospace_simple_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(depthtospace_simple_test) { - auto p = migraphx::parse_onnx("depthtospace_simple_test.onnx"); + auto p = read_onnx("depthtospace_simple_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_in(48); std::iota(std::begin(data_in), std::end(data_in), 0); diff --git a/test/onnx/verify/dynamicquantizelinear_1d_test.cpp b/test/onnx/verify/dynamicquantizelinear_1d_test.cpp index c6b8ea22600..edca4bb9e5f 100644 --- a/test/onnx/verify/dynamicquantizelinear_1d_test.cpp +++ b/test/onnx/verify/dynamicquantizelinear_1d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(dynamicquantizelinear_1d_test) { - auto p = migraphx::parse_onnx("dynamicquantizelinear_1d_test.onnx"); + auto p = read_onnx("dynamicquantizelinear_1d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data{0, 2, -3, -2.5, 1.34, 0.5}; @@ -55,7 +55,7 @@ TEST_CASE(dynamicquantizelinear_1d_test) TEST_CASE(dynamicquantizelinear_1d_max_adjusted_test) { - auto p = migraphx::parse_onnx("dynamicquantizelinear_1d_test.onnx"); + auto p = read_onnx("dynamicquantizelinear_1d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data{-1.0, -2.1, -1.3, -2.5, -3.34, -4.0}; diff --git a/test/onnx/verify/dynamicquantizelinear_2d_test.cpp b/test/onnx/verify/dynamicquantizelinear_2d_test.cpp index 709cbd43333..693a4d29a57 100644 --- a/test/onnx/verify/dynamicquantizelinear_2d_test.cpp +++ b/test/onnx/verify/dynamicquantizelinear_2d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(dynamicquantizelinear_2d_test) { - auto p = migraphx::parse_onnx("dynamicquantizelinear_2d_test.onnx"); + auto p = read_onnx("dynamicquantizelinear_2d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data{1.0, 2.1, 1.3, 2.5, 3.34, 4.0, 1.5, 2.6, 3.9, 4.0, 3.0, 2.345}; diff --git a/test/onnx/verify/einsum_tests.cpp b/test/onnx/verify/einsum_tests.cpp new file mode 100644 index 00000000000..65346feadae --- /dev/null +++ b/test/onnx/verify/einsum_tests.cpp @@ -0,0 +1,1676 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +static migraphx::shape make_shape(std::vector lens) +{ + return migraphx::shape{migraphx::shape::float_type, std::move(lens)}; +} + +TEST_CASE(einsum_permute_test) +{ + migraphx::program p = read_onnx("einsum_permute_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.06727745, 0.21160052, 0.1340474, 0.74153227, 0.40337096, 0.81284493}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.06727745, 0.74153227, 0.21160052, 0.40337096, 0.1340474, 0.81284493}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_summation_test) +{ + migraphx::program p = read_onnx("einsum_summation_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.79413969, 0.45169144, 0.06846618, 0.67973967, 0.83375529, 0.44838823}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape().scalar()); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {3.2761804984270566}; + EXPECT(result_vector == gold); +} + +TEST_CASE(einsum_column_sum_test) +{ + migraphx::program p = read_onnx("einsum_column_sum_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.22235926, 0.83263138, 0.04747776, 0.96030827, 0.18947713, 0.48815767}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.18266753, 1.0221085, 0.53563543}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_row_sum_test) +{ + migraphx::program p = read_onnx("einsum_row_sum_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.17123185, 0.59008514, 0.37948294, 0.73022965, 0.22919172, 0.27532941}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.14079993, 1.23475077}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_vector_multiplication_test) +{ + migraphx::program p = read_onnx("einsum_matrix_vector_multiplication_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.4834133, 0.14106742, 0.50055824, 0.91764271, 0.95528452, 0.98199955}; + + migraphx::shape v_shape{migraphx::shape::float_type, {3}}; + std::vector v_data = {0.73961958, 0.53071864, 0.34152803}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + pm["v"] = migraphx::argument{v_shape, v_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.60336371, 1.52107419}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_matrix_multiplication_test) +{ + migraphx::program p = read_onnx("einsum_matrix_matrix_multiplication_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.45176257, 0.84846429, 0.4374105, 0.25132236, 0.70519571, 0.4902031}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x_shape, x_data.data()}; + pm["x2"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.11530901, 0.92629139, 0.92629139, 0.80076299}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_vector_dot_product_test) +{ + migraphx::program p = read_onnx("einsum_vector_dot_product_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {3}}; + std::vector x_data = {0.45263196, 0.90876706, 0.9584567}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x_shape, x_data.data()}; + pm["x2"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape().scalar()); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.94937252}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_dot_product_test) +{ + migraphx::program p = read_onnx("einsum_matrix_dot_product_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.50001808, 0.12468059, 0.85439214, 0.00773521, 0.84764693, 0.87185525}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x_shape, x_data.data()}; + pm["x2"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape().scalar()); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {2.47424599}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_hadamard_product_test) +{ + migraphx::program p = read_onnx("einsum_hadamard_product_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.86162928, 0.76609605, 0.03362172, 0.21778614, 0.27204858, 0.83778314}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x_shape, x_data.data()}; + pm["x2"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.74240502, 0.58690315, 0.00113042, 0.0474308, 0.07401043, 0.70188058}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_vector_outer_product_test) +{ + migraphx::program p = read_onnx("einsum_vector_outer_product_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3}}; + std::vector x1_data = {0.35935151, 0.51298139, 0.46076789}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {5}}; + std::vector x2_data = {0.82417482, 0.17984153, 0.17680769, 0.55499376, 0.74447638}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 5})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.29616847, + 0.06462632, + 0.06353611, + 0.19943785, + 0.26752871, + 0.42278634, + 0.09225536, + 0.09069905, + 0.28470147, + 0.38190252, + 0.37975329, + 0.0828652, + 0.08146731, + 0.2557233, + 0.34303081}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_outer_product_test) +{ + migraphx::program p = read_onnx("einsum_matrix_outer_product_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x1_data = { + 0.25870501, 0.06755926, 0.18247427, 0.19436556, 0.61580192, 0.20010939}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 5}}; + std::vector x2_data = {0.30771264, + 0.86270274, + 0.55251869, + 0.35880608, + 0.3234085, + 0.24642323, + 0.82411907, + 0.33488431, + 0.69288027, + 0.21717812}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 3, 2, 5})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.0796068, 0.22318552, 0.14293935, 0.09282493, 0.0836674, 0.06375092, 0.21320373, + 0.08663625, 0.17925159, 0.05618507, 0.02078884, 0.05828356, 0.03732775, 0.02424067, + 0.02184924, 0.01664817, 0.05567687, 0.02262453, 0.04681048, 0.01467239, 0.05614964, + 0.15742105, 0.10082044, 0.06547288, 0.05901373, 0.0449659, 0.15038052, 0.06110777, + 0.12643282, 0.03962942, 0.05980874, 0.1676797, 0.1073906, 0.06973954, 0.06285947, + 0.04789619, 0.16018036, 0.06508997, 0.13467206, 0.04221195, 0.18949004, 0.53125401, + 0.34024207, 0.22095347, 0.19915557, 0.1517479, 0.50749411, 0.2062224, 0.426677, + 0.1337387, 0.06157619, 0.17263492, 0.11056418, 0.07180047, 0.06471708, 0.0493116, + 0.16491396, 0.06701349, 0.13865185, 0.04345938}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_batch_matrix_multiplication_test) +{ + migraphx::program p = read_onnx("einsum_batch_matrix_multiplication_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 2, 5}}; + std::vector x1_data = {0.99236023, 0.6848901, 0.37916487, 0.35448254, 0.06103943, + 0.88991707, 0.20816843, 0.12124124, 0.90632983, 0.88490338, + 0.93530363, 0.41393917, 0.95269137, 0.95556378, 0.63113954, + 0.87936215, 0.66831395, 0.38079353, 0.74128241, 0.05493966, + 0.12545692, 0.77418839, 0.17562823, 0.5558762, 0.95698858, + 0.49207445, 0.81934147, 0.50168285, 0.13782384, 0.71351839}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3, 5, 3}}; + std::vector x2_data = { + 0.72870257, 0.44635711, 0.05938103, 0.7031737, 0.52116502, 0.01719079, 0.99837568, + 0.29989025, 0.63673246, 0.39255282, 0.39796917, 0.03082538, 0.20994321, 0.11431396, + 0.06561894, 0.99749458, 0.45970296, 0.76957234, 0.98073012, 0.63154904, 0.22862209, + 0.71098086, 0.68895963, 0.92763041, 0.61730666, 0.54453456, 0.99719059, 0.05984043, + 0.64232788, 0.9754334, 0.39450223, 0.1005812, 0.11753032, 0.59885466, 0.75932222, + 0.45269589, 0.26201765, 0.39022748, 0.96507247, 0.55260731, 0.42233854, 0.50671452, + 0.60313192, 0.32628192, 0.40066181}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 2, 3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.73524908, + 1.06164644, + 0.32706016, + 1.45746952, + 1.00391812, + 0.21962538, + 2.64391179, + 2.27348666, + 3.26667873, + 2.26421769, + 1.52761296, + 1.97554961, + 1.44350867, + 1.21602803, + 1.19981019, + 1.32274886, + 1.15842452, + 1.2686234}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_tensor_contraction_test) +{ + migraphx::program p = read_onnx("einsum_tensor_contraction_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 3, 5, 7}}; + std::vector x1_data = {0.95685496, 0.40756636, 0.65360334, 0.96968506, + 0.50135366, 0.50255377, 0.54263245, 0.40919774, + 0.0512559, 0.18771721, 0.79265052, 0.76059609, + 0.31619353, 0.62297555, 0.70398181, 0.82378161, + 0.50388425, 0.56257752, 0.29233331, 0.98995162, + 0.38240504, 0.29803141, 0.23344604, 0.78356941, + 0.67958479, 0.10005701, 0.15588056, 0.29163352, + 0.90480928, 0.35649064, 0.77419322, 0.56301202, + 0.133201, 0.33165803, 0.37175546, 0.63959881, + 0.6058814, 0.43169871, 0.65272681, 0.17943427, + 0.30863453, 0.39029972, 0.66189176, 0.2311467, + 0.77007359, 0.33601537, 0.28087721, 0.65732174, + 0.67537887, 0.65066593, 0.89716601, 0.92921684, + 0.69368177, 0.86772161, 0.82583412, 0.32274594, + 0.0739795, 0.7573278, 0.82209441, 0.44979001, + 0.52619926, 0.68870551, 0.8586619, 0.32302478, + 0.30437449, 0.22181276, 0.41919667, 0.16351355, + 0.10825966, 0.20406509, 0.32577585, 0.89748513, + 0.78650319, 0.55487763, 0.74600253, 0.68125503, + 0.59796741, 0.75181214, 0.27655496, 0.87750203, + 0.50401991, 0.30561784, 0.82724439, 0.04727558, + 0.9224091, 0.24823561, 0.05547919, 0.93431458, + 0.51550858, 0.64800403, 0.95942825, 0.04009098, + 0.55616792, 0.71433063, 0.0753035, 0.0479713, + 0.19538077, 0.29627466, 0.47649694, 0.49999562, + 0.05246693, 0.29663604, 0.29992186, 0.62328915, + 0.00265317, 0.50642525, 0.73613139, 0.5998967, + 0.37132279, 0.02788106, 0.99984085, 0.87220473, + 0.08963238, 0.20698509, 0.17961793, 0.32962012, + 0.8046416, 0.96530006, 0.27079326, 0.07223538, + 0.72336279, 0.54842596, 0.38904735, 0.21660217, + 0.05165004, 0.60308648, 0.98992912, 0.01950237, + 0.19094762, 0.2928557, 0.18129261, 0.23948649, + 0.65970424, 0.0217831, 0.89637346, 0.25872699, + 0.98701943, 0.43783966, 0.65803132, 0.06773888, + 0.11277457, 0.68990466, 0.80914248, 0.66815968, + 0.10671669, 0.15578704, 0.78813393, 0.71601124, + 0.41304412, 0.93551562, 0.28607031, 0.16353775, + 0.54597636, 0.10405413, 0.05332971, 0.8301183, + 0.0991274, 0.1152268, 0.86477572, 0.20824363, + 0.77115011, 0.62202978, 0.87562719, 0.17638816, + 0.00798768, 0.46176706, 0.33432177, 0.93926911, + 0.60557399, 0.38483151, 0.23797486, 0.83815198, + 0.27293845, 0.62067518, 0.56702013, 0.80762545, + 0.47669687, 0.13692723, 0.40838777, 0.3148337, + 0.55255245, 0.24319153, 0.39330312, 0.22781179, + 0.101221, 0.80367016, 0.08707603, 0.90069816, + 0.28595044, 0.57599756, 0.71276499, 0.04032091, + 0.50101916, 0.94582167, 0.2091183, 0.17698968, + 0.72687874, 0.08878026, 0.16422912, 0.34543801, + 0.28480515, 0.8740834, 0.18413319, 0.60564407, + 0.94070861, 0.21143538, 0.2715485, 0.76848231, + 0.0064918, 0.36614132 + + }; + + migraphx::shape x2_shape{migraphx::shape::float_type, {1, 3, 3, 7, 5}}; + std::vector x2_data = { + 0.31719105, 0.44506343, 0.59957066, 0.00373946, 0.06497482, 0.30887562, 0.04364479, + 0.09203816, 0.0778086, 0.58357676, 0.49651904, 0.10000999, 0.16565024, 0.46539611, + 0.82516851, 0.64563229, 0.26637135, 0.2141455, 0.69189904, 0.75060041, 0.75433425, + 0.69215069, 0.18186255, 0.89800939, 0.93269204, 0.63033347, 0.9423835, 0.90530682, + 0.07135205, 0.57649693, 0.44479805, 0.94513207, 0.89856664, 0.79120729, 0.63383186, + 0.97271015, 0.69211656, 0.91893391, 0.07601606, 0.90099522, 0.31441974, 0.70932527, + 0.68997715, 0.33528514, 0.24921017, 0.09703337, 0.54714714, 0.98431729, 0.27753988, + 0.78936545, 0.51031898, 0.30604168, 0.53546681, 0.95644451, 0.79345859, 0.3444766, + 0.19356174, 0.41127976, 0.15782141, 0.65660564, 0.76540504, 0.21572256, 0.29864542, + 0.01153175, 0.06708682, 0.82473386, 0.45034386, 0.96212735, 0.5969872, 0.35962495, + 0.60466663, 0.52630816, 0.73655946, 0.11649375, 0.32456538, 0.64199728, 0.08340919, + 0.2237889, 0.09521117, 0.91767416, 0.22842615, 0.46863323, 0.00293057, 0.13495504, + 0.68305119, 0.80013148, 0.24702202, 0.83619373, 0.94419611, 0.25176846, 0.74292949, + 0.68404465, 0.23097011, 0.09664962, 0.44346347, 0.31467353, 0.37099949, 0.54412241, + 0.76552126, 0.1443158, 0.03555697, 0.43584746, 0.10575715, 0.1046359, 0.43291613, + 0.03007743, 0.55544576, 0.80022343, 0.42529416, 0.47484557, 0.84443037, 0.99362024, + 0.78040286, 0.16341681, 0.98059931, 0.64114384, 0.27438947, 0.51972672, 0.24844974, + 0.11630196, 0.86696682, 0.62380654, 0.23221499, 0.93125653, 0.53386878, 0.14323035, + 0.46524576, 0.24347234, 0.43592108, 0.68938894, 0.83452471, 0.67473429, 0.11704585, + 0.01223517, 0.61133307, 0.19640497, 0.94062148, 0.09548036, 0.27914148, 0.28533241, + 0.32062872, 0.27619432, 0.18284111, 0.73646915, 0.07043039, 0.10841211, 0.25284529, + 0.73262578, 0.63395762, 0.75505585, 0.66397536, 0.60934204, 0.17561379, 0.44185177, + 0.90064761, 0.87593443, 0.04697443, 0.90844936, 0.4878133, 0.17061924, 0.37868238, + 0.03991319, 0.99918374, 0.05644218, 0.11533688, 0.36478255, 0.74207249, 0.02537966, + 0.73720329, 0.41510019, 0.87408442, 0.0902388, 0.77849296, 0.22027469, 0.66811554, + 0.535826, 0.40478544, 0.47295354, 0.53722756, 0.81697433, 0.17400588, 0.52628511, + 0.57033592, 0.74645826, 0.58147372, 0.25898702, 0.03268815, 0.37127404, 0.04316943, + 0.86187713, 0.33330374, 0.58282901, 0.32484663, 0.8295674, 0.34023535, 0.48430125, + 0.5626468, 0.48469659, 0.16184832, 0.71399316, 0.5417521, 0.11897383, 0.84953376, + 0.98761605, 0.58273874, 0.89537346, 0.83282794, 0.78849938, 0.42528756, 0.08624209, + 0.7689597, 0.92518944, 0.25278458, 0.0732656, 0.0057378, 0.74097687, 0.13263284, + 0.73757523, 0.01510422, 0.8650508, 0.21755823, 0.38417346, 0.77236815, 0.80464568, + 0.23389132, 0.24982259, 0.3034747, 0.99357576, 0.69974824, 0.62271656, 0.43386392, + 0.3517672, 0.01739671, 0.54493487, 0.07725586, 0.75756086, 0.86409372, 0.50906544, + 0.87797418, 0.41355064, 0.11812738, 0.9809903, 0.67759122, 0.44601677, 0.53664097, + 0.75512155, 0.27589464, 0.12141359, 0.74533628, 0.95179317, 0.31788316, 0.41200016, + 0.81161753, 0.84035926, 0.42866542, 0.97692811, 0.14777789, 0.54256825, 0.03691842, + 0.71298109, 0.27676914, 0.31342084, 0.09905633, 0.01056144, 0.28488026, 0.39330704, + 0.07871612, 0.61847332, 0.48494692, 0.14455078, 0.53627478, 0.78087393, 0.24899241, + 0.78534409, 0.29844719, 0.33439453, 0.62448919, 0.21187341, 0.21381023, 0.25570138, + 0.67919933, 0.73611559, 0.45109776, 0.25360901, 0.17702297, 0.41635495, 0.80213947, + 0.01236559, 0.0112422, 0.03389217, 0.87942468, 0.25273501, 0.511234, 0.82734509, + 0.58747506, 0.31687443, 0.89906645, 0.96090575, 0.04004779, 0.02298561, 0.10433042, + 0.7104134, 0.79670464, 0.9930637, 0.5446879, 0.06004139, 0.41158374, 0.17676018, + 0.10056314, 0.01345726, 0.82521847, 0.76125409, 0.17694037, 0.05363529, 0.32265118}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 7, 1, 3, 7})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 4.37385737, 3.07363193, 3.61847664, 4.34283839, 4.26894546, 4.00093768, 4.51345157, + 3.28585485, 4.98955956, 3.40062413, 4.32430907, 3.58727315, 4.01024983, 4.04214073, + 3.71183284, 4.04117845, 3.9304425, 4.05446572, 3.19462145, 3.75153593, 3.63370359, + 3.6737565, 2.89999382, 3.0174889, 4.35349886, 3.165444, 3.10185148, 3.86251195, + 3.3873455, 4.06622752, 2.90101219, 3.93475191, 3.00084537, 3.36253104, 3.50215565, + 3.2272778, 3.63297086, 4.11360191, 2.55025226, 2.89909597, 2.8134455, 2.91506006, + 3.7938589, 3.12994095, 3.93469812, 5.19912284, 4.38534872, 3.50334177, 4.71274384, + 3.59957887, 4.82387001, 2.82827241, 5.04315375, 3.42817516, 3.97827684, 4.0792739, + 3.73622444, 4.59885202, 4.20690004, 3.39733812, 3.56861724, 4.18875149, 3.80445766, + 4.34760619, 2.83154296, 3.39897749, 4.91619741, 4.55085299, 4.02356989, 4.83137925, + 3.49172193, 5.09758452, 3.46814603, 4.8534725, 3.58561246, 4.17459184, 4.57103074, + 4.31924652, 3.86027525, 4.33725934, 3.88334716, 3.51074837, 4.2163728, 3.76365513, + 3.13004972, 2.27159717, 2.35669807, 3.25755431, 2.85534261, 2.56412151, 3.19951963, + 2.50814311, 3.53231318, 2.20002443, 3.12059903, 2.63204045, 2.90076584, 3.36582992, + 3.06683373, 2.76686275, 2.77506122, 2.09060484, 2.37978869, 2.59300135, 2.73194814, + 4.12941618, 3.09876995, 3.26773346, 4.15566501, 3.49722972, 3.46654242, 4.2842499, + 3.77358659, 4.61660476, 3.14276911, 3.88478492, 3.36244681, 3.70141846, 3.77154536, + 3.59743975, 4.07663608, 3.81503321, 3.53650377, 3.19912915, 3.41346893, 3.6696098, + 3.22521498, 2.26604057, 2.16539957, 4.2136737, 2.91410526, 3.02978768, 3.33819415, + 2.9409972, 3.83464087, 2.65153712, 3.32360785, 2.24438948, 3.95703137, 3.35290512, + 3.41760415, 2.86825506, 3.08274974, 2.72484017, 2.65706605, 3.36092398, 2.83630318, + 2.89697041, 2.50152336, 2.73918816, 4.5120665, 3.40255688, 2.21408714, 2.82712268, + 3.04826657, 3.41090928, 2.96534728, 3.52745057, 2.24957446, 3.84521048, 3.08574989, + 3.28188229, 2.31822221, 3.76298328, 2.57778028, 3.19081461, 3.07155158, 2.73609241, + 4.19950589, 3.6560231, 3.78387066, 4.79181063, 3.83391543, 3.55914169, 4.5795992, + 3.80991087, 5.12966262, 3.81299104, 4.21955081, 3.59584019, 4.29810986, 3.70353926, + 3.70364291, 4.26908068, 3.98312417, 3.12472346, 3.16217195, 3.4642648, 3.22122407, + 2.62355294, 1.82932863, 1.87920164, 2.36533037, 2.06395846, 2.33422825, 2.78131656, + 1.83772458, 2.43196754, 2.45650722, 2.37074638, 1.36516771, 2.47311739, 1.85973378, + 2.28547527, 2.22058881, 2.42265217, 1.82521576, 1.42674238, 2.63853633, 2.09125692, + 3.43987729, 2.19115419, 2.93461373, 3.85600443, 3.76977612, 3.15357479, 3.3520207, + 2.6665599, 4.023041, 2.68187355, 3.41405847, 2.72865504, 3.23944437, 3.64514952, + 3.347772, 3.08780622, 3.59354671, 3.2772289, 2.50492638, 2.77853552, 3.07724088, + 3.03408917, 2.45574117, 2.5493586, 3.48528482, 2.74493899, 2.611099, 3.26765525, + 2.93502233, 3.93585413, 2.32960219, 3.09824088, 3.03519943, 3.21090064, 3.3114777, + 2.58394431, 2.2187237, 3.00954904, 2.23092399, 2.83426168, 2.27217761, 2.5014613, + 3.19291058, 2.17091072, 3.02885277, 4.41008881, 4.12811972, 3.61970552, 3.53615268, + 2.78509447, 4.861919, 2.54172549, 4.17995171, 2.56407684, 4.31953876, 3.98183007, + 4.18525975, 3.4355, 3.32306034, 2.80758129, 3.17616352, 3.6386068, 3.45497304, + 3.46339678, 2.31062665, 2.98872364, 4.14619218, 3.33730406, 2.814647, 4.28392461, + 2.85391039, 3.99487077, 3.22812695, 4.24891978, 2.57924025, 3.05409494, 3.2767709, + 3.64664984, 3.49454643, 3.69300505, 2.42169066, 2.93327166, 3.5987843, 2.52333694}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_diagonal_test) +{ + migraphx::program p = read_onnx("einsum_matrix_diagonal_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x_data = {0.47776573, + 0.63448645, + 0.89651875, + 0.23679368, + 0.99918665, + 0.27613904, + 0.57251725, + 0.30676534, + 0.01097199}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.47776573, 0.99918665, 0.01097199}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_batch_matrix_diagonal_test) +{ + migraphx::program p = read_onnx("einsum_batch_matrix_diagonal_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 3, 3}}; + std::vector x_data = { + 0.28876273, 0.35989686, 0.87975286, 0.4636637, 0.42481418, 0.15188883, 0.19336828, + 0.24970656, 0.85099181, 0.26858692, 0.70659505, 0.28920736, 0.44962699, 0.02807534, + 0.36833006, 0.41504379, 0.00211731, 0.78780266, 0.23482163, 0.16543172, 0.29376553, + 0.8090205, 0.08804924, 0.16924385, 0.07311857, 0.52459502, 0.66098314}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.28876273, + 0.42481418, + 0.85099181, + 0.26858692, + 0.02807534, + 0.78780266, + 0.23482163, + 0.08804924, + 0.66098314}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_3d_diagonal_test) +{ + migraphx::program p = read_onnx("einsum_3d_diagonal_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 3, 3}}; + std::vector x_data = { + 0.0865182, 0.38083222, 0.67805353, 0.0585945, 0.74171412, 0.1304194, 0.00526353, + 0.43741816, 0.95075246, 0.56668103, 0.66687595, 0.73297639, 0.06474291, 0.27579944, + 0.13203794, 0.01323116, 0.18004087, 0.67450993, 0.86813684, 0.88677573, 0.67944271, + 0.38633242, 0.92832963, 0.02932602, 0.45013121, 0.36562681, 0.0411488}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.0865182, 0.27579944, 0.0411488}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_diag_vector_multiply_test) +{ + migraphx::program p = read_onnx("einsum_diag_vector_multiply_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x1_data = {0.8628764, + 0.96045198, + 0.14103307, + 0.89249896, + 0.97520951, + 0.7015561, + 0.06408759, + 0.59921615, + 0.76173894}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3}}; + std::vector x2_data = {0.79284103, 0.61505765, 0.70876231}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.68412382, 0.59981008, 0.53989185}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_trace_test) +{ + migraphx::program p = read_onnx("einsum_matrix_trace_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x_data = {0.90812557, + 0.40719192, + 0.71678312, + 0.78176503, + 0.57731702, + 0.23585615, + 0.06292936, + 0.46016886, + 0.37753559}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape().scalar()); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.86297818}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_matrix_trace_implicit_test) +{ + migraphx::program p = read_onnx("einsum_matrix_trace_implicit_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x_data = {0.78947898, + 0.56206428, + 0.18337164, + 0.58397232, + 0.68795372, + 0.11615468, + 0.22114439, + 0.84875979, + 0.08248506}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape().scalar()); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.559917763052301}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_2d_3d_multiplication_test) +{ + migraphx::program p = read_onnx("einsum_2d_3d_multiplication_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x1_data = {0.77117604, + 0.10042859, + 0.68555583, + 0.93192629, + 0.39255794, + 0.99285767, + 0.88129697, + 0.56599014, + 0.03828527}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3, 4, 5}}; + std::vector x2_data = { + 0.19665868, 0.49490562, 0.73175228, 0.89251999, 0.08735652, 0.25944536, 0.37003717, + 0.09387889, 0.75490936, 0.81022481, 0.9987667, 0.04082882, 0.26160334, 0.85590193, + 0.80221833, 0.11203218, 0.31701572, 0.45973754, 0.3452479, 0.85151585, 0.86455042, + 0.19206577, 0.09922319, 0.58911914, 0.15871974, 0.61540675, 0.21682354, 0.69036427, + 0.77451157, 0.91950467, 0.52659111, 0.80857867, 0.63179264, 0.10085509, 0.96412482, + 0.42412458, 0.0330562, 0.13279482, 0.39372801, 0.80698385, 0.1182876, 0.75943908, + 0.59421519, 0.66827559, 0.09009574, 0.66649037, 0.43015355, 0.37795428, 0.11304274, + 0.37406792, 0.33043231, 0.32357327, 0.38079892, 0.42659918, 0.55308245, 0.49437723, + 0.95926415, 0.99762983, 0.70624046, 0.24298556}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 4, 5})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.3195768, 0.92158614, 0.98164236, 1.20559466, 0.14507291, 0.71879884, 0.60203336, + 0.40083822, 0.73744823, 0.97361497, 1.04963956, 0.33451816, 0.5262512, 0.96263736, + 1.09464615, 0.46791396, 0.90542384, 1.05180592, 0.78995572, 0.90429304, 0.64010028, + 1.29062741, 1.31086115, 1.72652878, 0.23316878, 1.14509684, 0.85704442, 0.73375098, + 1.1197959, 1.48742487, 1.46556673, 0.67672563, 0.86988939, 1.26078125, 1.67521536, + 0.76174542, 1.26082452, 1.47107559, 1.17750291, 1.351588, 0.66717038, 0.57394148, + 0.72380011, 1.1455959, 0.17027018, 0.60247933, 0.46530117, 0.48794463, 1.10799312, + 1.24880054, 1.19090614, 0.50601796, 0.60271763, 0.82771923, 1.27385264, 0.35771131, + 0.33482015, 0.51852039, 0.5541507, 1.21648601}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_element_wise_multiplication_and_row_sum_test) +{ + migraphx::program p = read_onnx("einsum_element_wise_multiplication_and_row_sum_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3}}; + std::vector x1_data = {0.66866322, 0.01371844, 0.85036724}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3, 4}}; + std::vector x2_data = {0.72487469, + 0.24707426, + 0.8735483, + 0.04525622, + 0.52379655, + 0.32056461, + 0.51596208, + 0.10696902, + 0.08682559, + 0.95054461, + 0.16377484, + 0.61029108}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.2642773, 0.02012896, 1.54038595}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_broadcast_test) +{ + migraphx::program p = read_onnx("einsum_broadcast_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 1}}; + std::vector x1_data = {0.39430774, 0.13914788, 0.48328062}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2}}; + std::vector x2_data = {0.71903989, 0.19490621, 0.56431641, 0.09180231}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.50603732, 0.11305139, 0.17857631, 0.03989488, 0.62022123, 0.13856067}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_3d_broadcast_test) +{ + migraphx::program p = read_onnx("einsum_3d_broadcast_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {1, 3, 1}}; + std::vector x1_data = {0.6306304, 0.92378069, 0.09156996}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2, 4}}; + std::vector x2_data = {0.07905765, + 0.27054262, + 0.42684231, + 0.96296392, + 0.20374812, + 0.95058412, + 0.26180494, + 0.65115589, + 0.19317509, + 0.60143068, + 0.54864825, + 0.36401264, + 0.20867305, + 0.90065616, + 0.26377379, + 0.16009663}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 3, 4})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.17834592, 0.77007964, 0.43428189, 1.01791302, 0.26125051, + 1.1280533, 0.63615903, 1.49109271, 0.02589651, 0.11181853, + 0.0630594, 0.14780488, 0.25341766, 0.94726162, 0.51233803, + 0.33051924, 0.37121956, 1.38759882, 0.75049978, 0.48416203, + 0.03679722, 0.13754603, 0.07439345, 0.04799266}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_3d_opposite_broadcast_test) +{ + migraphx::program p = read_onnx("einsum_3d_opposite_broadcast_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {1, 3, 2}}; + std::vector x1_data = { + 0.89996837, 0.62380433, 0.38499382, 0.82576167, 0.71647773, 0.74190884}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 1, 4}}; + std::vector x2_data = {0.83902045, + 0.3002842, + 0.46254963, + 0.42754638, + 0.54720295, + 0.6184629, + 0.99604709, + 0.94529622}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 3, 4})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.27847646, 0.45756486, 0.7048205, 0.65148351, 1.01584862, + 0.36357074, 0.56003451, 0.51765413, 1.22361616, 0.43793044, + 0.67457618, 0.62352791, 0.83381291, 0.94239689, 1.51774936, + 1.44041657, 0.66252897, 0.74880736, 1.20596948, 1.14452259, + 0.79803343, 0.90195799, 1.4526217, 1.37860731}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_3_inputs_test) +{ + migraphx::program p = read_onnx("einsum_3_inputs_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 2, 2}}; + std::vector x1_data = {0.78808491, + 0.6661874, + 0.4170594, + 0.80972418, + 0.22687053, + 0.52144567, + 0.70463225, + 0.8934412}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2}}; + std::vector x2_data = {0.98518483, 0.61526655, 0.89011461, 0.02600793}; + + migraphx::shape x3_shape{migraphx::shape::float_type, {2, 2, 2}}; + std::vector x3_data = {0.04135729, + 0.36723732, + 0.82196749, + 0.35332048, + 0.92673273, + 0.50014512, + 0.91129541, + 0.97557965}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + pm["x3"] = migraphx::argument{x3_shape, x3_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.54312876, + 0.59155446, + 1.19274407, + 0.56709538, + 2.79449706, + 1.61644006, + 2.15997517, + 1.5496049}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_bilinear_transformation_test) +{ + migraphx::program p = read_onnx("einsum_bilinear_transformation_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x1_data = { + 0.34096073, 0.38172764, 0.36543085, 0.28104558, 0.0556053, 0.23574725}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {5, 3, 7}}; + std::vector x2_data = { + 0.27525548, 0.55922006, 0.28504873, 0.48681888, 0.7527785, 0.76094518, 0.99365312, + 0.76470274, 0.44406814, 0.24103473, 0.25141801, 0.51590554, 0.78834812, 0.96411404, + 0.01325493, 0.21739615, 0.25936655, 0.23025532, 0.85856546, 0.33609085, 0.33413049, + 0.60163776, 0.61253489, 0.84028869, 0.2593441, 0.53611056, 0.05595679, 0.30129639, + 0.44404875, 0.71431542, 0.95123376, 0.71387725, 0.05743836, 0.35266739, 0.53284905, + 0.07799213, 0.3639559, 0.72199632, 0.0920087, 0.71882463, 0.09804492, 0.79378518, + 0.2149909, 0.62017677, 0.57284093, 0.1480283, 0.65038853, 0.47830376, 0.18202239, + 0.37421293, 0.65768777, 0.2465394, 0.80183419, 0.65855262, 0.40956847, 0.36430994, + 0.4464513, 0.65720017, 0.29603235, 0.21994904, 0.31797431, 0.64774027, 0.71807814, + 0.67456442, 0.37665375, 0.84645173, 0.10965697, 0.57469259, 0.68129292, 0.28780513, + 0.50772577, 0.67820423, 0.92720621, 0.52615601, 0.5507361, 0.55419857, 0.37244191, + 0.52378246, 0.29057448, 0.14684616, 0.60456568, 0.79814119, 0.51783395, 0.69921548, + 0.12310853, 0.18934048, 0.98081268, 0.51493817, 0.1279986, 0.3868668, 0.42396674, + 0.04160038, 0.56299233, 0.40414454, 0.73163413, 0.3126024, 0.75276068, 0.88847181, + 0.96703089, 0.34357903, 0.34495332, 0.73431682, 0.01318382, 0.15232141, 0.88949811}; + + migraphx::shape x3_shape{migraphx::shape::float_type, {2, 7}}; + std::vector x3_data = {0.22897831, + 0.68897913, + 0.55615068, + 0.77395085, + 0.44879247, + 0.42608676, + 0.45303661, + 0.04397996, + 0.44780993, + 0.98314993, + 0.32980751, + 0.57814391, + 0.91010863, + 0.53235916}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + pm["x3"] = migraphx::argument{x3_shape, x3_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 5})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.82915577, + 1.88971744, + 1.84172272, + 2.0310065, + 1.91888787, + 1.11119172, + 1.03903856, + 1.03828167, + 1.17052253, + 0.98080627}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_ellipsis_test) +{ + migraphx::program p = read_onnx("einsum_ellipsis_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 3, 2}}; + std::vector x1_data = {0.04249489, + 0.55406728, + 0.19941733, + 0.73459709, + 0.85098409, + 0.57610406, + 0.20316778, + 0.43422309, + 0.83122325, + 0.26004847, + 0.75534733, + 0.96759149}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 4, 2}}; + std::vector x2_data = {0.92094713, + 0.79225215, + 0.74592229, + 0.44132894, + 0.33642643, + 0.7196803, + 0.52841641, + 0.19646611, + 0.85507066, + 0.69714208, + 0.61092676, + 0.10550163, + 0.1895, + 0.67025347, + 0.01897078, + 0.63833372}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 4, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.51290222, 0.4636753, 0.37019241, 0.13547507, 0.11929215, + 0.43725538, 0.03296608, 0.31709483, 0.81178524, 0.83982914, + 0.59753485, 0.39427841, 0.20629541, 0.77251339, 0.11931127, + 0.3293049, 1.27632103, 1.27297429, 0.98672538, 0.43543911, + 0.39546526, 1.19214015, 0.4606031, 0.76604642}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_ellipsis_multidim_test) +{ + migraphx::program p = read_onnx("einsum_ellipsis_multidim_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 2, 3, 2}}; + std::vector x1_data = { + 0.98667534, 0.26757447, 0.97607513, 0.82605353, 0.49444144, 0.01681133, + 0.77774229, 0.75994986, 0.11125708, 0.1130032, 0.63612414, 0.1262558, + 0.58148571, 0.03373236, 0.97679914, 0.96362191, 0.81985409, 0.49089541, + 0.20980484, 0.54484447, 0.86032374, 0.03736589, 0.21250823, 0.61016893, + 0.35060633, 0.66305752, 0.15096292, 0.13044199, 0.85426735, 0.35063898, + 0.62050398, 0.42931425, 0.78397709, 0.30081415, 0.13172537, 0.97078161}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 4, 3, 2}}; + std::vector x2_data = { + 0.57040198, 0.53550748, 0.45591515, 0.56752322, 0.50931221, 0.81220443, 0.00733681, + 0.3914752, 0.56944863, 0.57929432, 0.7376043, 0.07466457, 0.62632235, 0.93106704, + 0.75973908, 0.06791374, 0.4220263, 0.30228231, 0.12644542, 0.17381266, 0.6764365, + 0.7179303, 0.78075755, 0.45183063, 0.03752228, 0.54431596, 0.08627314, 0.8015124, + 0.74214063, 0.99574465, 0.26469823, 0.77350918, 0.29052469, 0.38834888, 0.13962948, + 0.7043763, 0.98259846, 0.59013313, 0.67843048, 0.60183051, 0.75242782, 0.49615042, + 0.74438165, 0.99080336, 0.09669321, 0.63712064, 0.45491748, 0.81021691}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 4, 3, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.57284157, 0.83013964, 0.26801834, 0.55576872, 0.67065001, 0.93146345, 0.07806553, + 0.89229501, 0.34092632, 0.3331285, 0.35119111, 0.34872845, 0.88089507, 1.1726018, + 0.46466248, 0.34215266, 0.64686801, 0.40057183, 0.3239381, 0.88814233, 0.39659985, + 0.49775691, 0.57537499, 0.62820037, 0.58775059, 0.12108844, 0.52847222, 0.51820293, + 0.17369356, 0.93628374, 0.22581618, 0.1309634, 0.83619289, 0.51289166, 0.12956445, + 0.27042167, 1.4230166, 0.17027473, 1.39586296, 0.08091573, 0.1618585, 0.38623148, + 0.73831932, 0.13130184, 0.75391828, 0.64145906, 0.17720578, 0.59794957, 0.28266118, + 0.40937228, 0.41613499, 0.60966132, 0.69531223, 1.07363852, 0.00807755, 0.34668684, + 0.60948202, 0.36006323, 0.67907081, 0.69363078, 0.32619851, 0.66678194, 0.9559136, + 0.38165051, 0.62435381, 0.52147196, 0.0750339, 0.2356611, 0.60204548, 0.54131732, + 0.82648748, 0.84606124}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_ellipsis_zero_test) +{ + migraphx::program p = read_onnx("einsum_ellipsis_zero_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 3, 2}}; + std::vector x1_data = {0.66350493, + 0.23942871, + 0.92238018, + 0.62110235, + 0.32076099, + 0.96309398, + 0.52844268, + 0.34438311, + 0.65616714, + 0.20566103, + 0.27886952, + 0.65970714}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {4, 3, 2}}; + std::vector x2_data = {0.80308382, 0.54059368, 0.37399569, 0.1005526, 0.76379294, + 0.67375565, 0.35891999, 0.84426002, 0.09043876, 0.90878662, + 0.94432809, 0.79103325, 0.1105734, 0.4352484, 0.33998431, + 0.05210384, 0.99372845, 0.38982222, 0.99214395, 0.66699468, + 0.11299297, 0.64553585, 0.39052278, 0.66001129}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 2, 4})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.66228372, 0.44028527, 0.17757696, 0.81799008, 0.61055509, + 0.48041753, 0.2083239, 0.7539929, 0.40741967, 0.64786843, + 0.34595661, 0.50516631, 0.26608343, 0.24624494, 0.23380226, + 0.20690385, 0.89388499, 1.06474297, 0.69418476, 0.76091737, + 0.65747998, 0.7851946, 0.53428908, 0.54431906}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_ellipsis_implicit_form_test) +{ + migraphx::program p = read_onnx("einsum_ellipsis_implicit_form_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 2, 3, 2}}; + std::vector x1_data = { + 0.23521871, 0.98377414, 0.89254812, 0.97761717, 0.05081862, 0.68622971, + 0.10890005, 0.2268622, 0.49600579, 0.2676526, 0.42904501, 0.37749836, + 0.79665579, 0.95331325, 0.86434957, 0.79121832, 0.28486632, 0.12174202, + 0.70187, 0.14436634, 0.03751946, 0.61306538, 0.13534059, 0.27080258, + 0.2651645, 0.29432102, 0.04611007, 0.58113752, 0.24878511, 0.17095365, + 0.0815941, 0.29892262, 0.11160549, 0.27367858, 0.36888151, 0.16212635}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3, 4, 3, 2}}; + std::vector x2_data = { + 0.44591065, 0.88061357, 0.701782, 0.57534276, 0.65403074, 0.81415861, 0.68154153, + 0.55451648, 0.81680318, 0.54274041, 0.44267802, 0.204258, 0.38894043, 0.26743358, + 0.9689122, 0.16832771, 0.70924974, 0.13868791, 0.52965739, 0.41611994, 0.59251147, + 0.03544427, 0.86559268, 0.68808533, 0.01154378, 0.50244414, 0.20684438, 0.15988138, + 0.28233231, 0.10307361, 0.90725685, 0.94720523, 0.42599834, 0.93168414, 0.82026755, + 0.22099913, 0.46835316, 0.90021715, 0.5152653, 0.51409383, 0.33123306, 0.3003667, + 0.07429799, 0.79805729, 0.17255054, 0.29718065, 0.92965361, 0.36905318, 0.69877278, + 0.77362919, 0.14773139, 0.23016429, 0.02718606, 0.39449785, 0.93450467, 0.34742404, + 0.35372862, 0.07290892, 0.79728572, 0.15650619, 0.53751043, 0.44802221, 0.77646259, + 0.65170074, 0.49278255, 0.36228251, 0.17940834, 0.66284468, 0.15208601, 0.83560697, + 0.51165061, 0.14598895}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 4, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {2.75198731, 1.33836971, 2.12812296, 1.01745957, 1.51515599, + 0.98532013, 1.61362211, 1.08658677, 0.88644536, 0.2525403, + 2.99170324, 1.53155007, 2.21435937, 0.91935904, 1.51402355, + 0.58178573, 0.62775842, 0.4417366, 0.63384035, 0.55901237, + 0.87345202, 0.68330958, 0.88752551, 0.67084639}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_ellipsis_scalar_multiplication_test) +{ + migraphx::program p = read_onnx("einsum_ellipsis_scalar_multiplication_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3}}; + std::vector x_data = { + 0.2766607, 0.76752867, 0.28231295, 0.30409753, 0.37753377, 0.73576867}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x_shape, x_data.data()}; + pm["x2"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.07654114, 0.58910026, 0.0797006, 0.09247531, 0.14253175, 0.54135554}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_1_test) +{ + migraphx::program p = read_onnx("einsum_common_1_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x1_data = {0.35498396, + 0.92145607, + 0.81807284, + 0.37990484, + 0.22314499, + 0.90337144, + 0.02492543, + 0.36666091, + 0.33262049, + 0.37052745, + 0.01950226, + 0.83690205, + 0.61551503, + 0.55244304, + 0.62696715, + 0.74933671}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x2_data = {0.44903857, + 0.47304138, + 0.63679145, + 0.78101353, + 0.41525864, + 0.57356733, + 0.83636479, + 0.01236986, + 0.10068789, + 0.46623025, + 0.29825429, + 0.56816588, + 0.00558546, + 0.91900877, + 0.74972012, + 0.4509882}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 2, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.59528833, + 0.52753278, + 0.67592725, + 0.61080723, + 0.81765261, + 0.30223943, + 0.68890669, + 0.0253823, + 0.20624196, + 0.31954056, + 0.34237582, + 0.51113793, + 0.48131582, + 0.6127432, + 0.39205418, + 0.8079919}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_2_test) +{ + migraphx::program p = read_onnx("einsum_common_2_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x1_data = {0.77858647, + 0.8659616, + 0.89981848, + 0.45454779, + 0.27364842, + 0.69225887, + 0.01304595, + 0.14404551, + 0.47394644, + 0.39058325, + 0.977306, + 0.90298946, + 0.01456065, + 0.70478062, + 0.92796867, + 0.00407166}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x2_data = {0.12299003, + 0.42677007, + 0.84213152, + 0.26884624, + 0.85685616, + 0.53033816, + 0.61543941, + 0.00586418, + 0.79310638, + 0.66468861, + 0.22797244, + 0.32789713, + 0.01537162, + 0.28328088, + 0.39257709, + 0.83954883}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {2.51890769, + 1.78883817, + 2.11484282, + 1.38804189, + 2.81881969, + 1.09537142, + 3.0398521, + 1.07377846}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_3_test) +{ + migraphx::program p = read_onnx("einsum_common_3_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x1_data = {0.22151958, + 0.19284961, + 0.8126814, + 0.02360209, + 0.99137254, + 0.0550951, + 0.34794661, + 0.03083101, + 0.03127261, + 0.04609321, + 0.02422953, + 0.30878066, + 0.42532866, + 0.02191982, + 0.34276933, + 0.66997637}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x2_data = {0.76051399, + 0.92365044, + 0.14703117, + 0.07201171, + 0.81879942, + 0.91050362, + 0.90936259, + 0.94197062, + 0.73971579, + 0.08809791, + 0.17392649, + 0.36623704, + 0.23731799, + 0.67476051, + 0.97480632, + 0.35175013}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 2})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.62099637, + 2.20329706, + 0.6457657, + 1.61829179, + 0.4142793, + 0.52881853, + 2.00689201, + 2.20807455}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_4_test) +{ + migraphx::program p = read_onnx("einsum_common_4_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {2, 2, 3, 2}}; + std::vector x1_data = {0.56144416, 0.70795103, 0.10800643, 0.85461707, 0.53053745, + 0.42957473, 0.2801385, 0.91878799, 0.51160639, 0.90354742, + 0.83131358, 0.84237736, 0.01078178, 0.75952001, 0.74426499, + 0.70506648, 0.65528756, 0.54674358, 0.3923791, 0.33558121, + 0.18089114, 0.41982192, 0.50568299, 0.83929267}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2, 4, 2}}; + std::vector x2_data = { + 0.71114916, 0.10373848, 0.85011488, 0.08836512, 0.01426097, 0.63389153, 0.3714056, + 0.42466907, 0.5412509, 0.12682203, 0.88595126, 0.09839624, 0.10689487, 0.1196194, + 0.5887543, 0.51683836, 0.50278953, 0.94187525, 0.98227159, 0.57961915, 0.12739494, + 0.59140361, 0.34997506, 0.43158845, 0.60170823, 0.06098434, 0.24573198, 0.15357368, + 0.99864135, 0.92721276, 0.81457582, 0.49836327}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 3, 4})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.4727123, 0.53985021, 0.4567709, 0.50916841, 0.16546536, 0.16733621, 0.5432748, + 0.40304363, 0.42185469, 0.48897721, 0.27986976, 0.37947168, 0.26814778, 0.33859434, + 0.13985024, 0.63979763, 0.39149714, 0.54216399, 0.1627699, 0.76819843, 0.55678123, + 0.81939007, 0.18962783, 0.92481237, 0.72079407, 0.45082298, 0.45055642, 0.33157342, + 1.03829331, 1.13974038, 0.51179445, 0.56477273, 0.84443597, 0.9605734, 0.40682645, + 0.46530252, 0.25656293, 0.14795654, 0.70300118, 0.48686388, 0.13444625, 0.10892434, + 0.56990961, 0.35657337, 0.35545733, 0.25315575, 1.28319881, 0.83018978}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_5_test) +{ + migraphx::program p = read_onnx("einsum_common_5_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 2, 3, 2}}; + std::vector x1_data = { + 0.54568637, 0.37482154, 0.04235242, 0.65373642, 0.33087863, 0.31717808, + 0.95558492, 0.04292704, 0.41062909, 0.15678733, 0.42269055, 0.52439126, + 0.79640916, 0.84653066, 0.07768967, 0.27527369, 0.89984151, 0.51484382, + 0.16384989, 0.91806877, 0.21812376, 0.11357245, 0.54908942, 0.31401177, + 0.65491277, 0.28771509, 0.78575018, 0.79237873, 0.46273786, 0.76982106, + 0.09757821, 0.22590816, 0.07358939, 0.10590534, 0.83561014, 0.46470277}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3, 4, 3, 2}}; + std::vector x2_data = { + 0.8106741, 0.59851071, 0.01563264, 0.59371323, 0.92144669, 0.13810113, 0.30200611, + 0.04771728, 0.27000965, 0.15975859, 0.79296359, 0.8423782, 0.14653939, 0.97910498, + 0.92130026, 0.98351422, 0.36302145, 0.34644287, 0.552259, 0.8590351, 0.32266987, + 0.05450608, 0.37737409, 0.28476044, 0.12639262, 0.68674546, 0.36657116, 0.95912161, + 0.25702418, 0.36058756, 0.68556443, 0.71449807, 0.15664292, 0.14519584, 0.96284277, + 0.08696439, 0.21784017, 0.35219703, 0.33682869, 0.65550335, 0.58188946, 0.15934059, + 0.4108815, 0.73728006, 0.18921976, 0.00133056, 0.56921019, 0.10649676, 0.63103856, + 0.06864912, 0.38452259, 0.44953274, 0.53725327, 0.75235172, 0.71780644, 0.56919235, + 0.14419679, 0.27101719, 0.03290223, 0.13075588, 0.99856136, 0.76185492, 0.29195496, + 0.45779837, 0.670453, 0.20837162, 0.90747364, 0.53769863, 0.37493214, 0.46571204, + 0.89671548, 0.16910057}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 3, 2, 4})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.66670851, 0.18268608, 0.44695419, 0.62334507, 0.80036024, 0.29064084, 0.18206091, + 0.56460621, 0.38879404, 0.11587557, 0.68197836, 0.04929846, 0.09950593, 0.13592194, + 0.53251525, 0.1410435, 0.34868967, 0.52955861, 0.23000012, 0.21518479, 0.46190584, + 0.77691399, 0.33511735, 0.30883835, 0.68201133, 1.15083431, 0.47163549, 0.95135997, + 0.65118898, 0.76828803, 0.35903419, 0.74419669, 0.29249974, 0.05213813, 0.20661094, + 0.01506669, 0.18888767, 0.05065779, 0.14791746, 0.04142444, 0.4169273, 0.91117897, + 0.60564381, 0.56702816, 0.25435799, 0.55599462, 0.36954417, 0.34598853, 0.4330266, + 0.63386583, 0.87316774, 0.74902009, 0.07708401, 0.19862746, 0.26954707, 0.21002016, + 0.65833888, 0.32805091, 0.59215335, 0.66362331, 0.0759047, 0.03931352, 0.06996808, + 0.07691242, 0.82778363, 0.11588374, 0.47065285, 0.54512138, 0.79855421, 0.08825606, + 0.65706819, 0.82788605}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_6_test) +{ + migraphx::program p = read_onnx("einsum_common_6_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 2, 2}}; + std::vector x1_data = {0.05474463, + 0.22797254, + 0.87786654, + 0.5430384, + 0.7145002, + 0.27575673, + 0.74687312, + 0.49764738, + 0.3077794, + 0.83018295, + 0.42118662, + 0.04536079}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {2, 2, 3}}; + std::vector x2_data = {0.51540488, + 0.78670115, + 0.71049908, + 0.51739133, + 0.75638524, + 0.50107731, + 0.15112663, + 0.55976972, + 0.09744345, + 0.63967998, + 0.56295837, + 0.95296606}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 2, 3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.06266837, + 0.17067979, + 0.06111044, + 0.80157133, + 0.96971331, + 0.95737617, + 0.40993108, + 0.7164584, + 0.53452242, + 0.70476074, + 0.84507857, + 0.84848224, + 0.28409375, + 0.70684169, + 0.29957287, + 0.24693469, + 0.34411558, + 0.25427435}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_7_test) +{ + migraphx::program p = read_onnx("einsum_common_7_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::float_type, {5, 5}}; + std::vector x_data = {0.45661163, 0.49868523, 0.8806857, 0.45253824, 0.61711842, + 0.19736463, 0.55164341, 0.84964635, 0.50090015, 0.49506288, + 0.19423388, 0.76448901, 0.65602353, 0.2169867, 0.99645268, + 0.62749812, 0.67396942, 0.69806385, 0.23727109, 0.23524408, + 0.84425561, 0.67866378, 0.20223278, 0.34088997, 0.22209943}; + + migraphx::parameter_map pm; + pm["x"] = migraphx::argument{x_shape, x_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({5})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {2.90563922, 2.5946174, 2.82818581, 2.47204655, 2.28814157}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(einsum_common_8_test) +{ + migraphx::program p = read_onnx("einsum_common_8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x1_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x1_data = {0.31281588, + 0.34922652, + 0.79181082, + 0.55581571, + 0.34963734, + 0.39777707, + 0.43040396, + 0.19965846, + 0.68818176}; + + migraphx::shape x2_shape{migraphx::shape::float_type, {3, 3}}; + std::vector x2_data = {0.94199384, + 0.06564557, + 0.36439139, + 0.30556677, + 0.25776106, + 0.59531702, + 0.21481152, + 0.09608821, + 0.41203512}; + + migraphx::parameter_map pm; + pm["x1"] = migraphx::argument{x1_shape, x1_data.data()}; + pm["x2"] = migraphx::argument{x2_shape, x2_data.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({3, 3})); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.29467063, + 0.08063175, + 0.12889113, + 0.32935622, + 0.09012289, + 0.14406286, + 0.64826297, + 0.17738646, + 0.28355505}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/eyelike_verify_negk_test.cpp b/test/onnx/verify/eyelike_verify_negk_test.cpp index 7e7e61328f7..d949df052fe 100644 --- a/test/onnx/verify/eyelike_verify_negk_test.cpp +++ b/test/onnx/verify/eyelike_verify_negk_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(eyelike_verify_negk_test) { - migraphx::program p = migraphx::parse_onnx("eyelike_verify_negk_test.onnx"); + migraphx::program p = read_onnx("eyelike_verify_negk_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3, 4}}; diff --git a/test/onnx/verify/eyelike_verify_test.cpp b/test/onnx/verify/eyelike_verify_test.cpp index ab33e3c2f51..433b79b761b 100644 --- a/test/onnx/verify/eyelike_verify_test.cpp +++ b/test/onnx/verify/eyelike_verify_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(eyelike_verify_test) { - migraphx::program p = migraphx::parse_onnx("eyelike_verify_test.onnx"); + migraphx::program p = read_onnx("eyelike_verify_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3, 4}}; diff --git a/test/onnx/verify/gather_elements_axis0_test.cpp b/test/onnx/verify/gather_elements_axis0_test.cpp index c307df1a602..36aee16739c 100644 --- a/test/onnx/verify/gather_elements_axis0_test.cpp +++ b/test/onnx/verify/gather_elements_axis0_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gather_elements) { - migraphx::program p = migraphx::parse_onnx("gather_elements_axis0_test.onnx"); + migraphx::program p = read_onnx("gather_elements_axis0_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::float_type, {3, 4}}; std::vector data = { diff --git a/test/onnx/verify/gelu_add_bias_test.cpp b/test/onnx/verify/gelu_add_bias_test.cpp index 68a693b2e7e..dbfd891aeb3 100644 --- a/test/onnx/verify/gelu_add_bias_test.cpp +++ b/test/onnx/verify/gelu_add_bias_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gelu_add_bias_test) { - migraphx::program p = migraphx::parse_onnx("gelu_add_bias_test.onnx"); + migraphx::program p = read_onnx("gelu_add_bias_test.onnx"); p.compile(migraphx::make_target("ref")); auto input_type = migraphx::shape::float_type; diff --git a/test/onnx/verify/gelu_default_half_test.cpp b/test/onnx/verify/gelu_default_half_test.cpp index 2cdf2dd1b3a..c2d801343ff 100644 --- a/test/onnx/verify/gelu_default_half_test.cpp +++ b/test/onnx/verify/gelu_default_half_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gelu_default_half_test) { - migraphx::program p = migraphx::parse_onnx("gelu_default_half_test.onnx"); + migraphx::program p = read_onnx("gelu_default_half_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{3, 3}; diff --git a/test/onnx/verify/gelu_default_test.cpp b/test/onnx/verify/gelu_default_test.cpp index aa89f42efc4..880e3df3e6e 100644 --- a/test/onnx/verify/gelu_default_test.cpp +++ b/test/onnx/verify/gelu_default_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gelu_default_test) { - migraphx::program p = migraphx::parse_onnx("gelu_default_test.onnx"); + migraphx::program p = read_onnx("gelu_default_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{3, 3}; diff --git a/test/onnx/verify/gelu_fast_bias_test.cpp b/test/onnx/verify/gelu_fast_bias_test.cpp index 4c881adf33b..e686dc9b567 100644 --- a/test/onnx/verify/gelu_fast_bias_test.cpp +++ b/test/onnx/verify/gelu_fast_bias_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gelu_fast_bias_test) { - migraphx::program p = migraphx::parse_onnx("gelu_fast_bias_test.onnx"); + migraphx::program p = read_onnx("gelu_fast_bias_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape shape{migraphx::shape::half_type, {3, 3}}; diff --git a/test/onnx/verify/gelu_tanh_test.cpp b/test/onnx/verify/gelu_tanh_test.cpp index e820ea45027..a5125d06407 100644 --- a/test/onnx/verify/gelu_tanh_test.cpp +++ b/test/onnx/verify/gelu_tanh_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gelu_tanh_test) { - migraphx::program p = migraphx::parse_onnx("gelu_tanh_test.onnx"); + migraphx::program p = read_onnx("gelu_tanh_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{3, 3}; diff --git a/test/onnx/verify/gemm_brcst_C_test.cpp b/test/onnx/verify/gemm_brcst_C_test.cpp index b0475ec793b..168a354f305 100644 --- a/test/onnx/verify/gemm_brcst_C_test.cpp +++ b/test/onnx/verify/gemm_brcst_C_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gemm_test) { - migraphx::program p = migraphx::parse_onnx("gemm_brcst_C_test.onnx"); + migraphx::program p = read_onnx("gemm_brcst_C_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a_shape{migraphx::shape::float_type, {5, 6}}; diff --git a/test/onnx/verify/gemm_half_test.cpp b/test/onnx/verify/gemm_half_test.cpp index f035200deb2..9dfcebf95e6 100644 --- a/test/onnx/verify/gemm_half_test.cpp +++ b/test/onnx/verify/gemm_half_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(gemm_half_test) { - migraphx::program p = migraphx::parse_onnx("gemm_half_test.onnx"); + migraphx::program p = read_onnx("gemm_half_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a_shape{migraphx::shape::half_type, {8, 6}}; diff --git a/test/onnx/verify/greaterorequal_test.cpp b/test/onnx/verify/greaterorequal_test.cpp index eb7d187b806..2db044f911c 100644 --- a/test/onnx/verify/greaterorequal_test.cpp +++ b/test/onnx/verify/greaterorequal_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(greaterorequal_test) { - migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx"); + migraphx::program p = read_onnx("greaterorequal_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3}}; diff --git a/test/onnx/verify/group_norm_3d_half_test.cpp b/test/onnx/verify/group_norm_3d_half_test.cpp index cbb66aa0b50..4800257a92e 100644 --- a/test/onnx/verify/group_norm_3d_half_test.cpp +++ b/test/onnx/verify/group_norm_3d_half_test.cpp @@ -32,8 +32,8 @@ TEST_CASE(group_norm_half_test) using migraphx::half; std::vector scale{half{1.2}, half{0.8}}; std::vector bias{half{0.5}, half{0.2}}; - std::vector result_vector = norm_test( - {1, 4, 2}, scale, bias, migraphx::parse_onnx("group_norm_3d_half_test.onnx")); + std::vector result_vector = + norm_test({1, 4, 2}, scale, bias, read_onnx("group_norm_3d_half_test.onnx")); std::vector gold = {half{-1.10996256}, half{-0.0366542}, half{1.0366542}, diff --git a/test/onnx/verify/group_norm_3d_test.cpp b/test/onnx/verify/group_norm_3d_test.cpp index 1403a8849de..1ca79dc6b89 100644 --- a/test/onnx/verify/group_norm_3d_test.cpp +++ b/test/onnx/verify/group_norm_3d_test.cpp @@ -32,7 +32,7 @@ TEST_CASE(group_norm_test) std::vector scale{1.2, 0.8}; std::vector bias{0.5, 0.2}; std::vector result_vector = - norm_test({1, 4, 2}, scale, bias, migraphx::parse_onnx("group_norm_3d_test.onnx")); + norm_test({1, 4, 2}, scale, bias, read_onnx("group_norm_3d_test.onnx")); std::vector gold = {-1.10996256, -0.0366542, 1.0366542, diff --git a/test/onnx/verify/hardmax_axis_neg_test.cpp b/test/onnx/verify/hardmax_axis_neg_test.cpp index dac316a77be..e76a3cd2ee8 100644 --- a/test/onnx/verify/hardmax_axis_neg_test.cpp +++ b/test/onnx/verify/hardmax_axis_neg_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardmax_axis_neg_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardmax_axis_neg_test.onnx"); + migraphx::program p = read_onnx("hardmax_axis_neg_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{1, 2, 3, 4}; diff --git a/test/onnx/verify/hardmax_axis_neg_ver11_test.cpp b/test/onnx/verify/hardmax_axis_neg_ver11_test.cpp index 0635a69acae..480a9f6ef8e 100644 --- a/test/onnx/verify/hardmax_axis_neg_ver11_test.cpp +++ b/test/onnx/verify/hardmax_axis_neg_ver11_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardmax_axis_neg_ver11_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardmax_axis_neg_ver11_test.onnx"); + migraphx::program p = read_onnx("hardmax_axis_neg_ver11_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{1, 2, 3, 4}; diff --git a/test/onnx/verify/hardmax_axis_test.cpp b/test/onnx/verify/hardmax_axis_test.cpp index 1a1ea0271dc..87c16aee470 100644 --- a/test/onnx/verify/hardmax_axis_test.cpp +++ b/test/onnx/verify/hardmax_axis_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardmax_axis_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardmax_axis_test.onnx"); + migraphx::program p = read_onnx("hardmax_axis_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{1, 2, 3, 4}; diff --git a/test/onnx/verify/hardmax_axis_ver11_test.cpp b/test/onnx/verify/hardmax_axis_ver11_test.cpp index 6b1f54242ef..a37762fc2ab 100644 --- a/test/onnx/verify/hardmax_axis_ver11_test.cpp +++ b/test/onnx/verify/hardmax_axis_ver11_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardmax_axis_ver11_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardmax_axis_ver11_test.onnx"); + migraphx::program p = read_onnx("hardmax_axis_ver11_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{1, 2, 3, 4}; diff --git a/test/onnx/verify/hardmax_default_test.cpp b/test/onnx/verify/hardmax_default_test.cpp index 0e25e599ec6..c3d5a83f1b6 100644 --- a/test/onnx/verify/hardmax_default_test.cpp +++ b/test/onnx/verify/hardmax_default_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardmax_default_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardmax_default_test.onnx"); + migraphx::program p = read_onnx("hardmax_default_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{1, 2, 3, 4}; diff --git a/test/onnx/verify/hardmax_default_ver11_test.cpp b/test/onnx/verify/hardmax_default_ver11_test.cpp index 9e85a59f217..bcf3b3c9f1f 100644 --- a/test/onnx/verify/hardmax_default_ver11_test.cpp +++ b/test/onnx/verify/hardmax_default_ver11_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardmax_default_ver11_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardmax_default_ver11_test.onnx"); + migraphx::program p = read_onnx("hardmax_default_ver11_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector input_lens{1, 2, 3, 4}; diff --git a/test/onnx/verify/hardsigmoid_verify_test.cpp b/test/onnx/verify/hardsigmoid_verify_test.cpp index d5e17cab3a9..9955330e2d1 100644 --- a/test/onnx/verify/hardsigmoid_verify_test.cpp +++ b/test/onnx/verify/hardsigmoid_verify_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(hardsigmoid_verify_test) { - migraphx::program p = migraphx::parse_onnx("hardsigmoid_verify_test.onnx"); + migraphx::program p = read_onnx("hardsigmoid_verify_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {2, 5}}; diff --git a/test/onnx/verify/if_else_test.cpp b/test/onnx/verify/if_else_test.cpp index 506b8fb0e18..9ee8882b6ac 100644 --- a/test/onnx/verify/if_else_test.cpp +++ b/test/onnx/verify/if_else_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(if_else_test) { - migraphx::program p = migraphx::parse_onnx("if_else_test.onnx"); + migraphx::program p = read_onnx("if_else_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; std::vector data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; diff --git a/test/onnx/verify/if_else_test_inlined.cpp b/test/onnx/verify/if_else_test_inlined.cpp index 562a5aed430..f181ef1f2c2 100644 --- a/test/onnx/verify/if_else_test_inlined.cpp +++ b/test/onnx/verify/if_else_test_inlined.cpp @@ -28,7 +28,7 @@ TEST_CASE(if_else_test_inlined) { - migraphx::program p = migraphx::parse_onnx("if_else_test_inlined.onnx"); + migraphx::program p = read_onnx("if_else_test_inlined.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; std::vector data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; diff --git a/test/onnx/verify/if_literal_test.cpp b/test/onnx/verify/if_literal_test.cpp index f91e158696f..e779c69114c 100644 --- a/test/onnx/verify/if_literal_test.cpp +++ b/test/onnx/verify/if_literal_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(if_literal_test) { auto run_prog = [](bool cond) { - migraphx::program p = migraphx::parse_onnx("if_literal_test.onnx"); + migraphx::program p = read_onnx("if_literal_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::bool_type}; std::vector data = {static_cast(cond)}; diff --git a/test/onnx/verify/if_pl_test.cpp b/test/onnx/verify/if_pl_test.cpp index d5cebf4e893..a3fe9289c9e 100644 --- a/test/onnx/verify/if_pl_test.cpp +++ b/test/onnx/verify/if_pl_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(if_pl_test) { auto run_prog = [](bool cond) { - migraphx::program p = migraphx::parse_onnx("if_pl_test.onnx"); + migraphx::program p = read_onnx("if_pl_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; diff --git a/test/onnx/verify/if_then_else_multi_output_shapes_inlined_test.cpp b/test/onnx/verify/if_then_else_multi_output_shapes_inlined_test.cpp index 548529a0cb9..5733dc42f62 100644 --- a/test/onnx/verify/if_then_else_multi_output_shapes_inlined_test.cpp +++ b/test/onnx/verify/if_then_else_multi_output_shapes_inlined_test.cpp @@ -28,8 +28,7 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test) { - migraphx::program p = - migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx"); + migraphx::program p = read_onnx("if_then_else_multi_output_shapes_inlined_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x_data{migraphx::shape::float_type, {2, 3, 1}}; migraphx::shape y_data{migraphx::shape::float_type, {2, 3}}; diff --git a/test/onnx/verify/if_then_else_multi_output_shapes_test.cpp b/test/onnx/verify/if_then_else_multi_output_shapes_test.cpp index 457b7dac111..34ec7575b99 100644 --- a/test/onnx/verify/if_then_else_multi_output_shapes_test.cpp +++ b/test/onnx/verify/if_then_else_multi_output_shapes_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(if_then_else_multi_output_shapes_test) { - migraphx::program p = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx"); + migraphx::program p = read_onnx("if_then_else_multi_output_shapes_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::float_type, {2, 3, 1}}; std::vector data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; diff --git a/test/onnx/verify/if_then_test.cpp b/test/onnx/verify/if_then_test.cpp index 4f015836fd9..2b09d2a565f 100644 --- a/test/onnx/verify/if_then_test.cpp +++ b/test/onnx/verify/if_then_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(if_then_test) { - migraphx::program p = migraphx::parse_onnx("if_then_test.onnx"); + migraphx::program p = read_onnx("if_then_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; std::vector data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; diff --git a/test/onnx/verify/if_then_test_inlined.cpp b/test/onnx/verify/if_then_test_inlined.cpp index e4a740f8856..d7c20be2642 100644 --- a/test/onnx/verify/if_then_test_inlined.cpp +++ b/test/onnx/verify/if_then_test_inlined.cpp @@ -28,7 +28,7 @@ TEST_CASE(if_then_test_inlined) { - migraphx::program p = migraphx::parse_onnx("if_then_test_inlined.onnx"); + migraphx::program p = read_onnx("if_then_test_inlined.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; std::vector data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; diff --git a/test/onnx/verify/if_tuple_test.cpp b/test/onnx/verify/if_tuple_test.cpp index 23c690fe636..364a4eedcfe 100644 --- a/test/onnx/verify/if_tuple_test.cpp +++ b/test/onnx/verify/if_tuple_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(if_tuple_test) { auto run_prog = [](bool cond) { - migraphx::program p = migraphx::parse_onnx("if_tuple_test.onnx"); + migraphx::program p = read_onnx("if_tuple_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::float_type, {1, 4}}; migraphx::shape ys{migraphx::shape::float_type, {3, 4}}; diff --git a/test/onnx/verify/instance_norm_dyn_batch_test.cpp b/test/onnx/verify/instance_norm_dyn_batch_test.cpp index 4adf93b7022..0c4bfce21d2 100644 --- a/test/onnx/verify/instance_norm_dyn_batch_test.cpp +++ b/test/onnx/verify/instance_norm_dyn_batch_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(instance_norm_dyn_batch_test) { - migraphx::program p = migraphx::parse_onnx("instance_norm_dyn_batch_test.onnx"); + migraphx::program p = read_onnx("instance_norm_dyn_batch_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::float_type, {1, 2, 3, 3}}; diff --git a/test/onnx/verify/instance_norm_val_3d_test.cpp b/test/onnx/verify/instance_norm_val_3d_test.cpp index b2d733ab3c3..44fe6998d33 100644 --- a/test/onnx/verify/instance_norm_val_3d_test.cpp +++ b/test/onnx/verify/instance_norm_val_3d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(instance_norm_3d_test) { - migraphx::program p = migraphx::parse_onnx("instance_norm_val_3d_test.onnx"); + migraphx::program p = read_onnx("instance_norm_val_3d_test.onnx"); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); diff --git a/test/onnx/verify/instance_norm_val_test.cpp b/test/onnx/verify/instance_norm_val_test.cpp index d3da4df1447..037fa463b63 100644 --- a/test/onnx/verify/instance_norm_val_test.cpp +++ b/test/onnx/verify/instance_norm_val_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(instance_norm_test) { - migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); + migraphx::program p = read_onnx("instance_norm_val_test.onnx"); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); diff --git a/test/onnx/verify/isinf_double_pos_test.cpp b/test/onnx/verify/isinf_double_pos_test.cpp index 5e571979045..93bb5b9d6b9 100644 --- a/test/onnx/verify/isinf_double_pos_test.cpp +++ b/test/onnx/verify/isinf_double_pos_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(isinf_double_pos_test) { - migraphx::program p = migraphx::parse_onnx("isinf_double_pos_test.onnx"); + migraphx::program p = read_onnx("isinf_double_pos_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::double_type, {2, 3}}; diff --git a/test/onnx/verify/isinf_half_test.cpp b/test/onnx/verify/isinf_half_test.cpp index ef4add8bd3f..27a9bcb683c 100644 --- a/test/onnx/verify/isinf_half_test.cpp +++ b/test/onnx/verify/isinf_half_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(isinf_half_test) { - migraphx::program p = migraphx::parse_onnx("isinf_half_test.onnx"); + migraphx::program p = read_onnx("isinf_half_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::half_type, {2, 3}}; diff --git a/test/onnx/verify/isinf_neg_test.cpp b/test/onnx/verify/isinf_neg_test.cpp index 4bb73cb74a0..91cb0fa4cb2 100644 --- a/test/onnx/verify/isinf_neg_test.cpp +++ b/test/onnx/verify/isinf_neg_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(isinf_neg_test) { - migraphx::program p = migraphx::parse_onnx("isinf_neg_test.onnx"); + migraphx::program p = read_onnx("isinf_neg_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/onnx/verify/isinf_no_detect_test.cpp b/test/onnx/verify/isinf_no_detect_test.cpp index f7056a89867..00d763a9257 100644 --- a/test/onnx/verify/isinf_no_detect_test.cpp +++ b/test/onnx/verify/isinf_no_detect_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(isinf_no_detect_test) { - migraphx::program p = migraphx::parse_onnx("isinf_no_detect_test.onnx"); + migraphx::program p = read_onnx("isinf_no_detect_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/onnx/verify/layer_norm_3d_half_test.cpp b/test/onnx/verify/layer_norm_3d_half_test.cpp index 58de25dbd18..6b48840d73f 100644 --- a/test/onnx/verify/layer_norm_3d_half_test.cpp +++ b/test/onnx/verify/layer_norm_3d_half_test.cpp @@ -32,8 +32,8 @@ TEST_CASE(layer_norm_half_test) using migraphx::half; std::vector scale{half{1.2}, half{0.8}}; std::vector bias{half{0.5}, half{0.2}}; - std::vector result_vector = norm_test( - {1, 4, 2}, scale, bias, migraphx::parse_onnx("layer_norm_3d_half_test.onnx")); + std::vector result_vector = + norm_test({1, 4, 2}, scale, bias, read_onnx("layer_norm_3d_half_test.onnx")); std::vector gold = {half{-0.69997597}, half{0.99998398}, half{-0.69997597}, diff --git a/test/onnx/verify/layer_norm_3d_test.cpp b/test/onnx/verify/layer_norm_3d_test.cpp index 1fddf50262c..711151dbd93 100644 --- a/test/onnx/verify/layer_norm_3d_test.cpp +++ b/test/onnx/verify/layer_norm_3d_test.cpp @@ -32,7 +32,7 @@ TEST_CASE(layer_norm_test) std::vector scale{1.2, 0.8}; std::vector bias{0.5, 0.2}; std::vector result_vector = - norm_test({1, 4, 2}, scale, bias, migraphx::parse_onnx("layer_norm_3d_test.onnx")); + norm_test({1, 4, 2}, scale, bias, read_onnx("layer_norm_3d_test.onnx")); std::vector gold = {-0.69997597, 0.99998398, -0.69997597, diff --git a/test/onnx/verify/lessorequal_test.cpp b/test/onnx/verify/lessorequal_test.cpp index 223d8d17541..e239a328e49 100644 --- a/test/onnx/verify/lessorequal_test.cpp +++ b/test/onnx/verify/lessorequal_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(lessorequal_test) { - migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx"); + migraphx::program p = read_onnx("lessorequal_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3}}; diff --git a/test/onnx/verify/lpnormalization_l1_test.cpp b/test/onnx/verify/lpnormalization_l1_test.cpp index 977d129d8eb..a8737347f1a 100644 --- a/test/onnx/verify/lpnormalization_l1_test.cpp +++ b/test/onnx/verify/lpnormalization_l1_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(lpnormalization_1norm) { - migraphx::program p = migraphx::parse_onnx("lpnormalization_l1_test.onnx"); + migraphx::program p = read_onnx("lpnormalization_l1_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3, 4}}; std::vector data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f}; diff --git a/test/onnx/verify/lpnormalization_l2_test.cpp b/test/onnx/verify/lpnormalization_l2_test.cpp index 26856ed7080..ca2696cf478 100644 --- a/test/onnx/verify/lpnormalization_l2_test.cpp +++ b/test/onnx/verify/lpnormalization_l2_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(lpnormalization_2norm) { - migraphx::program p = migraphx::parse_onnx("lpnormalization_l2_test.onnx"); + migraphx::program p = read_onnx("lpnormalization_l2_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3, 4}}; std::vector data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f}; diff --git a/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp b/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp index 9b65e086a09..cc15aad6baa 100644 --- a/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp +++ b/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(matmulinteger_int8_uint8_dual_zp_test) { - migraphx::program p = migraphx::parse_onnx("matmulinteger_int8_uint8_dual_zp_test.onnx"); + migraphx::program p = read_onnx("matmulinteger_int8_uint8_dual_zp_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; diff --git a/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp b/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp index 6777df675c4..c7be8b45cb6 100644 --- a/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp +++ b/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(matmulinteger_int8_uint8_one_zp_test) { - migraphx::program p = migraphx::parse_onnx("matmulinteger_int8_uint8_one_zp_test.onnx"); + migraphx::program p = read_onnx("matmulinteger_int8_uint8_one_zp_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; diff --git a/test/onnx/verify/matmulinteger_int8_uint8_test.cpp b/test/onnx/verify/matmulinteger_int8_uint8_test.cpp index 5448942d988..cd7665c8dc0 100644 --- a/test/onnx/verify/matmulinteger_int8_uint8_test.cpp +++ b/test/onnx/verify/matmulinteger_int8_uint8_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(matmulinteger_int8_uint8_test) { - migraphx::program p = migraphx::parse_onnx("matmulinteger_int8_uint8_test.onnx"); + migraphx::program p = read_onnx("matmulinteger_int8_uint8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; diff --git a/test/onnx/verify/matmulinteger_uns_test.cpp b/test/onnx/verify/matmulinteger_uns_test.cpp index 0aaa2a00781..1392edddea2 100644 --- a/test/onnx/verify/matmulinteger_uns_test.cpp +++ b/test/onnx/verify/matmulinteger_uns_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(matmulinteger_uns_test) { - migraphx::program p = migraphx::parse_onnx("matmulinteger_uns_test.onnx"); + migraphx::program p = read_onnx("matmulinteger_uns_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::uint8_type, {4, 3}}; diff --git a/test/onnx/verify/matmulinteger_uns_zp_test.cpp b/test/onnx/verify/matmulinteger_uns_zp_test.cpp index 0eae90d71cd..29ffd7f6d37 100644 --- a/test/onnx/verify/matmulinteger_uns_zp_test.cpp +++ b/test/onnx/verify/matmulinteger_uns_zp_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(matmulinteger_uns_zp_test) { - migraphx::program p = migraphx::parse_onnx("matmulinteger_uns_zp_test.onnx"); + migraphx::program p = read_onnx("matmulinteger_uns_zp_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::uint8_type, {4, 3}}; diff --git a/test/onnx/verify/mean_broadcast_test.cpp b/test/onnx/verify/mean_broadcast_test.cpp index fb48936bfc4..0f34032ce69 100644 --- a/test/onnx/verify/mean_broadcast_test.cpp +++ b/test/onnx/verify/mean_broadcast_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(mean_broadcast_test) { - migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx"); + migraphx::program p = read_onnx("mean_broadcast_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s0{migraphx::shape::float_type, {1, 3, 4}}; diff --git a/test/onnx/verify/mean_integral_test.cpp b/test/onnx/verify/mean_integral_test.cpp index eaf4135e753..52cb522f111 100644 --- a/test/onnx/verify/mean_integral_test.cpp +++ b/test/onnx/verify/mean_integral_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(mean_integral_test) { - migraphx::program p = migraphx::parse_onnx("mean_integral_test.onnx"); + migraphx::program p = read_onnx("mean_integral_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::int32_type, {2, 2, 2}}; diff --git a/test/onnx/verify/mean_test.cpp b/test/onnx/verify/mean_test.cpp index 59ae8fafc88..dc12b6fac53 100644 --- a/test/onnx/verify/mean_test.cpp +++ b/test/onnx/verify/mean_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(mean_test) { - migraphx::program p = migraphx::parse_onnx("mean_test.onnx"); + migraphx::program p = read_onnx("mean_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::double_type, {2, 2, 2}}; diff --git a/test/onnx/verify/mod_test.cpp b/test/onnx/verify/mod_test.cpp index e6ede089462..d6676c42383 100644 --- a/test/onnx/verify/mod_test.cpp +++ b/test/onnx/verify/mod_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(mod_test) { - migraphx::program p = migraphx::parse_onnx("mod_test.onnx"); + migraphx::program p = read_onnx("mod_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::int32_type, {3, 3, 3}}; diff --git a/test/onnx/verify/mod_test_different_dtypes.cpp b/test/onnx/verify/mod_test_different_dtypes.cpp index 60125019de6..82f7e819dc2 100644 --- a/test/onnx/verify/mod_test_different_dtypes.cpp +++ b/test/onnx/verify/mod_test_different_dtypes.cpp @@ -28,7 +28,7 @@ TEST_CASE(mod_test_different_types) { - migraphx::program p = migraphx::parse_onnx("mod_test_different_dtypes.onnx"); + migraphx::program p = read_onnx("mod_test_different_dtypes.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_int16{migraphx::shape::int16_type, {3, 3, 3}}; diff --git a/test/onnx/verify/mod_test_fmod.cpp b/test/onnx/verify/mod_test_fmod.cpp index fcd49eba6e8..c5b92d92276 100644 --- a/test/onnx/verify/mod_test_fmod.cpp +++ b/test/onnx/verify/mod_test_fmod.cpp @@ -28,7 +28,7 @@ TEST_CASE(mod_test_fmod) { - migraphx::program p = migraphx::parse_onnx("mod_test_fmod.onnx"); + migraphx::program p = read_onnx("mod_test_fmod.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}}; diff --git a/test/onnx/verify/mod_test_fmod_different_dtypes.cpp b/test/onnx/verify/mod_test_fmod_different_dtypes.cpp index d59fb08edc0..bc9fa0510e5 100644 --- a/test/onnx/verify/mod_test_fmod_different_dtypes.cpp +++ b/test/onnx/verify/mod_test_fmod_different_dtypes.cpp @@ -28,7 +28,7 @@ TEST_CASE(mod_test_fmod_different_types) { - migraphx::program p = migraphx::parse_onnx("mod_test_fmod_different_dtypes.onnx"); + migraphx::program p = read_onnx("mod_test_fmod_different_dtypes.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s_float{migraphx::shape::float_type, {3, 3, 3}}; diff --git a/test/onnx/verify/multinomial_dyn_test.cpp b/test/onnx/verify/multinomial_dyn_test.cpp index e05a3890017..2f803129980 100644 --- a/test/onnx/verify/multinomial_dyn_test.cpp +++ b/test/onnx/verify/multinomial_dyn_test.cpp @@ -30,7 +30,7 @@ TEST_CASE(multinomial_dyn_test) { migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto p = migraphx::parse_onnx("multinomial_dyn_test.onnx", options); + auto p = read_onnx("multinomial_dyn_test.onnx", options); const size_t batch_size(2); const size_t categories(5); const size_t sample_size(100000); diff --git a/test/onnx/verify/mvn_default_axes_fp16_test.cpp b/test/onnx/verify/mvn_default_axes_fp16_test.cpp index efda706bfb4..f2ec40b3ce1 100644 --- a/test/onnx/verify/mvn_default_axes_fp16_test.cpp +++ b/test/onnx/verify/mvn_default_axes_fp16_test.cpp @@ -30,8 +30,7 @@ TEST_CASE(mvn_default_axes_fp16_test) { using migraphx::half; - auto result = - mvn_test({2, 2, 2, 2}, migraphx::parse_onnx("mvn_default_axes_fp16_test.onnx")); + auto result = mvn_test({2, 2, 2, 2}, read_onnx("mvn_default_axes_fp16_test.onnx")); std::vector gold{half{-1.324}, half{-1.084}, half{-0.843}, diff --git a/test/onnx/verify/mvn_default_axes_test.cpp b/test/onnx/verify/mvn_default_axes_test.cpp index b9424cf35b7..ffaef035505 100644 --- a/test/onnx/verify/mvn_default_axes_test.cpp +++ b/test/onnx/verify/mvn_default_axes_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(mvn_default_axes_test) { - auto result = mvn_test({2, 2, 2, 2}, migraphx::parse_onnx("mvn_default_axes_test.onnx")); + auto result = mvn_test({2, 2, 2, 2}, read_onnx("mvn_default_axes_test.onnx")); std::vector gold{-1.32424438, -1.08347268, -0.84270097, diff --git a/test/onnx/verify/mvn_rank_2_fp16_test.cpp b/test/onnx/verify/mvn_rank_2_fp16_test.cpp index 15de3b46bc2..e7e5b956e79 100644 --- a/test/onnx/verify/mvn_rank_2_fp16_test.cpp +++ b/test/onnx/verify/mvn_rank_2_fp16_test.cpp @@ -30,8 +30,7 @@ TEST_CASE(mvn_rank_2_fp16_test) { using migraphx::half; - auto result = - mvn_test({2, 2}, migraphx::parse_onnx("mvn_rank_2_fp16_test.onnx")); + auto result = mvn_test({2, 2}, read_onnx("mvn_rank_2_fp16_test.onnx")); std::vector gold{half{-1}, half{1}, half{-1}, half{1}}; EXPECT(migraphx::verify::verify_rms_range(result, gold)); } diff --git a/test/onnx/verify/mvn_rank_2_test.cpp b/test/onnx/verify/mvn_rank_2_test.cpp index 1369d3277a9..508271a094e 100644 --- a/test/onnx/verify/mvn_rank_2_test.cpp +++ b/test/onnx/verify/mvn_rank_2_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(mvn_rank_2_test) { - auto result = mvn_test({2, 2}, migraphx::parse_onnx("mvn_rank_2_test.onnx")); + auto result = mvn_test({2, 2}, read_onnx("mvn_rank_2_test.onnx")); std::vector gold{-1, 1, -1, 1}; EXPECT(migraphx::verify::verify_rms_range(result, gold)); } diff --git a/test/onnx/verify/mvn_rank_3_fp16_test.cpp b/test/onnx/verify/mvn_rank_3_fp16_test.cpp index 264eae323c0..901e5dae6bc 100644 --- a/test/onnx/verify/mvn_rank_3_fp16_test.cpp +++ b/test/onnx/verify/mvn_rank_3_fp16_test.cpp @@ -30,7 +30,7 @@ TEST_CASE(mvn_rank_3_fp16_test) { using migraphx::half; - auto result = mvn_test({2, 2, 2}, migraphx::parse_onnx("mvn_rank_3_fp16_test.onnx")); + auto result = mvn_test({2, 2, 2}, read_onnx("mvn_rank_3_fp16_test.onnx")); std::vector gold{half{-1.342}, half{-1.342}, half{-0.4473}, diff --git a/test/onnx/verify/mvn_rank_3_test.cpp b/test/onnx/verify/mvn_rank_3_test.cpp index 86316c8ab9e..4440be61d67 100644 --- a/test/onnx/verify/mvn_rank_3_test.cpp +++ b/test/onnx/verify/mvn_rank_3_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(mvn_rank_3_test) { - auto result = mvn_test({2, 2, 2}, migraphx::parse_onnx("mvn_rank_3_test.onnx")); + auto result = mvn_test({2, 2, 2}, read_onnx("mvn_rank_3_test.onnx")); std::vector gold{-1.34164079, -1.34164079, -0.4472136, diff --git a/test/onnx/verify/nonzero_dynamic_test.cpp b/test/onnx/verify/nonzero_dynamic_test.cpp index bfed855fe9a..4e12eaa1095 100644 --- a/test/onnx/verify/nonzero_dynamic_test.cpp +++ b/test/onnx/verify/nonzero_dynamic_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(nonzero_test) { - migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); + migraphx::program p = read_onnx("nonzero_dynamic_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::bool_type, {2, 2}}; diff --git a/test/onnx/verify/qlinearadd_bcast_test.cpp b/test/onnx/verify/qlinearadd_bcast_test.cpp index dd67e9b1f06..3fe7a3427c3 100644 --- a/test/onnx/verify/qlinearadd_bcast_test.cpp +++ b/test/onnx/verify/qlinearadd_bcast_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearadd_bcast_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAdd - migraphx::program p = migraphx::parse_onnx("qlinearadd_bcast_test.onnx"); + migraphx::program p = read_onnx("qlinearadd_bcast_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {64}}; diff --git a/test/onnx/verify/qlinearadd_test.cpp b/test/onnx/verify/qlinearadd_test.cpp index 12de720c94d..3c278869aeb 100644 --- a/test/onnx/verify/qlinearadd_test.cpp +++ b/test/onnx/verify/qlinearadd_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearadd_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAdd - migraphx::program p = migraphx::parse_onnx("qlinearadd_test.onnx"); + migraphx::program p = read_onnx("qlinearadd_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {64}}; diff --git a/test/onnx/verify/qlinearaveragepool_1d_test.cpp b/test/onnx/verify/qlinearaveragepool_1d_test.cpp index 4c260ff0338..9288bad22a6 100644 --- a/test/onnx/verify/qlinearaveragepool_1d_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_1d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_1d_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_1d_test.onnx"); + auto p = read_onnx("qlinearaveragepool_1d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = { -31, 51, 125, 30, -17, -125, 121, -19, -13, 52, 18, -70, 97, 15, 56, 42, diff --git a/test/onnx/verify/qlinearaveragepool_2d_ceil_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_ceil_test.cpp index d0778d4fba7..4fa759c0f8f 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_ceil_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_ceil_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_ceil_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_ceil_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_ceil_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32}; migraphx::shape s_x{migraphx::shape::uint8_type, {1, 1, 4, 4}}; diff --git a/test/onnx/verify/qlinearaveragepool_2d_dilations_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_dilations_test.cpp index d57ae4e0e3a..07e1e5bd578 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_dilations_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_dilations_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_dilations_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_dilations_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_dilations_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32}; migraphx::shape s_x{migraphx::shape::int8_type, {1, 1, 4, 4}}; diff --git a/test/onnx/verify/qlinearaveragepool_2d_pads_count_include_pad_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_pads_count_include_pad_test.cpp index 7f2795b0348..a17cced175b 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_pads_count_include_pad_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_pads_count_include_pad_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_pads_count_include_pad_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_pads_count_include_pad_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_pads_count_include_pad_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {-30, 50, 91, -87, -21, -113, -16, 6, -128, 104, 82, -126, 54, 41, -71, 62, -11, -111, 13, 104, -43, -48, 30, 85, diff --git a/test/onnx/verify/qlinearaveragepool_2d_same_lower_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_same_lower_test.cpp index cb636db82c4..77ce0ec3879 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_same_lower_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_same_lower_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_same_lower_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_same_lower_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_same_lower_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {195, 102, 250, 61, 222, 6, 243, 218, 230, 105, 36, 116, 194, 31, 113, 85, 126, 204, 80, 38, 115, 167, 221, 67, diff --git a/test/onnx/verify/qlinearaveragepool_2d_same_upper_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_same_upper_test.cpp index 1daad5cbfa0..27d34abcd82 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_same_upper_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_same_upper_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_same_upper_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_same_upper_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_same_upper_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {-61, 102, -6, 61, -34, 6, -13, -38, -26, 105, 36, 116, -62, 31, 113, 85, 126, -52, 80, 38, 115, -89, -35, 67, diff --git a/test/onnx/verify/qlinearaveragepool_2d_strides_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_strides_test.cpp index f695c38b518..4c8d9a83cb2 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_strides_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_strides_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_strides_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_strides_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_strides_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = { 84, -73, 117, -2, -97, 72, 67, 27, 1, -44, 110, 51, 9, 7, 58, 113, diff --git a/test/onnx/verify/qlinearaveragepool_2d_test.cpp b/test/onnx/verify/qlinearaveragepool_2d_test.cpp index 547954f9de0..873e8ba5535 100644 --- a/test/onnx/verify/qlinearaveragepool_2d_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_2d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_2d_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_2d_test.onnx"); + auto p = read_onnx("qlinearaveragepool_2d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {84, -73, 117, -2, -97, 72, 67, 27, 1, -44, 110, 51, 9, 7, 58, 113, -34, 34, 124, -20, 6, 66, 68, 98, diff --git a/test/onnx/verify/qlinearaveragepool_3d_test.cpp b/test/onnx/verify/qlinearaveragepool_3d_test.cpp index 9c3644c4fff..2014f878a11 100644 --- a/test/onnx/verify/qlinearaveragepool_3d_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_3d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_3d_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_3d_test.onnx"); + auto p = read_onnx("qlinearaveragepool_3d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = { -61, 102, -6, 61, -34, 6, -13, -38, -26, 105, 36, 116, -62, 31, 113, 85, 126, diff --git a/test/onnx/verify/qlinearaveragepool_notset_test.cpp b/test/onnx/verify/qlinearaveragepool_notset_test.cpp index 4c109195576..d1d554a0681 100644 --- a/test/onnx/verify/qlinearaveragepool_notset_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_notset_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearaveragepool_notset_test) { - auto p = migraphx::parse_onnx("qlinearaveragepool_notset_test.onnx"); + auto p = read_onnx("qlinearaveragepool_notset_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; diff --git a/test/onnx/verify/qlinearaveragepool_nt_cip_test.cpp b/test/onnx/verify/qlinearaveragepool_nt_cip_test.cpp index c0cb078e51a..ae8b09d8678 100644 --- a/test/onnx/verify/qlinearaveragepool_nt_cip_test.cpp +++ b/test/onnx/verify/qlinearaveragepool_nt_cip_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearaveragepool_nt_cip_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool - auto p = migraphx::parse_onnx("qlinearaveragepool_nt_cip_test.onnx"); + auto p = read_onnx("qlinearaveragepool_nt_cip_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; diff --git a/test/onnx/verify/qlinearconcat_3d_test.cpp b/test/onnx/verify/qlinearconcat_3d_test.cpp index d4d2da0f483..8bd11b000d6 100644 --- a/test/onnx/verify/qlinearconcat_3d_test.cpp +++ b/test/onnx/verify/qlinearconcat_3d_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearconcat_3d_test) { - auto p = migraphx::parse_onnx("qlinearconcat_3d_test.onnx"); + auto p = read_onnx("qlinearconcat_3d_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_t0 = {10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, diff --git a/test/onnx/verify/qlinearconcat_test.cpp b/test/onnx/verify/qlinearconcat_test.cpp index c5a2278b5d2..1731ff97243 100644 --- a/test/onnx/verify/qlinearconcat_test.cpp +++ b/test/onnx/verify/qlinearconcat_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearconcat_test) { - auto p = migraphx::parse_onnx("qlinearconcat_test.onnx"); + auto p = read_onnx("qlinearconcat_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_t0 = {2, 3}; diff --git a/test/onnx/verify/qlinearconv_pad_0_test.cpp b/test/onnx/verify/qlinearconv_pad_0_test.cpp index a24f5871ed9..d69bc352cf2 100644 --- a/test/onnx/verify/qlinearconv_pad_0_test.cpp +++ b/test/onnx/verify/qlinearconv_pad_0_test.cpp @@ -30,7 +30,7 @@ TEST_CASE(qlinearconv_pad_0_test) { // https:xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__Conv.html - migraphx::program p = migraphx::parse_onnx("qlinearconv_pad_0_test.onnx"); + migraphx::program p = read_onnx("qlinearconv_pad_0_test.onnx"); p.compile(migraphx::make_target("ref")); diff --git a/test/onnx/verify/qlinearconv_pad_1_test.cpp b/test/onnx/verify/qlinearconv_pad_1_test.cpp index c22461af389..7b10dbcace4 100644 --- a/test/onnx/verify/qlinearconv_pad_1_test.cpp +++ b/test/onnx/verify/qlinearconv_pad_1_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearconv_pad_1_test) { // https:xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__Conv.html - migraphx::program p = migraphx::parse_onnx("qlinearconv_pad_1_test.onnx"); + migraphx::program p = read_onnx("qlinearconv_pad_1_test.onnx"); p.compile(migraphx::make_target("ref")); diff --git a/test/onnx/verify/qlinearconv_scale_1D_test.cpp b/test/onnx/verify/qlinearconv_scale_1D_test.cpp index 4b5ae9187ad..011c190cecc 100644 --- a/test/onnx/verify/qlinearconv_scale_1D_test.cpp +++ b/test/onnx/verify/qlinearconv_scale_1D_test.cpp @@ -30,7 +30,7 @@ TEST_CASE(qlinearconv_scale_1D_test) { // https:xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__Conv.html - migraphx::program p = migraphx::parse_onnx("qlinearconv_scale_1D_test.onnx"); + migraphx::program p = read_onnx("qlinearconv_scale_1D_test.onnx"); p.compile(migraphx::make_target("ref")); diff --git a/test/onnx/verify/qlinearconv_test.cpp b/test/onnx/verify/qlinearconv_test.cpp index da8be32763f..2c82e36cb5c 100644 --- a/test/onnx/verify/qlinearconv_test.cpp +++ b/test/onnx/verify/qlinearconv_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearconv_test) { // https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html - migraphx::program p = migraphx::parse_onnx("qlinearconv_test.onnx"); + migraphx::program p = read_onnx("qlinearconv_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::uint8_type, {1, 1, 7, 7}}; diff --git a/test/onnx/verify/qlinearglobalavgpool_test.cpp b/test/onnx/verify/qlinearglobalavgpool_test.cpp index e83720935fa..fb03cfa1522 100644 --- a/test/onnx/verify/qlinearglobalavgpool_test.cpp +++ b/test/onnx/verify/qlinearglobalavgpool_test.cpp @@ -31,7 +31,7 @@ TEST_CASE(qlinearglobalavgpool_test) // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md // #com.microsoft.QLinearGlobalAveragePool - migraphx::program p = migraphx::parse_onnx("qlinearglobalavgpool_test.onnx"); + migraphx::program p = read_onnx("qlinearglobalavgpool_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sh_x{migraphx::shape::uint8_type, {1, 3, 4, 4}}; diff --git a/test/onnx/verify/qlinearleakyrelu_test.cpp b/test/onnx/verify/qlinearleakyrelu_test.cpp index 2d7a10640f3..4c2be5d7f53 100644 --- a/test/onnx/verify/qlinearleakyrelu_test.cpp +++ b/test/onnx/verify/qlinearleakyrelu_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearleakyrelu_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid - migraphx::program p = migraphx::parse_onnx("qlinearleakyrelu_test.onnx"); + migraphx::program p = read_onnx("qlinearleakyrelu_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x{migraphx::shape::int8_type, {64}}; diff --git a/test/onnx/verify/qlinearmatmul_1D_test.cpp b/test/onnx/verify/qlinearmatmul_1D_test.cpp index 15b1c50540b..fd1ebf1e126 100644 --- a/test/onnx/verify/qlinearmatmul_1D_test.cpp +++ b/test/onnx/verify/qlinearmatmul_1D_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearmatmul_1D_test) { - migraphx::program p = migraphx::parse_onnx("qlinearmatmul_1D_test.onnx"); + migraphx::program p = read_onnx("qlinearmatmul_1D_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {8}}; diff --git a/test/onnx/verify/qlinearmatmul_2D_test.cpp b/test/onnx/verify/qlinearmatmul_2D_test.cpp index af7672aee14..a9b48ebe41d 100644 --- a/test/onnx/verify/qlinearmatmul_2D_test.cpp +++ b/test/onnx/verify/qlinearmatmul_2D_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(qlinearmatmul_2D_test) { - migraphx::program p = migraphx::parse_onnx("qlinearmatmul_2D_test.onnx"); + migraphx::program p = read_onnx("qlinearmatmul_2D_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {1, 8}}; diff --git a/test/onnx/verify/qlinearmatmul_3D_test.cpp b/test/onnx/verify/qlinearmatmul_3D_test.cpp index 612e00bc095..12bf7863006 100644 --- a/test/onnx/verify/qlinearmatmul_3D_test.cpp +++ b/test/onnx/verify/qlinearmatmul_3D_test.cpp @@ -30,7 +30,7 @@ TEST_CASE(qlinearmatmul_3D_test) { // https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearMatMul.html - migraphx::program p = migraphx::parse_onnx("qlinearmatmul_3D_test.onnx"); + migraphx::program p = read_onnx("qlinearmatmul_3D_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {2, 2, 4}}; diff --git a/test/onnx/verify/qlinearmul_bcast_test.cpp b/test/onnx/verify/qlinearmul_bcast_test.cpp index 59d2a7cfbaf..ba279044d16 100644 --- a/test/onnx/verify/qlinearmul_bcast_test.cpp +++ b/test/onnx/verify/qlinearmul_bcast_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearmul_bcast_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul - migraphx::program p = migraphx::parse_onnx("qlinearmul_bcast_test.onnx"); + migraphx::program p = read_onnx("qlinearmul_bcast_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {64}}; diff --git a/test/onnx/verify/qlinearmul_test.cpp b/test/onnx/verify/qlinearmul_test.cpp index 54eadff9ac3..af80ab54c25 100644 --- a/test/onnx/verify/qlinearmul_test.cpp +++ b/test/onnx/verify/qlinearmul_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearmul_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul - migraphx::program p = migraphx::parse_onnx("qlinearmul_test.onnx"); + migraphx::program p = read_onnx("qlinearmul_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {64}}; diff --git a/test/onnx/verify/qlinearsigmoid_test.cpp b/test/onnx/verify/qlinearsigmoid_test.cpp index 035bab70487..ea0f6190a8f 100644 --- a/test/onnx/verify/qlinearsigmoid_test.cpp +++ b/test/onnx/verify/qlinearsigmoid_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(qlinearsigmoid_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid - migraphx::program p = migraphx::parse_onnx("qlinearsigmoid_test.onnx"); + migraphx::program p = read_onnx("qlinearsigmoid_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape x{migraphx::shape::int8_type, {64}}; diff --git a/test/onnx/verify/quant_convolution_dual_bias_test.cpp b/test/onnx/verify/quant_convolution_dual_bias_test.cpp index a4a8f6fab9c..27a9518c4c1 100644 --- a/test/onnx/verify/quant_convolution_dual_bias_test.cpp +++ b/test/onnx/verify/quant_convolution_dual_bias_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(quant_convolution_dual_zero_bias_test) { - migraphx::program p = migraphx::parse_onnx("convinteger_dual_bias_test.onnx"); + migraphx::program p = read_onnx("convinteger_dual_bias_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}}; @@ -82,7 +82,7 @@ TEST_CASE(quant_convolution_dual_zero_bias_test) TEST_CASE(quant_convolution_dual_non_zero_bias_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul - migraphx::program p = migraphx::parse_onnx("convinteger_dual_bias_test.onnx"); + migraphx::program p = read_onnx("convinteger_dual_bias_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}}; diff --git a/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp b/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp index caf587d38fa..127286001ad 100644 --- a/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp +++ b/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(quant_convolution_mismatched_inputs_dual_zero_bias_test) { - migraphx::program p = migraphx::parse_onnx("convinteger_mismatched_inputs_dual_bias_test.onnx"); + migraphx::program p = read_onnx("convinteger_mismatched_inputs_dual_bias_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {1, 3, 5, 5}}; @@ -81,7 +81,7 @@ TEST_CASE(quant_convolution_mismatched_inputs_dual_zero_bias_test) TEST_CASE(quant_convolution_mismatched_inputs_dual_non_zero_bias_test) { - migraphx::program p = migraphx::parse_onnx("convinteger_mismatched_inputs_dual_bias_test.onnx"); + migraphx::program p = read_onnx("convinteger_mismatched_inputs_dual_bias_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::uint8_type, {1, 3, 5, 5}}; diff --git a/test/onnx/verify/quant_convolution_test.cpp b/test/onnx/verify/quant_convolution_test.cpp index 73986631952..e27f9f7a7f6 100644 --- a/test/onnx/verify/quant_convolution_test.cpp +++ b/test/onnx/verify/quant_convolution_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(quant_convolution_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul - migraphx::program p = migraphx::parse_onnx("convinteger_no_bias_test.onnx"); + migraphx::program p = read_onnx("convinteger_no_bias_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}}; diff --git a/test/onnx/verify/reducesum_variable_axes_test.cpp b/test/onnx/verify/reducesum_variable_axes_test.cpp index e747c4db702..f583db770f3 100644 --- a/test/onnx/verify/reducesum_variable_axes_test.cpp +++ b/test/onnx/verify/reducesum_variable_axes_test.cpp @@ -32,7 +32,7 @@ auto reducesum_variable_axes_test_base(const std::string& file, size_t axes_size migraphx::onnx_options options; options.map_input_dims["axes"] = std::vector{axes_size}; - migraphx::program p = migraphx::parse_onnx(file, options); + migraphx::program p = read_onnx(file, options); p.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; diff --git a/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp b/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp index aceb49cd180..fea9880d10e 100644 --- a/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp +++ b/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp @@ -35,7 +35,7 @@ auto reducesum_variable_dynamic_axes_test_base(migraphx::shape axes_shape, migraphx::onnx_options options; const std::vector axes_dims{{0, 3}}; options.map_dyn_input_dims["axes"] = axes_dims; - migraphx::program p = parse_onnx(file, options); + migraphx::program p = read_onnx(file, options); p.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; diff --git a/test/onnx/verify/resize_downsample_f_dyn2_test.cpp b/test/onnx/verify/resize_downsample_f_dyn2_test.cpp index e41d49185e9..87524785f1b 100644 --- a/test/onnx/verify/resize_downsample_f_dyn2_test.cpp +++ b/test/onnx/verify/resize_downsample_f_dyn2_test.cpp @@ -31,12 +31,12 @@ TEST_CASE(resize_downsample_f_dyn2_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto p = migraphx::parse_onnx("resize_downsample_f_dyn2_test.onnx", options); + auto p = read_onnx("resize_downsample_f_dyn2_test.onnx", options); p.compile(migraphx::make_target("ref")); // A Resize op. with static input shape goes through a different code path // but should give same result - auto reference_p = migraphx::parse_onnx("resize_downsample_f_ref2_test.onnx", options); + auto reference_p = read_onnx("resize_downsample_f_ref2_test.onnx", options); reference_p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {2, 1, 5, 9}}; diff --git a/test/onnx/verify/resize_downsample_f_dyn3_test.cpp b/test/onnx/verify/resize_downsample_f_dyn3_test.cpp index 7598785371d..9caa07fbd3e 100644 --- a/test/onnx/verify/resize_downsample_f_dyn3_test.cpp +++ b/test/onnx/verify/resize_downsample_f_dyn3_test.cpp @@ -32,12 +32,12 @@ TEST_CASE(resize_downsample_f_dyn3_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto p = migraphx::parse_onnx("resize_downsample_f_dyn3_test.onnx", options); + auto p = read_onnx("resize_downsample_f_dyn3_test.onnx", options); p.compile(migraphx::make_target("ref")); // A Resize op. with static input shape goes through a different code path // but should give same result - auto reference_p = migraphx::parse_onnx("resize_downsample_f_ref_test.onnx", options); + auto reference_p = read_onnx("resize_downsample_f_ref_test.onnx", options); reference_p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {2, 1, 5, 9}}; diff --git a/test/onnx/verify/resize_downsample_f_dyn_test.cpp b/test/onnx/verify/resize_downsample_f_dyn_test.cpp index 422ec3c2d63..ad00612accb 100644 --- a/test/onnx/verify/resize_downsample_f_dyn_test.cpp +++ b/test/onnx/verify/resize_downsample_f_dyn_test.cpp @@ -31,12 +31,12 @@ TEST_CASE(resize_downsample_f_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto p = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options); + auto p = read_onnx("resize_downsample_f_dyn_test.onnx", options); p.compile(migraphx::make_target("ref")); // A Resize op. with static input shape goes through a different code path // but should give same result - auto reference_p = migraphx::parse_onnx("resize_downsample_f_ref_test.onnx", options); + auto reference_p = read_onnx("resize_downsample_f_ref_test.onnx", options); reference_p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {2, 1, 5, 9}}; diff --git a/test/onnx/verify/resize_downsample_f_test.cpp b/test/onnx/verify/resize_downsample_f_test.cpp index eba40c0212f..bb65752a122 100644 --- a/test/onnx/verify/resize_downsample_f_test.cpp +++ b/test/onnx/verify/resize_downsample_f_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(resize_downsample_f_test) { - migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx"); + migraphx::program p = read_onnx("resize_downsample_f_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; diff --git a/test/onnx/verify/resize_outsize_test.cpp b/test/onnx/verify/resize_outsize_test.cpp index 9a9eed50354..9b8f2be9588 100644 --- a/test/onnx/verify/resize_outsize_test.cpp +++ b/test/onnx/verify/resize_outsize_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(resize_outsize_test) { // resize using output_size input, rather than scales - migraphx::program p = migraphx::parse_onnx("resize_outsize_test.onnx"); + migraphx::program p = read_onnx("resize_outsize_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; diff --git a/test/onnx/verify/resize_upsample_f_dyn_test.cpp b/test/onnx/verify/resize_upsample_f_dyn_test.cpp index 4969ea88296..a2b0c8233f1 100644 --- a/test/onnx/verify/resize_upsample_f_dyn_test.cpp +++ b/test/onnx/verify/resize_upsample_f_dyn_test.cpp @@ -32,7 +32,7 @@ TEST_CASE(resize_upsample_f_dyn_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 10}; - auto p = migraphx::parse_onnx("resize_upsample_f_dyn_test.onnx", options); + auto p = read_onnx("resize_upsample_f_dyn_test.onnx", options); p.compile(migraphx::make_target("ref")); // should upscale to 2x4x8 diff --git a/test/onnx/verify/resize_upsample_linear_ac_test.cpp b/test/onnx/verify/resize_upsample_linear_ac_test.cpp index 4c139924013..4b0cc7b908d 100644 --- a/test/onnx/verify/resize_upsample_linear_ac_test.cpp +++ b/test/onnx/verify/resize_upsample_linear_ac_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(resize_upsample_linear_ac_test) { - migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx"); + migraphx::program p = read_onnx("resize_upsample_linear_ac_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; diff --git a/test/onnx/verify/resize_upsample_linear_test.cpp b/test/onnx/verify/resize_upsample_linear_test.cpp index 0810f35516b..495899d9221 100644 --- a/test/onnx/verify/resize_upsample_linear_test.cpp +++ b/test/onnx/verify/resize_upsample_linear_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(resize_upsample_linear_test) { - migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_test.onnx"); + migraphx::program p = read_onnx("resize_upsample_linear_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; diff --git a/test/onnx/verify/resize_upsample_pf_test.cpp b/test/onnx/verify/resize_upsample_pf_test.cpp index 9031daae25f..04c9239a67f 100644 --- a/test/onnx/verify/resize_upsample_pf_test.cpp +++ b/test/onnx/verify/resize_upsample_pf_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(resize_upsample_pf_test) { - migraphx::program p = migraphx::parse_onnx("resize_upsample_pf_test.onnx"); + migraphx::program p = read_onnx("resize_upsample_pf_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; diff --git a/test/onnx/verify/reversesequence_4D_test.cpp b/test/onnx/verify/reversesequence_4D_test.cpp index e73a52636ff..d66a5d4a1f1 100644 --- a/test/onnx/verify/reversesequence_4D_test.cpp +++ b/test/onnx/verify/reversesequence_4D_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(reversesequence_4D_verify_test) { - migraphx::program p = migraphx::parse_onnx("reversesequence_4D_test.onnx"); + migraphx::program p = read_onnx("reversesequence_4D_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2, 2}}; diff --git a/test/onnx/verify/reversesequence_batch_test.cpp b/test/onnx/verify/reversesequence_batch_test.cpp index 8fdfcb3aba4..114c651dc61 100644 --- a/test/onnx/verify/reversesequence_batch_test.cpp +++ b/test/onnx/verify/reversesequence_batch_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(reversesequence_batch_verify_test) { - migraphx::program p = migraphx::parse_onnx("reversesequence_batch_test.onnx"); + migraphx::program p = read_onnx("reversesequence_batch_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::float_type, {4, 4}}; diff --git a/test/onnx/verify/reversesequence_time_test.cpp b/test/onnx/verify/reversesequence_time_test.cpp index 2142992b813..cfe4638a2d8 100644 --- a/test/onnx/verify/reversesequence_time_test.cpp +++ b/test/onnx/verify/reversesequence_time_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(reversesequence_time_verify_test) { - migraphx::program p = migraphx::parse_onnx("reversesequence_time_test.onnx"); + migraphx::program p = read_onnx("reversesequence_time_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::float_type, {4, 4}}; diff --git a/test/onnx/verify/round_half_test.cpp b/test/onnx/verify/round_half_test.cpp index b7368c40168..a2018f53627 100644 --- a/test/onnx/verify/round_half_test.cpp +++ b/test/onnx/verify/round_half_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(round_half_test) { - migraphx::program p = migraphx::parse_onnx("round_half_test.onnx"); + migraphx::program p = read_onnx("round_half_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::half_type, {4, 4}}; diff --git a/test/onnx/verify/selu_test.cpp b/test/onnx/verify/selu_test.cpp index b8dca79fa22..f1dd54cbb74 100644 --- a/test/onnx/verify/selu_test.cpp +++ b/test/onnx/verify/selu_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(selu_test) { - migraphx::program p = migraphx::parse_onnx("selu_test.onnx"); + migraphx::program p = read_onnx("selu_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape xs{migraphx::shape::double_type, {2, 3}}; diff --git a/test/onnx/verify/shrink_hard_test.cpp b/test/onnx/verify/shrink_hard_test.cpp index da3259b6d10..e740ae13d84 100644 --- a/test/onnx/verify/shrink_hard_test.cpp +++ b/test/onnx/verify/shrink_hard_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(shrink_hard_test) { - migraphx::program p = migraphx::parse_onnx("shrink_hard_test.onnx"); + migraphx::program p = read_onnx("shrink_hard_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {5}}; diff --git a/test/onnx/verify/shrink_int8_test.cpp b/test/onnx/verify/shrink_int8_test.cpp index 71bfbd7c355..051abbe7e01 100644 --- a/test/onnx/verify/shrink_int8_test.cpp +++ b/test/onnx/verify/shrink_int8_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(shrink_int8_test) { - migraphx::program p = migraphx::parse_onnx("shrink_int8_test.onnx"); + migraphx::program p = read_onnx("shrink_int8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::int8_type, {3, 3}}; diff --git a/test/onnx/verify/shrink_soft_test.cpp b/test/onnx/verify/shrink_soft_test.cpp index d2dd2e3a762..99ea5f8f469 100644 --- a/test/onnx/verify/shrink_soft_test.cpp +++ b/test/onnx/verify/shrink_soft_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(shrink_soft_test) { - migraphx::program p = migraphx::parse_onnx("shrink_soft_test.onnx"); + migraphx::program p = read_onnx("shrink_soft_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {5}}; diff --git a/test/onnx/verify/shrink_uint8_test.cpp b/test/onnx/verify/shrink_uint8_test.cpp index a46c6767899..b2dd31fc2ce 100644 --- a/test/onnx/verify/shrink_uint8_test.cpp +++ b/test/onnx/verify/shrink_uint8_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(shrink_uint8_test) { - migraphx::program p = migraphx::parse_onnx("shrink_uint8_test.onnx"); + migraphx::program p = read_onnx("shrink_uint8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::uint8_type, {3, 3}}; diff --git a/test/onnx/verify/shrink_verify2_test.cpp b/test/onnx/verify/shrink_verify2_test.cpp index db492f2becf..5843c18dceb 100644 --- a/test/onnx/verify/shrink_verify2_test.cpp +++ b/test/onnx/verify/shrink_verify2_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(shrink_verify2_test) { - migraphx::program p = migraphx::parse_onnx("shrink_verify2_test.onnx"); + migraphx::program p = read_onnx("shrink_verify2_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::half_type, {5}}; diff --git a/test/onnx/verify/shrink_verify_test.cpp b/test/onnx/verify/shrink_verify_test.cpp index 2edaced62e6..e4135743c4f 100644 --- a/test/onnx/verify/shrink_verify_test.cpp +++ b/test/onnx/verify/shrink_verify_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(shrink_verify_test) { - migraphx::program p = migraphx::parse_onnx("shrink_verify_test.onnx"); + migraphx::program p = read_onnx("shrink_verify_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::half_type, {5}}; diff --git a/test/onnx/verify/size_verify_test.cpp b/test/onnx/verify/size_verify_test.cpp index 4226f8ae617..13342f43003 100644 --- a/test/onnx/verify/size_verify_test.cpp +++ b/test/onnx/verify/size_verify_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(size_verify_test) { - migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx"); + migraphx::program p = read_onnx("size_verify_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {2, 5, 3}}; diff --git a/test/onnx/verify/slice_5arg_reverse_test.cpp b/test/onnx/verify/slice_5arg_reverse_test.cpp index f4bec5146f7..58c8ba06bfd 100644 --- a/test/onnx/verify/slice_5arg_reverse_test.cpp +++ b/test/onnx/verify/slice_5arg_reverse_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(slice_reverse_test) { - migraphx::program p = migraphx::parse_onnx("slice_5arg_reverse_test.onnx"); + migraphx::program p = read_onnx("slice_5arg_reverse_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start diff --git a/test/onnx/verify/slice_5arg_step_test.cpp b/test/onnx/verify/slice_5arg_step_test.cpp index 9b6ccf08693..035355ee657 100644 --- a/test/onnx/verify/slice_5arg_step_test.cpp +++ b/test/onnx/verify/slice_5arg_step_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(slice_step_test) { - migraphx::program p = migraphx::parse_onnx("slice_5arg_step_test.onnx"); + migraphx::program p = read_onnx("slice_5arg_step_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start diff --git a/test/onnx/verify/slice_5arg_test.cpp b/test/onnx/verify/slice_5arg_test.cpp index d7092d2f4a8..2ada2561d3f 100644 --- a/test/onnx/verify/slice_5arg_test.cpp +++ b/test/onnx/verify/slice_5arg_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(slice_5arg_test) { - migraphx::program p = migraphx::parse_onnx("slice_5arg_test.onnx"); + migraphx::program p = read_onnx("slice_5arg_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start diff --git a/test/onnx/verify/slice_test.cpp b/test/onnx/verify/slice_test.cpp index eb8525080d3..0c4c060faf2 100644 --- a/test/onnx/verify/slice_test.cpp +++ b/test/onnx/verify/slice_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(slice_test) { - migraphx::program p = migraphx::parse_onnx("slice_test.onnx"); + migraphx::program p = read_onnx("slice_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape sh_data{migraphx::shape::float_type, {3, 2}}; diff --git a/test/onnx/verify/softplus_test.cpp b/test/onnx/verify/softplus_test.cpp index ee4dba2a602..4385f275229 100644 --- a/test/onnx/verify/softplus_test.cpp +++ b/test/onnx/verify/softplus_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(softplus_test) { - migraphx::program p = migraphx::parse_onnx("softplus_test.onnx"); + migraphx::program p = read_onnx("softplus_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {5}}; diff --git a/test/onnx/verify/softsign_test.cpp b/test/onnx/verify/softsign_test.cpp index d484549ad36..7c1fd62f419 100644 --- a/test/onnx/verify/softsign_test.cpp +++ b/test/onnx/verify/softsign_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(softsign_test) { - migraphx::program p = migraphx::parse_onnx("softsign_test.onnx"); + migraphx::program p = read_onnx("softsign_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::float_type, {5}}; diff --git a/test/onnx/verify/spacetodepth_simple_test.cpp b/test/onnx/verify/spacetodepth_simple_test.cpp index a28a8ced635..db4691f9624 100644 --- a/test/onnx/verify/spacetodepth_simple_test.cpp +++ b/test/onnx/verify/spacetodepth_simple_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(spacetodepth_simple_test) { - auto p = migraphx::parse_onnx("spacetodepth_simple_test.onnx"); + auto p = read_onnx("spacetodepth_simple_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector data_in(48); std::iota(std::begin(data_in), std::end(data_in), 0); @@ -47,7 +47,7 @@ TEST_CASE(spacetodepth_simple_test) TEST_CASE(spacetodepth_depthtospace_test) { // space to depth - auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx"); + auto p1 = read_onnx("spacetodepth_simple_test.onnx"); p1.compile(migraphx::make_target("ref")); std::vector gold_data_in(48); std::iota(std::begin(gold_data_in), std::end(gold_data_in), 0); @@ -56,7 +56,7 @@ TEST_CASE(spacetodepth_depthtospace_test) pp1["x"] = migraphx::argument(s_x_1, gold_data_in.data()); auto result1 = p1.eval(pp1).back(); // depth to space - auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx"); + auto p2 = read_onnx("depthtospace_simple_test.onnx"); p2.compile(migraphx::make_target("ref")); migraphx::parameter_map pp2; pp2["x"] = result1; diff --git a/test/onnx/verify/tril_batch_diff_k_test.cpp b/test/onnx/verify/tril_batch_diff_k_test.cpp index e89db1e8b8b..8b7a3e6a02d 100644 --- a/test/onnx/verify/tril_batch_diff_k_test.cpp +++ b/test/onnx/verify/tril_batch_diff_k_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(tril_batch_diff_k_test) { - migraphx::program p = migraphx::parse_onnx("tril_batch_diff_k_test.onnx"); + migraphx::program p = read_onnx("tril_batch_diff_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p); diff --git a/test/onnx/verify/tril_neg_k_test.cpp b/test/onnx/verify/tril_neg_k_test.cpp index 2c0959fa180..87ff9df7a97 100644 --- a/test/onnx/verify/tril_neg_k_test.cpp +++ b/test/onnx/verify/tril_neg_k_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(tril_neg_k_test) { - migraphx::program p = migraphx::parse_onnx("tril_neg_k_test.onnx"); + migraphx::program p = read_onnx("tril_neg_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); diff --git a/test/onnx/verify/tril_out_k_test.cpp b/test/onnx/verify/tril_out_k_test.cpp index 2167de5f7dc..8ec295d5a22 100644 --- a/test/onnx/verify/tril_out_k_test.cpp +++ b/test/onnx/verify/tril_out_k_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(tril_out_k_test) { - migraphx::program p = migraphx::parse_onnx("tril_out_k_test.onnx"); + migraphx::program p = read_onnx("tril_out_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); diff --git a/test/onnx/verify/tril_row_one_test.cpp b/test/onnx/verify/tril_row_one_test.cpp index c3b8a7131a3..ec1d1ff830b 100644 --- a/test/onnx/verify/tril_row_one_test.cpp +++ b/test/onnx/verify/tril_row_one_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(tril_row_one_test) { - migraphx::program p = migraphx::parse_onnx("tril_row_one_test.onnx"); + migraphx::program p = read_onnx("tril_row_one_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p); diff --git a/test/onnx/verify/tril_test.cpp b/test/onnx/verify/tril_test.cpp index e46676d27fa..d95db6e1be0 100644 --- a/test/onnx/verify/tril_test.cpp +++ b/test/onnx/verify/tril_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(tril_test) { - migraphx::program p = migraphx::parse_onnx("tril_test.onnx"); + migraphx::program p = read_onnx("tril_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); diff --git a/test/onnx/verify/triu_batch_diff_k_test.cpp b/test/onnx/verify/triu_batch_diff_k_test.cpp index a76e87f2bcb..b83b6059b18 100644 --- a/test/onnx/verify/triu_batch_diff_k_test.cpp +++ b/test/onnx/verify/triu_batch_diff_k_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(triu_batch_diff_k_test) { - migraphx::program p = migraphx::parse_onnx("triu_batch_diff_k_test.onnx"); + migraphx::program p = read_onnx("triu_batch_diff_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p); diff --git a/test/onnx/verify/triu_neg_k_test.cpp b/test/onnx/verify/triu_neg_k_test.cpp index 339ada4f12c..47ca830d267 100644 --- a/test/onnx/verify/triu_neg_k_test.cpp +++ b/test/onnx/verify/triu_neg_k_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(triu_neg_k_test) { - migraphx::program p = migraphx::parse_onnx("triu_neg_k_test.onnx"); + migraphx::program p = read_onnx("triu_neg_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); diff --git a/test/onnx/verify/triu_out_k_test.cpp b/test/onnx/verify/triu_out_k_test.cpp index c87353de001..b687d200bd0 100644 --- a/test/onnx/verify/triu_out_k_test.cpp +++ b/test/onnx/verify/triu_out_k_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(triu_out_k_test) { - migraphx::program p = migraphx::parse_onnx("triu_out_k_test.onnx"); + migraphx::program p = read_onnx("triu_out_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); diff --git a/test/onnx/verify/triu_row_one_test.cpp b/test/onnx/verify/triu_row_one_test.cpp index 8d97701e2fd..9f91d1b7305 100644 --- a/test/onnx/verify/triu_row_one_test.cpp +++ b/test/onnx/verify/triu_row_one_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(triu_row_one_test) { - migraphx::program p = migraphx::parse_onnx("triu_row_one_test.onnx"); + migraphx::program p = read_onnx("triu_row_one_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p); diff --git a/test/onnx/verify/triu_test.cpp b/test/onnx/verify/triu_test.cpp index 59e2c1b6f15..5cddf854390 100644 --- a/test/onnx/verify/triu_test.cpp +++ b/test/onnx/verify/triu_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(triu_test) { - migraphx::program p = migraphx::parse_onnx("triu_test.onnx"); + migraphx::program p = read_onnx("triu_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); diff --git a/test/onnx/verify/unique_dynamic_sorted_test.cpp b/test/onnx/verify/unique_dynamic_sorted_test.cpp index a56498ed9db..bf7df608a8a 100644 --- a/test/onnx/verify/unique_dynamic_sorted_test.cpp +++ b/test/onnx/verify/unique_dynamic_sorted_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(unique_dynamic_sorted_test) { - migraphx::program p = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx"); + migraphx::program p = read_onnx("unique_dynamic_sorted_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector x{2, 1, 1, 3, 4, 3}; diff --git a/test/onnx/verify/unique_dynamic_unsorted_test.cpp b/test/onnx/verify/unique_dynamic_unsorted_test.cpp index 0963abc29e7..05e4399d66c 100644 --- a/test/onnx/verify/unique_dynamic_unsorted_test.cpp +++ b/test/onnx/verify/unique_dynamic_unsorted_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(unique_dynamic_unsorted_test) { - migraphx::program p = migraphx::parse_onnx("unique_dynamic_unsorted_test.onnx"); + migraphx::program p = read_onnx("unique_dynamic_unsorted_test.onnx"); p.compile(migraphx::make_target("ref")); std::vector x{2, 1, 1, 3, 4, 3}; diff --git a/test/onnx/verify/upsample_test.cpp b/test/onnx/verify/upsample_test.cpp index cebd7510edb..679ac2d7c41 100644 --- a/test/onnx/verify/upsample_test.cpp +++ b/test/onnx/verify/upsample_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(upsample_test) { - migraphx::program p = migraphx::parse_onnx("upsample_test.onnx"); + migraphx::program p = read_onnx("upsample_test.onnx"); std::vector x_data = {1, 2, 3, 4}; migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; diff --git a/test/onnx/verify/where_test.cpp b/test/onnx/verify/where_test.cpp index df4f5555ccf..73b7bc93a6d 100644 --- a/test/onnx/verify/where_test.cpp +++ b/test/onnx/verify/where_test.cpp @@ -28,7 +28,7 @@ TEST_CASE(where_test) { - migraphx::program p = migraphx::parse_onnx("where_test.onnx"); + migraphx::program p = read_onnx("where_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape c_shape{migraphx::shape::bool_type, {2}}; diff --git a/test/optimize_module_test.cpp b/test/optimize_module_test.cpp index 8324be4a12f..7da3b36de78 100644 --- a/test/optimize_module_test.cpp +++ b/test/optimize_module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,12 +38,12 @@ TEST_CASE(broadcast_transpose_inner_broadcast) // then finds inner broadcast to become mul+broadcast migraphx::module m1; { - auto l1 = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); - auto l2 = m1.add_parameter("y", {migraphx::shape::float_type, {1}, {0}}); + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto y = m1.add_parameter("y", {migraphx::shape::float_type, {1}, {0}}); auto mb1 = - m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l1); + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), x); auto mb2 = - m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2}}}), l2); + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2}}}), y); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb1); auto mul = m1.add_instruction(migraphx::make_op("mul"), mb2, t1); @@ -52,9 +52,9 @@ TEST_CASE(broadcast_transpose_inner_broadcast) run_pass(m1); migraphx::module m2; { - auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); - auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {1}, {0}}); - auto mul = m2.add_instruction(migraphx::make_op("mul"), l2, l1); + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto y = m2.add_parameter("y", {migraphx::shape::float_type, {1}, {0}}); + auto mul = m2.add_instruction(migraphx::make_op("mul"), y, x); auto mb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2}}}), mul); m2.add_return({mb}); @@ -68,12 +68,12 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic) // then finds inner broadcast to become mul+broadcast migraphx::module m1; { - auto l1 = m1.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); - auto l2 = m1.add_parameter("y", {migraphx::shape::float_type, {5}}); + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); + auto y = m1.add_parameter("y", {migraphx::shape::float_type, {5}}); auto mb1 = - m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), l1); + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), x); auto mb2 = - m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 10, 5}}}), l2); + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 10, 5}}}), y); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb2); auto mul = m1.add_instruction(migraphx::make_op("mul"), mb1, t1); @@ -82,19 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic) run_pass(m1); migraphx::module m2; { - auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); - auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}}); - auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2); + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); + auto y = m2.add_parameter("y", {migraphx::shape::float_type, {5}}); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), y); auto transpose = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); - auto mb1 = - m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), transpose); + auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 10}}}), + squeeze); + auto mul = m2.add_instruction(migraphx::make_op("mul"), x, mb1); auto mb2 = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose); - auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2); - auto mb3 = m2.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul); - m2.add_return({mb3}); + m2.add_return({mb2}); } EXPECT(m1 == m2); } diff --git a/test/param_utils.cpp b/test/param_utils.cpp new file mode 100644 index 00000000000..c47bcb10482 --- /dev/null +++ b/test/param_utils.cpp @@ -0,0 +1,42 @@ +#include +#include +#include +#include + +TEST_CASE(test_param_name) +{ + CHECK(migraphx::param_name(0) == "x0"); + CHECK(migraphx::param_name(1) == "x1"); + CHECK(migraphx::param_name(10) == "x:00010"); + CHECK(migraphx::param_name(11) == "x:00011"); + CHECK(migraphx::param_name(100) == "x:00100"); + CHECK(migraphx::param_name(101) == "x:00101"); + CHECK(migraphx::param_name(10011) == "x:10011"); + CHECK(migraphx::param_name(99999) == "x:99999"); + CHECK(test::throws([] { migraphx::param_name(100000); })); + CHECK(test::throws([] { migraphx::param_name(100001); })); +} + +TEST_CASE(test_param_name_sorted) +{ + auto pname = [](std::size_t i) { return migraphx::param_name(i); }; + std::vector names; + migraphx::transform(migraphx::range(8, 25), std::back_inserter(names), pname); + migraphx::transform(migraphx::range(90, 130), std::back_inserter(names), pname); + migraphx::transform(migraphx::range(990, 1030), std::back_inserter(names), pname); + migraphx::transform(migraphx::range(9990, 10030), std::back_inserter(names), pname); + migraphx::transform(migraphx::range(99990, 100000), std::back_inserter(names), pname); + CHECK(std::is_sorted(names.begin(), names.end())); + + auto xnames = names; + // Shuffled + std::shuffle(xnames.begin(), xnames.end(), std::minstd_rand{}); + std::sort(xnames.begin(), xnames.end()); + EXPECT(xnames == names); + // Reversed + std::reverse(xnames.begin(), xnames.end()); + std::sort(xnames.begin(), xnames.end()); + EXPECT(xnames == names); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 353bcea3944..431c72b844b 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -112,15 +112,8 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_det_2d_cpu') backend_test.exclude(r'test_det_nd_cpu') backend_test.exclude(r'test_edge_pad_cpu') - backend_test.exclude(r'test_einsum_batch_diagonal_cpu') - backend_test.exclude(r'test_einsum_batch_matmul_cpu') - backend_test.exclude(r'test_einsum_inner_prod_cpu') - backend_test.exclude(r'test_einsum_sum_cpu') - backend_test.exclude(r'test_einsum_transpose_cpu') backend_test.exclude(r'test_maxunpool_export_with_output_shape_cpu') backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu') - backend_test.exclude(r'test_mod_mixed_sign_int32_cpu') - backend_test.exclude(r'test_mod_mixed_sign_int8_cpu') backend_test.exclude(r'test_qlinearmatmul_2D_cpu') backend_test.exclude(r'test_qlinearmatmul_3D_cpu') backend_test.exclude(r'test_range_float_type_positive_delta_expanded_cpu') diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index a6ae62d935a..71224070886 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -632,8 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) { auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); - auto yb = - m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y); + auto yb = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {384, 768}}}), y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sumb = m2.add_instruction(b, sum); m2.add_instruction(pass_op{}, sumb); @@ -641,6 +641,122 @@ TEST_CASE(simplify_inner_broadcast_different_dims) EXPECT(m1 == m2); } +TEST_CASE(simplify_inner_broadcast_different_dims_single_element) +{ + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}}); + auto xb = m1.add_instruction(b, x); + auto yb = m1.add_instruction(b, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}}); + auto xs = m2.add_instruction(migraphx::make_op("squeeze"), x); + auto xb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1}}}), xs); + auto sum = m2.add_instruction(migraphx::make_op("add"), xb, y); + auto sumb = m2.add_instruction(b, sum); + m2.add_instruction(pass_op{}, sumb); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_inner_broadcast_different_dims_single_element_no_squeeze) +{ + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 5}}); + auto z = m1.add_parameter("z", {migraphx::shape::int32_type, {1, 1, 5}}); + auto xb = m1.add_instruction(b, x); + auto yb = m1.add_instruction(b, y); + auto zb = m1.add_instruction(b, z); + auto sum = m1.add_instruction(migraphx::make_op("where"), xb, yb, zb); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 5}}); + auto z = m2.add_parameter("z", {migraphx::shape::int32_type, {1, 1, 5}}); + auto ys = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2}}}), y); + auto zs = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), z); + auto xb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), x); + auto sum = m2.add_instruction(migraphx::make_op("where"), xb, ys, zs); + auto sumb = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 1, 4, 5}}}), sum); + m2.add_instruction(pass_op{}, sumb); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_inner_broadcast_different_dims_broadcasted) +{ + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 768}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 768}}); + auto xb = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x); + auto xbb = m1.add_instruction(b, xb); + auto yb = m1.add_instruction(b, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xbb, yb); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 768}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 768}}); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + auto sumb = m2.add_instruction(b, sum); + m2.add_instruction(pass_op{}, sumb); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_inner_broadcast_different_dims_broadcasted_scalar) +{ + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 384}}); + auto xb = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 1}}}), x); + auto xbb = m1.add_instruction(b, xb); + auto yb = m1.add_instruction(b, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xbb, yb); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 384}}); + auto xb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), x); + auto sum = m2.add_instruction(migraphx::make_op("add"), xb, y); + auto sumb = m2.add_instruction(b, sum); + m2.add_instruction(pass_op{}, sumb); + } + EXPECT(m1 == m2); +} + TEST_CASE(simplify_inner_broadcast_different_broadcasts) { auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}}); @@ -660,15 +776,73 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts) { auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {24}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}}); - auto xs = m2.add_instruction(migraphx::make_op("squeeze"), x); - auto ys = m2.add_instruction(migraphx::make_op("squeeze"), y); - auto sum = m2.add_instruction(migraphx::make_op("add"), xs, ys); + auto ys = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), y); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, ys); auto sumb = m2.add_instruction(b, sum); m2.add_instruction(pass_op{}, sumb); } EXPECT(m1 == m2); } +TEST_CASE(simplify_inner_broadcast_different_broadcasts_diff_axis) +{ + auto b = migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {1, 64, 112, 112}}}); + auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 112, 112}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 64}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {64, 1, 1}}); + auto xb = m1.add_instruction(b, x); + auto yb = m1.add_instruction(mb, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 64}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {64, 1, 1}}); + auto xs = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), x); + auto ys = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), y); + auto sum = m2.add_instruction(migraphx::make_op("add"), xs, ys); + auto sumb = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 112, 112}}}), sum); + m2.add_instruction(pass_op{}, sumb); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_inner_broadcast_different_broadcasts_diff_dims) +{ + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {256}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 256, 31, 31}}); + auto xb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {64, 256, 31, 31}}}), x); + auto yb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 256, 31, 31}}}), y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {256}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 256, 31, 31}}); + auto xb = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {256, 31, 31}}}), x); + auto ys = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), y); + auto sum = m2.add_instruction(migraphx::make_op("add"), xb, ys); + auto sumb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 256, 31, 31}}}), sum); + m2.add_instruction(pass_op{}, sumb); + } + EXPECT(m1 == m2); +} + TEST_CASE(simplify_inner_broadcast_no_common_axis) { auto b = migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}); @@ -1756,6 +1930,28 @@ TEST_CASE(simplify_slice_concat_flipped) EXPECT(m1 == m2); } +TEST_CASE(simplify_slice_concat_interleaved_non_slice) +{ + // Matched by matcher find_split_concat, but no substitution because the "slice" + // instructions in the input list have a different instruction mixed in + migraphx::module m1; + { + migraphx::shape s{migraphx::shape::float_type, {4, 3, 3, 3}}; + auto x = m1.add_parameter("x", s); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), x); + auto relu = m1.add_instruction(migraphx::make_op("relu"), x); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), x); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), slice1, relu, slice2); + m1.add_instruction(pass_op{}, concat); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(simplify_split_add_relu) { auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index f28c3ff8bde..f0d100821ab 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3) EXPECT(new_concat->get_operator().to_value()["axis"].to() == 2); } +// Broadcasted batch dim, axis is broadcasted dim +// matched by find_concat_multibroadcasts but it skips this case TEST_CASE(concat_multibroadcasts4) { - // Broadcasted batch dim, axis is broadcasted dim std::vector in_lens = {3, 4}; std::vector mbcast_lens = {2, 3, 4}; const int axis = 0; @@ -930,6 +931,112 @@ TEST_CASE(concat_multibroadcasts4) EXPECT(m1 == m); } +// Matched by find_concat_multibroadcasts but skipped because dimensions other than concat axis do +// not match +TEST_CASE(concat_multibroadcasts5) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {1, 12, 60, 64, 64}; + std::vector mb_lens1 = {1, 12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Matched by find_concat_multibroadcasts but skipped because parameter inputs are not the same +// rank. +TEST_CASE(concat_multibroadcasts6) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {12, 60, 64, 64}; + std::vector mb_lens1 = {12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Concat axis moved to 2 because rank(in_dims) < rank(out_dims) +// Matched by find_concat_multibroadcasts but skipped because the dimensions +// other than the concat axis are not the same. +// TODO: has common broadcast axes, so can be simplified by moving multibroadcast up to have a +// smaller concat. +TEST_CASE(concat_multibroadcasts7) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {1, 12, 60, 64, 64}; + std::vector mb_lens1 = {1, 12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Shape of inputs to multibroadcasts do not have the same rank. +// Matched by find_concat_multibroadcasts but skipped. +// TODO: has a common broadcast axis, so can be simplified by moving multibroadcast up to have a +// smaller concat. +TEST_CASE(concat_multibroadcasts8) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {64, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {60, 64, 64}; + std::vector mb_lens1 = {60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Shape of inputs to multibroadcasts do not have a common broadcast axis. +// Matched by find_concat_multibroadcasts, but skipped because the dimensions other than +// the concat axis are not the same. +TEST_CASE(concat_multibroadcasts9) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 64, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {60, 64, 64}; + std::vector mb_lens1 = {60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + TEST_CASE(concat_transpose1) { migraphx::module m; diff --git a/test/tf/CMakeLists.txt b/test/tf/CMakeLists.txt new file mode 100755 index 00000000000..d99b10b9cf9 --- /dev/null +++ b/test/tf/CMakeLists.txt @@ -0,0 +1,38 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### + +function(add_tf_test TEST_NAME) + rocm_add_test_executable(${TEST_NAME} ${ARGN}) + rocm_clang_tidy_check(${TEST_NAME}) + target_link_libraries(${TEST_NAME} migraphx_tf pb_files) + target_include_directories(${TEST_NAME} PUBLIC ../include include) +endfunction() + +include(Embed) +file(GLOB_RECURSE PB_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/models/*.pb) +add_embed_library(pb_files ${PB_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/models) + +file(GLOB TF_TESTS CONFIGRE_DEPENDS tests/*.cpp) + +add_tf_test(test_tf ${TF_TESTS}) diff --git a/test/tf/gen_tf_pb.py b/test/tf/gen_tf_pb.py index b1cc59ad0f0..b177194dabc 100644 --- a/test/tf/gen_tf_pb.py +++ b/test/tf/gen_tf_pb.py @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,7 +32,7 @@ def run_test(): g1 = tf.Graph() op_test(g1) tf.io.write_graph(g1, - '.', + './models', '{}.pb'.format(op_test.__name__), as_text=False) diff --git a/test/tf/include/tf_conv_utils.hpp b/test/tf/include/tf_conv_utils.hpp new file mode 100644 index 00000000000..13f825c99f1 --- /dev/null +++ b/test/tf/include/tf_conv_utils.hpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#ifndef MIGRAPHX_GUARD_TEST_TF_TF_CONV_UTILS_HPP +#define MIGRAPHX_GUARD_TEST_TF_TF_CONV_UTILS_HPP + +inline migraphx::program create_conv() +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + std::vector weight_data(3 * 3 * 3 * 32); + std::fill(weight_data.begin(), weight_data.end(), 1.0f); + auto l1 = + mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data); + + migraphx::op::convolution op; + op.padding = {1, 1, 1, 1}; + op.stride = {1, 1}; + op.dilation = {1, 1}; + auto l2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); + mm->add_instruction(op, l0, l2); + return p; +} + +#endif diff --git a/test/tf/include/tf_test.hpp b/test/tf/include/tf_test.hpp new file mode 100644 index 00000000000..d477befead1 --- /dev/null +++ b/test/tf/include/tf_test.hpp @@ -0,0 +1,96 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_TEST_TF_TF_TEST_HPP +#define MIGRAPHX_GUARD_TEST_TF_TF_TEST_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "test.hpp" + +inline migraphx::program read_pb_file(const std::string& name, const migraphx::tf_options& options) +{ + static auto pb_files{::pb_files()}; + if(pb_files.find(name) == pb_files.end()) + { + std::cerr << "Can not find TensorFlow Protobuf file by name: " << name + << " , aborting the program\n" + << std::endl; + std::abort(); + } + return migraphx::parse_tf_buffer(std::string{pb_files.at(name)}, options); +} + +inline migraphx::program +parse_tf(const std::string& name, + bool is_nhwc, + const std::unordered_map>& dim_params = {}, + const std::vector& output_node_names = {}) +{ + + return read_pb_file(name, migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names}); +} + +inline migraphx::program optimize_tf(const std::string& name, bool is_nhwc) +{ + auto prog = read_pb_file(name, migraphx::tf_options{is_nhwc, 1}); + auto* mm = prog.get_main_module(); + if(is_nhwc) + migraphx::run_passes(*mm, + {migraphx::simplify_reshapes{}, + migraphx::dead_code_elimination{}, + migraphx::eliminate_identity{}}); + + // remove the last return instruction + + if(mm->size() > 0) + { + auto last_ins = std::prev(mm->end()); + if(last_ins->name() == "@return") + { + mm->remove_instruction(last_ins); + } + } + return prog; +} + +#endif diff --git a/test/tf/add_bcast_test.pb b/test/tf/models/add_bcast_test.pb similarity index 100% rename from test/tf/add_bcast_test.pb rename to test/tf/models/add_bcast_test.pb diff --git a/test/tf/add_test.pb b/test/tf/models/add_test.pb similarity index 78% rename from test/tf/add_test.pb rename to test/tf/models/add_test.pb index f176c1b2b93..41cc3fef3ed 100644 --- a/test/tf/add_test.pb +++ b/test/tf/models/add_test.pb @@ -7,6 +7,6 @@ 1 Placeholder* dtype0* shape: - -add1Add01* -T0"¸ \ No newline at end of file + +add1AddV201* +T0"æ \ No newline at end of file diff --git a/test/tf/addv2_test.pb b/test/tf/models/addv2_test.pb similarity index 100% rename from test/tf/addv2_test.pb rename to test/tf/models/addv2_test.pb diff --git a/test/tf/argmax_test.pb b/test/tf/models/argmax_test.pb similarity index 100% rename from test/tf/argmax_test.pb rename to test/tf/models/argmax_test.pb diff --git a/test/tf/argmin_test.pb b/test/tf/models/argmin_test.pb similarity index 100% rename from test/tf/argmin_test.pb rename to test/tf/models/argmin_test.pb diff --git a/test/tf/assert_less_equal_test.pb b/test/tf/models/assert_less_equal_test.pb similarity index 100% rename from test/tf/assert_less_equal_test.pb rename to test/tf/models/assert_less_equal_test.pb diff --git a/test/tf/batchmatmul_test.pb b/test/tf/models/batchmatmul_test.pb similarity index 100% rename from test/tf/batchmatmul_test.pb rename to test/tf/models/batchmatmul_test.pb diff --git a/test/tf/batchnorm_half_test.pb b/test/tf/models/batchnorm_half_test.pb similarity index 100% rename from test/tf/batchnorm_half_test.pb rename to test/tf/models/batchnorm_half_test.pb diff --git a/test/tf/batchnorm_test.pb b/test/tf/models/batchnorm_test.pb similarity index 100% rename from test/tf/batchnorm_test.pb rename to test/tf/models/batchnorm_test.pb diff --git a/test/tf/batchnormv3_test.pb b/test/tf/models/batchnormv3_test.pb similarity index 100% rename from test/tf/batchnormv3_test.pb rename to test/tf/models/batchnormv3_test.pb diff --git a/test/tf/biasadd_scalar_test.pb b/test/tf/models/biasadd_scalar_test.pb similarity index 100% rename from test/tf/biasadd_scalar_test.pb rename to test/tf/models/biasadd_scalar_test.pb diff --git a/test/tf/biasadd_test.pb b/test/tf/models/biasadd_test.pb similarity index 100% rename from test/tf/biasadd_test.pb rename to test/tf/models/biasadd_test.pb diff --git a/test/tf/cast_test.pb b/test/tf/models/cast_test.pb similarity index 100% rename from test/tf/cast_test.pb rename to test/tf/models/cast_test.pb diff --git a/test/tf/concat_test.pb b/test/tf/models/concat_test.pb similarity index 100% rename from test/tf/concat_test.pb rename to test/tf/models/concat_test.pb diff --git a/test/tf/constant_test.pb b/test/tf/models/constant_test.pb similarity index 100% rename from test/tf/constant_test.pb rename to test/tf/models/constant_test.pb diff --git a/test/tf/conv_add_test.pb b/test/tf/models/conv_add_test.pb similarity index 100% rename from test/tf/conv_add_test.pb rename to test/tf/models/conv_add_test.pb diff --git a/test/tf/conv_batch_test.pb b/test/tf/models/conv_batch_test.pb similarity index 100% rename from test/tf/conv_batch_test.pb rename to test/tf/models/conv_batch_test.pb diff --git a/test/tf/conv_nchw_test.pb b/test/tf/models/conv_nchw_test.pb similarity index 100% rename from test/tf/conv_nchw_test.pb rename to test/tf/models/conv_nchw_test.pb diff --git a/test/tf/conv_relu6_test.pb b/test/tf/models/conv_relu6_test.pb similarity index 100% rename from test/tf/conv_relu6_test.pb rename to test/tf/models/conv_relu6_test.pb diff --git a/test/tf/conv_relu_test.pb b/test/tf/models/conv_relu_test.pb similarity index 100% rename from test/tf/conv_relu_test.pb rename to test/tf/models/conv_relu_test.pb diff --git a/test/tf/conv_test.pb b/test/tf/models/conv_test.pb similarity index 100% rename from test/tf/conv_test.pb rename to test/tf/models/conv_test.pb diff --git a/test/tf/depthwise_conv_test.pb b/test/tf/models/depthwise_conv_test.pb similarity index 100% rename from test/tf/depthwise_conv_test.pb rename to test/tf/models/depthwise_conv_test.pb diff --git a/test/tf/expanddims_neg_test.pb b/test/tf/models/expanddims_neg_test.pb similarity index 100% rename from test/tf/expanddims_neg_test.pb rename to test/tf/models/expanddims_neg_test.pb diff --git a/test/tf/expanddims_test.pb b/test/tf/models/expanddims_test.pb similarity index 100% rename from test/tf/expanddims_test.pb rename to test/tf/models/expanddims_test.pb diff --git a/test/tf/gather_test.pb b/test/tf/models/gather_test.pb similarity index 100% rename from test/tf/gather_test.pb rename to test/tf/models/gather_test.pb diff --git a/test/tf/identity_test.pb b/test/tf/models/identity_test.pb similarity index 100% rename from test/tf/identity_test.pb rename to test/tf/models/identity_test.pb diff --git a/test/tf/matmul_test.pb b/test/tf/models/matmul_test.pb similarity index 100% rename from test/tf/matmul_test.pb rename to test/tf/models/matmul_test.pb diff --git a/test/tf/mean_test.pb b/test/tf/models/mean_test.pb similarity index 100% rename from test/tf/mean_test.pb rename to test/tf/models/mean_test.pb diff --git a/test/tf/mean_test_nhwc.pb b/test/tf/models/mean_test_nhwc.pb similarity index 100% rename from test/tf/mean_test_nhwc.pb rename to test/tf/models/mean_test_nhwc.pb diff --git a/test/tf/mul_test.pb b/test/tf/models/mul_test.pb similarity index 100% rename from test/tf/mul_test.pb rename to test/tf/models/mul_test.pb diff --git a/test/tf/multi_output_test.pb b/test/tf/models/multi_output_test.pb similarity index 100% rename from test/tf/multi_output_test.pb rename to test/tf/models/multi_output_test.pb diff --git a/test/tf/noop_test.pb b/test/tf/models/noop_test.pb similarity index 100% rename from test/tf/noop_test.pb rename to test/tf/models/noop_test.pb diff --git a/test/tf/onehot_test.pb b/test/tf/models/onehot_test.pb similarity index 100% rename from test/tf/onehot_test.pb rename to test/tf/models/onehot_test.pb diff --git a/test/tf/pack_test.pb b/test/tf/models/pack_test.pb similarity index 100% rename from test/tf/pack_test.pb rename to test/tf/models/pack_test.pb diff --git a/test/tf/pack_test_nhwc.pb b/test/tf/models/pack_test_nhwc.pb similarity index 100% rename from test/tf/pack_test_nhwc.pb rename to test/tf/models/pack_test_nhwc.pb diff --git a/test/tf/pad_test.pb b/test/tf/models/pad_test.pb similarity index 100% rename from test/tf/pad_test.pb rename to test/tf/models/pad_test.pb diff --git a/test/tf/pooling_test.pb b/test/tf/models/pooling_test.pb similarity index 100% rename from test/tf/pooling_test.pb rename to test/tf/models/pooling_test.pb diff --git a/test/tf/pow_test.pb b/test/tf/models/pow_test.pb similarity index 100% rename from test/tf/pow_test.pb rename to test/tf/models/pow_test.pb diff --git a/test/tf/relu6_half_test.pb b/test/tf/models/relu6_half_test.pb similarity index 100% rename from test/tf/relu6_half_test.pb rename to test/tf/models/relu6_half_test.pb diff --git a/test/tf/relu6_test.pb b/test/tf/models/relu6_test.pb similarity index 100% rename from test/tf/relu6_test.pb rename to test/tf/models/relu6_test.pb diff --git a/test/tf/relu_test.pb b/test/tf/models/relu_test.pb similarity index 100% rename from test/tf/relu_test.pb rename to test/tf/models/relu_test.pb diff --git a/test/tf/reshape_test.pb b/test/tf/models/reshape_test.pb similarity index 100% rename from test/tf/reshape_test.pb rename to test/tf/models/reshape_test.pb diff --git a/test/tf/rsqrt_test.pb b/test/tf/models/rsqrt_test.pb similarity index 100% rename from test/tf/rsqrt_test.pb rename to test/tf/models/rsqrt_test.pb diff --git a/test/tf/shape_test.pb b/test/tf/models/shape_test.pb similarity index 100% rename from test/tf/shape_test.pb rename to test/tf/models/shape_test.pb diff --git a/test/tf/slice_test.pb b/test/tf/models/slice_test.pb similarity index 100% rename from test/tf/slice_test.pb rename to test/tf/models/slice_test.pb diff --git a/test/tf/softmax_test.pb b/test/tf/models/softmax_test.pb similarity index 100% rename from test/tf/softmax_test.pb rename to test/tf/models/softmax_test.pb diff --git a/test/tf/split_test.pb b/test/tf/models/split_test.pb similarity index 100% rename from test/tf/split_test.pb rename to test/tf/models/split_test.pb diff --git a/test/tf/split_test_one_output.pb b/test/tf/models/split_test_one_output.pb similarity index 100% rename from test/tf/split_test_one_output.pb rename to test/tf/models/split_test_one_output.pb diff --git a/test/tf/split_test_vector_as_input.pb b/test/tf/models/split_test_vector_as_input.pb similarity index 100% rename from test/tf/split_test_vector_as_input.pb rename to test/tf/models/split_test_vector_as_input.pb diff --git a/test/tf/sqdiff_test.pb b/test/tf/models/sqdiff_test.pb similarity index 100% rename from test/tf/sqdiff_test.pb rename to test/tf/models/sqdiff_test.pb diff --git a/test/tf/squeeze_test.pb b/test/tf/models/squeeze_test.pb similarity index 100% rename from test/tf/squeeze_test.pb rename to test/tf/models/squeeze_test.pb diff --git a/test/tf/stopgradient_test.pb b/test/tf/models/stopgradient_test.pb similarity index 100% rename from test/tf/stopgradient_test.pb rename to test/tf/models/stopgradient_test.pb diff --git a/test/tf/stridedslice_masks_test.pb b/test/tf/models/stridedslice_masks_test.pb similarity index 100% rename from test/tf/stridedslice_masks_test.pb rename to test/tf/models/stridedslice_masks_test.pb diff --git a/test/tf/stridedslice_test.pb b/test/tf/models/stridedslice_test.pb similarity index 100% rename from test/tf/stridedslice_test.pb rename to test/tf/models/stridedslice_test.pb diff --git a/test/tf/sub_test.pb b/test/tf/models/sub_test.pb similarity index 100% rename from test/tf/sub_test.pb rename to test/tf/models/sub_test.pb diff --git a/test/tf/tanh_test.pb b/test/tf/models/tanh_test.pb similarity index 100% rename from test/tf/tanh_test.pb rename to test/tf/models/tanh_test.pb diff --git a/test/tf/transpose_test.pb b/test/tf/models/transpose_test.pb similarity index 100% rename from test/tf/transpose_test.pb rename to test/tf/models/transpose_test.pb diff --git a/test/tf/variable_batch_test.pb b/test/tf/models/variable_batch_test.pb similarity index 100% rename from test/tf/variable_batch_test.pb rename to test/tf/models/variable_batch_test.pb diff --git a/test/tf/tests/add_bcast_test.cpp b/test/tf/tests/add_bcast_test.cpp new file mode 100644 index 00000000000..84e71d3d4d2 --- /dev/null +++ b/test/tf/tests/add_bcast_test.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(add_bcast_test) +{ + + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {2, 3}}; + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); + auto l2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); + auto prog = optimize_tf("add_bcast_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/add_test.cpp b/test/tf/tests/add_test.cpp new file mode 100644 index 00000000000..eab59fe9c0c --- /dev/null +++ b/test/tf/tests/add_test.cpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(add_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto prog = optimize_tf("add_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/addv2_test.cpp b/test/tf/tests/addv2_test.cpp new file mode 100644 index 00000000000..18d162b0ba7 --- /dev/null +++ b/test/tf/tests/addv2_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(addv2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto prog = optimize_tf("addv2_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/argmax_test.cpp b/test/tf/tests/argmax_test.cpp new file mode 100644 index 00000000000..683bd750a0d --- /dev/null +++ b/test/tf/tests/argmax_test.cpp @@ -0,0 +1,41 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(argmax_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}}); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); + auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); + mm->add_return({l1}); + auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}}); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/argmin_test.cpp b/test/tf/tests/argmin_test.cpp new file mode 100644 index 00000000000..81f34538f2f --- /dev/null +++ b/test/tf/tests/argmin_test.cpp @@ -0,0 +1,41 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(argmin_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); + auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); + mm->add_return({l1}); + auto prog = parse_tf("argmin_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/assert_less_equal_test.cpp b/test/tf/tests/assert_less_equal_test.cpp new file mode 100644 index 00000000000..327f02c4ebb --- /dev/null +++ b/test/tf/tests/assert_less_equal_test.cpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(assert_less_equal_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {2, 3}}; + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_parameter("1", s0); + migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}}; + auto l2 = mm->add_literal(l); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1); + mm->add_instruction(migraphx::make_op("identity"), l3, l2); + auto prog = optimize_tf("assert_less_equal_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/batchmatmul_test.cpp b/test/tf/tests/batchmatmul_test.cpp new file mode 100644 index 00000000000..aa2871d5dc1 --- /dev/null +++ b/test/tf/tests/batchmatmul_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(batchmatmul_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); + + auto trans_l0 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0); + auto trans_l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + + mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); + auto prog = optimize_tf("batchmatmul_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/batchnorm_half_test.cpp b/test/tf/tests/batchnorm_half_test.cpp new file mode 100644 index 00000000000..1281fd3678e --- /dev/null +++ b/test/tf/tests/batchnorm_half_test.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(batchnorm_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::half_type, {1, 32, 16, 16}}); + auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {32}}); + auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {32}}); + auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {32}}); + + std::vector scale_data(32, 1.0); + auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); + auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-4f}}); + + auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); + auto usq_bias = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); + auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); + auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); + + auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); + auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); + auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps); + auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); + auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); + add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); + + auto prog = optimize_tf("batchnorm_half_test.pb", true); + EXPECT(p == prog); +} diff --git a/test/tf/tests/batchnorm_test.cpp b/test/tf/tests/batchnorm_test.cpp new file mode 100644 index 00000000000..4a549ccb53a --- /dev/null +++ b/test/tf/tests/batchnorm_test.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(batchnorm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 32, 16, 16}}); + auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {32}}); + auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {32}}); + auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {32}}); + + std::vector scale_data(32, 1.0); + auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); + auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-4f}}); + + auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); + auto usq_bias = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); + auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); + auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); + + auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); + auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); + auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps); + auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); + auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); + add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); + + auto prog = optimize_tf("batchnorm_test.pb", true); + EXPECT(p == prog); +} diff --git a/test/tf/tests/batchnormv3_test.cpp b/test/tf/tests/batchnormv3_test.cpp new file mode 100644 index 00000000000..2450e9c4331 --- /dev/null +++ b/test/tf/tests/batchnormv3_test.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(batchnormv3_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 32, 16, 16}}); + auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {32}}); + auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {32}}); + auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {32}}); + + std::vector scale_data(32, 1.0); + auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); + auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); + + auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); + auto usq_bias = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); + auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); + auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); + + auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); + auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); + auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps); + auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); + auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); + add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); + + auto prog = optimize_tf("batchnormv3_test.pb", true); + EXPECT(p == prog); +} diff --git a/test/tf/tests/biasadd_scalar_test.cpp b/test/tf/tests/biasadd_scalar_test.cpp new file mode 100644 index 00000000000..13af83d6027 --- /dev/null +++ b/test/tf/tests/biasadd_scalar_test.cpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(biasadd_scalar_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {1, 1}}; + uint64_t axis = 1; + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}}); + auto l2 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); + auto prog = optimize_tf("biasadd_scalar_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/biasadd_test.cpp b/test/tf/tests/biasadd_test.cpp new file mode 100644 index 00000000000..caf09833230 --- /dev/null +++ b/test/tf/tests/biasadd_test.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(biasadd_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}}; + uint64_t axis = 1; + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); + auto l2 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); + auto prog = optimize_tf("biasadd_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/cast_test.cpp b/test/tf/tests/cast_test.cpp new file mode 100644 index 00000000000..c67f7b77754 --- /dev/null +++ b/test/tf/tests/cast_test.cpp @@ -0,0 +1,41 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(cast_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), + l0); + auto prog = optimize_tf("cast_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/concat_test.cpp b/test/tf/tests/concat_test.cpp new file mode 100644 index 00000000000..afac63368c0 --- /dev/null +++ b/test/tf/tests/concat_test.cpp @@ -0,0 +1,46 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(concat_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); + + int axis = 1; + // tf uses axis as the third input, and it is in int32 format + // add the literal using a vector in order to set stride to 1 (like in tf parser) + mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector{axis}); + + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1); + auto prog = optimize_tf("concat_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/constant_test.cpp b/test/tf/tests/constant_test.cpp new file mode 100644 index 00000000000..b173a48e5fc --- /dev/null +++ b/test/tf/tests/constant_test.cpp @@ -0,0 +1,37 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(const_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector{1.0f}); + auto prog = optimize_tf("constant_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/conv_add_test.cpp b/test/tf/tests/conv_add_test.cpp new file mode 100644 index 00000000000..a29655d21ed --- /dev/null +++ b/test/tf/tests/conv_add_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include +#include + +TEST_CASE(conv_add_test) +{ + migraphx::program p = create_conv(); + auto* mm = p.get_main_module(); + auto l0 = std::prev(mm->end()); + mm->add_instruction(migraphx::make_op("add"), l0, l0); + auto prog = optimize_tf("conv_add_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/conv_nchw_test.cpp b/test/tf/tests/conv_nchw_test.cpp new file mode 100644 index 00000000000..bc2c80e01fa --- /dev/null +++ b/test/tf/tests/conv_nchw_test.cpp @@ -0,0 +1,35 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include +#include + +TEST_CASE(conv_nchw_test) +{ + migraphx::program p = create_conv(); + auto prog = optimize_tf("conv_nchw_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/conv_relu6_test.cpp b/test/tf/tests/conv_relu6_test.cpp new file mode 100644 index 00000000000..ed01dd87e14 --- /dev/null +++ b/test/tf/tests/conv_relu6_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include +#include + +TEST_CASE(conv_relu6_test) +{ + migraphx::program p = create_conv(); + auto* mm = p.get_main_module(); + std::vector input_lens{1, 32, 16, 16}; + auto l0 = std::prev(mm->end()); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + min_val); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + auto prog = optimize_tf("conv_relu6_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/conv_relu_test.cpp b/test/tf/tests/conv_relu_test.cpp new file mode 100644 index 00000000000..295d66f22eb --- /dev/null +++ b/test/tf/tests/conv_relu_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include +#include + +TEST_CASE(conv_relu_test) +{ + migraphx::program p = create_conv(); + auto* mm = p.get_main_module(); + auto l0 = std::prev(mm->end()); + mm->add_instruction(migraphx::make_op("relu"), l0); + auto prog = optimize_tf("conv_relu_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/conv_test.cpp b/test/tf/tests/conv_test.cpp new file mode 100644 index 00000000000..69249765197 --- /dev/null +++ b/test/tf/tests/conv_test.cpp @@ -0,0 +1,35 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include +#include + +TEST_CASE(conv_test) +{ + migraphx::program p = create_conv(); + auto prog = optimize_tf("conv_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/depthwise_conv_test.cpp b/test/tf/tests/depthwise_conv_test.cpp new file mode 100644 index 00000000000..53839fe96d5 --- /dev/null +++ b/test/tf/tests/depthwise_conv_test.cpp @@ -0,0 +1,54 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include +#include + +TEST_CASE(depthwiseconv_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + std::vector weight_data(3 * 3 * 3 * 1); + std::fill(weight_data.begin(), weight_data.end(), 1.0f); + auto l1 = + mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data); + + migraphx::op::convolution op; + op.padding = {1, 1}; + op.stride = {1, 1}; + op.dilation = {1, 1}; + op.group = 3; + auto l3 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); + auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); + auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4); + mm->add_instruction(op, l0, l5); + auto prog = optimize_tf("depthwise_conv_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/expanddims_neg_test.cpp b/test/tf/tests/expanddims_neg_test.cpp new file mode 100644 index 00000000000..3e659bc912d --- /dev/null +++ b/test/tf/tests/expanddims_neg_test.cpp @@ -0,0 +1,41 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(expanddims_test_neg_dims) +{ + // this check makes sure the pb parses negative dim value correctly + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); + mm->add_literal(-1); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0); + auto prog = optimize_tf("expanddims_neg_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/expanddims_test.cpp b/test/tf/tests/expanddims_test.cpp new file mode 100644 index 00000000000..85295d83257 --- /dev/null +++ b/test/tf/tests/expanddims_test.cpp @@ -0,0 +1,40 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(expanddims_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); + mm->add_literal(0); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0); + auto prog = optimize_tf("expanddims_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/gather_test.cpp b/test/tf/tests/gather_test.cpp new file mode 100644 index 00000000000..ef96f51523d --- /dev/null +++ b/test/tf/tests/gather_test.cpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(gather_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); + auto l1 = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}}); + mm->add_literal(1); + + int axis = 1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1); + auto prog = optimize_tf("gather_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/identity_test.cpp b/test/tf/tests/identity_test.cpp new file mode 100644 index 00000000000..ae79c3160ac --- /dev/null +++ b/test/tf/tests/identity_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(identity_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_tf("identity_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/main.cpp b/test/tf/tests/main.cpp new file mode 100644 index 00000000000..336c0391aa6 --- /dev/null +++ b/test/tf/tests/main.cpp @@ -0,0 +1,28 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/tf/tests/matmul_test.cpp b/test/tf/tests/matmul_test.cpp new file mode 100644 index 00000000000..82f46bbf9ce --- /dev/null +++ b/test/tf/tests/matmul_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(matmul_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); + + auto trans_l0 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0); + auto trans_l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + + mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); + auto prog = optimize_tf("matmul_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/mean_test.cpp b/test/tf/tests/mean_test.cpp new file mode 100644 index 00000000000..62a19631a37 --- /dev/null +++ b/test/tf/tests/mean_test.cpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(mean_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_literal(l); + mm->add_literal(l); + migraphx::op::reduce_mean op{{2, 3}}; + mm->add_instruction(op, l0); + auto l3 = mm->add_instruction(op, l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3); + auto prog = optimize_tf("mean_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/mean_test_nhwc.cpp b/test/tf/tests/mean_test_nhwc.cpp new file mode 100644 index 00000000000..b3f8eed2b7e --- /dev/null +++ b/test/tf/tests/mean_test_nhwc.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(mean_test_nhwc) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + migraphx::op::reduce_mean op{{1, 2}}; + auto l2 = mm->add_instruction(op, l1); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2); + auto prog = optimize_tf("mean_test_nhwc.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/mul_test.cpp b/test/tf/tests/mul_test.cpp new file mode 100644 index 00000000000..101cd6b6b82 --- /dev/null +++ b/test/tf/tests/mul_test.cpp @@ -0,0 +1,40 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(mul_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); + + mm->add_instruction(migraphx::make_op("mul"), l0, l1); + auto prog = optimize_tf("mul_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/multi_output_test.cpp b/test/tf/tests/multi_output_test.cpp new file mode 100644 index 00000000000..06039b122b6 --- /dev/null +++ b/test/tf/tests/multi_output_test.cpp @@ -0,0 +1,42 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(multi_output_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = mm->add_instruction(migraphx::make_op("relu"), l0); + auto l2 = mm->add_instruction(migraphx::make_op("tanh"), l0); + mm->add_return({l1, l2}); + + EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); })); + auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"}); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/noop_test.cpp b/test/tf/tests/noop_test.cpp new file mode 100644 index 00000000000..06e82269fd0 --- /dev/null +++ b/test/tf/tests/noop_test.cpp @@ -0,0 +1,34 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(noop_test) +{ + migraphx::program p; + auto prog = optimize_tf("noop_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/onehot_test.cpp b/test/tf/tests/onehot_test.cpp new file mode 100644 index 00000000000..2cb67411590 --- /dev/null +++ b/test/tf/tests/onehot_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(onehot_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}}); + mm->add_literal(2); + mm->add_literal(1.0f); + mm->add_literal(0.0f); + auto l1 = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}}); + int axis = 0; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0); + auto prog = optimize_tf("onehot_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/pack_test.cpp b/test/tf/tests/pack_test.cpp new file mode 100644 index 00000000000..7c20df43a08 --- /dev/null +++ b/test/tf/tests/pack_test.cpp @@ -0,0 +1,52 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(pack_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}}); + auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}}); + std::vector args{l0, l1, l2}; + std::vector unsqueezed_args; + int64_t axis = 1; + + std::transform( + args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&](migraphx::instruction_ref arg) { + return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); + }); + mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast(axis)}}), + unsqueezed_args); + auto prog = optimize_tf("pack_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/pack_test_nhwc.cpp b/test/tf/tests/pack_test_nhwc.cpp new file mode 100644 index 00000000000..d87e53a56ba --- /dev/null +++ b/test/tf/tests/pack_test_nhwc.cpp @@ -0,0 +1,58 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(pack_test_nhwc) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); + auto lt0 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); + auto lt1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1); + auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); + auto lt2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2); + std::vector args{lt0, lt1, lt2}; + std::vector unsqueezed_args; + int64_t nchw_axis = 3; + + std::transform(args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&](migraphx::instruction_ref arg) { + return mm->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg); + }); + mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast(nchw_axis)}}), + unsqueezed_args); + auto prog = optimize_tf("pack_test_nhwc.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/pad_test.cpp b/test/tf/tests/pad_test.cpp new file mode 100644 index 00000000000..431eff47236 --- /dev/null +++ b/test/tf/tests/pad_test.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(pad_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); + std::vector pad_literals{1, 1, 2, 2}; + std::vector pads{1, 2, 1, 2}; + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals); + + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0); + auto prog = optimize_tf("pad_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/pooling_test.cpp b/test/tf/tests/pooling_test.cpp new file mode 100644 index 00000000000..3105328eff2 --- /dev/null +++ b/test/tf/tests/pooling_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(pooling_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average}; + migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max}; + avg_pool_op.stride = {2, 2}; + max_pool_op.stride = {2, 2}; + avg_pool_op.lengths = {2, 2}; + max_pool_op.lengths = {2, 2}; + mm->add_instruction(avg_pool_op, l0); + mm->add_instruction(max_pool_op, l0); + auto prog = optimize_tf("pooling_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/pow_test.cpp b/test/tf/tests/pow_test.cpp new file mode 100644 index 00000000000..30f6a01f9ce --- /dev/null +++ b/test/tf/tests/pow_test.cpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(pow_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("pow"), l0, l1); + auto prog = optimize_tf("pow_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/relu6_half_test.cpp b/test/tf/tests/relu6_half_test.cpp new file mode 100644 index 00000000000..84ade6d366f --- /dev/null +++ b/test/tf/tests/relu6_half_test.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(relu6_half_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 16, 16}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens}); + auto min_val = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.0f}}); + auto max_val = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {6.0f}}); + min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + min_val); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + auto prog = optimize_tf("relu6_half_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/relu6_test.cpp b/test/tf/tests/relu6_test.cpp new file mode 100644 index 00000000000..5da363d62b7 --- /dev/null +++ b/test/tf/tests/relu6_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(relu6_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 16, 16}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens}); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + min_val); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + auto prog = optimize_tf("relu6_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/relu_test.cpp b/test/tf/tests/relu_test.cpp new file mode 100644 index 00000000000..a549df6193d --- /dev/null +++ b/test/tf/tests/relu_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(relu_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("relu"), l0); + auto prog = optimize_tf("relu_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/reshape_test.cpp b/test/tf/tests/reshape_test.cpp new file mode 100644 index 00000000000..2700421de05 --- /dev/null +++ b/test/tf/tests/reshape_test.cpp @@ -0,0 +1,41 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(reshape_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}}); + migraphx::shape s0{migraphx::shape::int32_type, {4}}; + // in tf, the second arg is a literal that contains new dimensions + mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0); + auto prog = optimize_tf("reshape_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/rsqrt_test.cpp b/test/tf/tests/rsqrt_test.cpp new file mode 100644 index 00000000000..5cbf5b287e6 --- /dev/null +++ b/test/tf/tests/rsqrt_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(rsqrt_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("rsqrt"), l0); + auto prog = optimize_tf("rsqrt_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/shape_test.cpp b/test/tf/tests/shape_test.cpp new file mode 100644 index 00000000000..f749fbdefb4 --- /dev/null +++ b/test/tf/tests/shape_test.cpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(shape_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}}); + auto prog = optimize_tf("shape_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/slice_test.cpp b/test/tf/tests/slice_test.cpp new file mode 100644 index 00000000000..8db6b04e3d7 --- /dev/null +++ b/test/tf/tests/slice_test.cpp @@ -0,0 +1,44 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(slice_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + std::size_t num_axes = 2; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}}); + migraphx::shape s0{migraphx::shape::int32_type, {num_axes}}; + mm->add_literal(migraphx::literal{s0, {1, 0}}); + mm->add_literal(migraphx::literal{s0, {2, -1}}); + + mm->add_instruction( + migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0); + auto prog = optimize_tf("slice_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/softmax_test.cpp b/test/tf/tests/softmax_test.cpp new file mode 100644 index 00000000000..80b39b30dfb --- /dev/null +++ b/test/tf/tests/softmax_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(softmax_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0); + auto prog = optimize_tf("softmax_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/split_test.cpp b/test/tf/tests/split_test.cpp new file mode 100644 index 00000000000..6df77ea18b2 --- /dev/null +++ b/test/tf/tests/split_test.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(split_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + std::vector axes{0, 1}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); + mm->add_literal(3); // num_splits + mm->add_literal(1); // split axis + mm->add_literal(1); // concat axis + mm->add_literal(1); // concat axis + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0); + auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); + auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); + mm->add_return({l4, l5}); + auto prog = parse_tf("split_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/split_test_one_output.cpp b/test/tf/tests/split_test_one_output.cpp new file mode 100644 index 00000000000..309ab55a4f0 --- /dev/null +++ b/test/tf/tests/split_test_one_output.cpp @@ -0,0 +1,41 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(split_test_one_output) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); + mm->add_literal(1); // num_splits + mm->add_literal(1); // split axis + auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0); + mm->add_return({l1}); + auto prog = parse_tf("split_test_one_output.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/split_test_vector_as_input.cpp b/test/tf/tests/split_test_vector_as_input.cpp new file mode 100644 index 00000000000..96d566beef5 --- /dev/null +++ b/test/tf/tests/split_test_vector_as_input.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(split_test_vector_as_input) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + std::vector axes{0, 1}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); + // split sizes + mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}}); + mm->add_literal(1); // split axis + mm->add_literal(1); // concat axis + mm->add_literal(1); // concat axis + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0); + auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); + auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); + mm->add_return({l4, l5}); + auto prog = parse_tf("split_test_vector_as_input.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/sqdiff_test.cpp b/test/tf/tests/sqdiff_test.cpp new file mode 100644 index 00000000000..d1ca678a84c --- /dev/null +++ b/test/tf/tests/sqdiff_test.cpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(sqdiff_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1); + auto prog = optimize_tf("sqdiff_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/squeeze_test.cpp b/test/tf/tests/squeeze_test.cpp new file mode 100644 index 00000000000..da6e6fbf1ee --- /dev/null +++ b/test/tf/tests/squeeze_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(squeeze_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0); + auto prog = optimize_tf("squeeze_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/stopgradient_test.cpp b/test/tf/tests/stopgradient_test.cpp new file mode 100644 index 00000000000..4d3d6d15fcf --- /dev/null +++ b/test/tf/tests/stopgradient_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(stopgradient_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_tf("stopgradient_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/stridedslice_masks_test.cpp b/test/tf/tests/stridedslice_masks_test.cpp new file mode 100644 index 00000000000..5d2466bd780 --- /dev/null +++ b/test/tf/tests/stridedslice_masks_test.cpp @@ -0,0 +1,54 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(stridedslice_masks_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); + // add literals for starts, ends, and strides in tf (NHWC format) + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, + std::vector{0, 1, 1, 0}); + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, + std::vector{0, 0, 0, 0}); + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, + std::vector{1, 1, 1, 1}); + + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op( + "slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}), + l1); + auto l3 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); + mm->add_return({l3}); + auto prog = parse_tf("stridedslice_masks_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/stridedslice_test.cpp b/test/tf/tests/stridedslice_test.cpp new file mode 100644 index 00000000000..514ea53fb7a --- /dev/null +++ b/test/tf/tests/stridedslice_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(stridedslice_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op( + "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}), + l1); + auto shrink_axis = 1; + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); + auto prog = optimize_tf("stridedslice_test.pb", true); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/sub_test.cpp b/test/tf/tests/sub_test.cpp new file mode 100644 index 00000000000..1830fff3f0b --- /dev/null +++ b/test/tf/tests/sub_test.cpp @@ -0,0 +1,40 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(sub_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l2 = mm->add_instruction(migraphx::make_op("sub"), l0, l1); + mm->add_return({l2}); + auto prog = parse_tf("sub_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/tanh_test.cpp b/test/tf/tests/tanh_test.cpp new file mode 100644 index 00000000000..b686312d8c2 --- /dev/null +++ b/test/tf/tests/tanh_test.cpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(tanh_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = mm->add_instruction(migraphx::make_op("tanh"), l0); + mm->add_return({l1}); + auto prog = parse_tf("tanh_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/transpose_test.cpp b/test/tf/tests/transpose_test.cpp new file mode 100644 index 00000000000..5f290520d7f --- /dev/null +++ b/test/tf/tests/transpose_test.cpp @@ -0,0 +1,40 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(transpose_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + migraphx::shape s0{migraphx::shape::int32_type, {4}}; + mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto prog = optimize_tf("transpose_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tests/variable_batch_test.cpp b/test/tf/tests/variable_batch_test.cpp new file mode 100644 index 00000000000..fe4e56d6d40 --- /dev/null +++ b/test/tf/tests/variable_batch_test.cpp @@ -0,0 +1,38 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +#include + +TEST_CASE(variable_batch_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_tf("variable_batch_test.pb", false); + + EXPECT(p == prog); +} diff --git a/test/tf/tf_test.cpp b/test/tf/tf_test.cpp deleted file mode 100644 index b8096065911..00000000000 --- a/test/tf/tf_test.cpp +++ /dev/null @@ -1,1066 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "test.hpp" - -migraphx::program -parse_tf(const std::string& name, - bool is_nhwc, - const std::unordered_map>& dim_params = {}, - const std::vector& output_node_names = {}) -{ - return migraphx::parse_tf(name, - migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names}); -} - -migraphx::program optimize_tf(const std::string& name, bool is_nhwc) -{ - auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1}); - auto* mm = prog.get_main_module(); - if(is_nhwc) - migraphx::run_passes(*mm, - {migraphx::simplify_reshapes{}, - migraphx::dead_code_elimination{}, - migraphx::eliminate_identity{}}); - - // remove the last return instruction - - if(mm->size() > 0) - { - auto last_ins = std::prev(mm->end()); - if(last_ins->name() == "@return") - { - mm->remove_instruction(last_ins); - } - } - return prog; -} - -TEST_CASE(add_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - mm->add_instruction(migraphx::make_op("add"), l0, l1); - auto prog = optimize_tf("add_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(addv2_test) -{ - migraphx::program p; - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - mm->add_instruction(migraphx::make_op("add"), l0, l1); - auto prog = optimize_tf("addv2_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(add_bcast_test) -{ - - migraphx::program p; - - auto* mm = p.get_main_module(); - migraphx::shape s0{migraphx::shape::float_type, {2, 3}}; - auto l0 = mm->add_parameter("0", s0); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); - auto l2 = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1); - mm->add_instruction(migraphx::make_op("add"), l0, l2); - auto prog = optimize_tf("add_bcast_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(argmax_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}}); - mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); - auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0); - auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); - mm->add_return({l1}); - auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}}); - - EXPECT(p == prog); -} - -TEST_CASE(argmin_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); - auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0); - auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); - mm->add_return({l1}); - auto prog = parse_tf("argmin_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(assert_less_equal_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - migraphx::shape s0{migraphx::shape::float_type, {2, 3}}; - auto l0 = mm->add_parameter("0", s0); - auto l1 = mm->add_parameter("1", s0); - migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}}; - auto l2 = mm->add_literal(l); - mm->add_instruction(migraphx::make_op("add"), l0, l1); - auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1); - mm->add_instruction(migraphx::make_op("identity"), l3, l2); - auto prog = optimize_tf("assert_less_equal_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(batchmatmul_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); - - auto trans_l0 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0); - auto trans_l1 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); - - mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); - auto prog = optimize_tf("batchmatmul_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(batchnorm_test) -{ - migraphx::program p; - auto* mm = p.get_main_module(); - - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 32, 16, 16}}); - auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {32}}); - auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {32}}); - auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {32}}); - - std::vector scale_data(32, 1.0); - auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-4f}}); - - auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); - auto usq_bias = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); - auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); - auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); - - auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); - auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); - auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps); - auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); - auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); - add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); - - auto prog = optimize_tf("batchnorm_test.pb", true); - EXPECT(p == prog); -} - -TEST_CASE(batchnorm_half_test) -{ - migraphx::program p; - auto* mm = p.get_main_module(); - - auto x = mm->add_parameter("x", {migraphx::shape::half_type, {1, 32, 16, 16}}); - auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {32}}); - auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {32}}); - auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {32}}); - - std::vector scale_data(32, 1.0); - auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-4f}}); - - auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); - auto usq_bias = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); - auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); - auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); - - auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); - auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); - auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps); - auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); - auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); - add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); - - auto prog = optimize_tf("batchnorm_half_test.pb", true); - EXPECT(p == prog); -} - -TEST_CASE(batchnormv3_test) -{ - migraphx::program p; - auto* mm = p.get_main_module(); - - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 32, 16, 16}}); - auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {32}}); - auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {32}}); - auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {32}}); - - std::vector scale_data(32, 1.0); - auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); - - auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); - auto usq_bias = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); - auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); - auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); - - auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); - auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); - auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps); - auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt}); - auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0}); - add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); - - auto prog = optimize_tf("batchnormv3_test.pb", true); - EXPECT(p == prog); -} - -TEST_CASE(biasadd_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}}; - uint64_t axis = 1; - auto l0 = mm->add_parameter("0", s0); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); - auto l2 = mm->add_instruction( - migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1); - mm->add_instruction(migraphx::make_op("add"), l0, l2); - auto prog = optimize_tf("biasadd_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(biasadd_scalar_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - migraphx::shape s0{migraphx::shape::float_type, {1, 1}}; - uint64_t axis = 1; - auto l0 = mm->add_parameter("0", s0); - auto l1 = mm->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}}); - auto l2 = mm->add_instruction( - migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1); - mm->add_instruction(migraphx::make_op("add"), l0, l2); - auto prog = optimize_tf("biasadd_scalar_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(cast_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_instruction( - migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), - l0); - auto prog = optimize_tf("cast_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(concat_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); - - int axis = 1; - // tf uses axis as the third input, and it is in int32 format - // add the literal using a vector in order to set stride to 1 (like in tf parser) - mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector{axis}); - - mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1); - auto prog = optimize_tf("concat_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(const_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector{1.0f}); - auto prog = optimize_tf("constant_test.pb", false); - - EXPECT(p == prog); -} - -migraphx::program create_conv() -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - std::vector weight_data(3 * 3 * 3 * 32); - std::fill(weight_data.begin(), weight_data.end(), 1.0f); - auto l1 = - mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data); - - migraphx::op::convolution op; - op.padding = {1, 1, 1, 1}; - op.stride = {1, 1}; - op.dilation = {1, 1}; - auto l2 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); - mm->add_instruction(op, l0, l2); - return p; -} - -TEST_CASE(conv_test) -{ - migraphx::program p = create_conv(); - auto prog = optimize_tf("conv_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(conv_add_test) -{ - migraphx::program p = create_conv(); - auto* mm = p.get_main_module(); - auto l0 = std::prev(mm->end()); - mm->add_instruction(migraphx::make_op("add"), l0, l0); - auto prog = optimize_tf("conv_add_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(conv_nchw_test) -{ - migraphx::program p = create_conv(); - auto prog = optimize_tf("conv_nchw_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(conv_relu_test) -{ - migraphx::program p = create_conv(); - auto* mm = p.get_main_module(); - auto l0 = std::prev(mm->end()); - mm->add_instruction(migraphx::make_op("relu"), l0); - auto prog = optimize_tf("conv_relu_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(conv_relu6_test) -{ - migraphx::program p = create_conv(); - auto* mm = p.get_main_module(); - std::vector input_lens{1, 32, 16, 16}; - auto l0 = std::prev(mm->end()); - auto min_val = mm->add_literal(0.0f); - auto max_val = mm->add_literal(6.0f); - min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), - min_val); - max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), - max_val); - mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); - auto prog = optimize_tf("conv_relu6_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(depthwiseconv_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - std::vector weight_data(3 * 3 * 3 * 1); - std::fill(weight_data.begin(), weight_data.end(), 1.0f); - auto l1 = - mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data); - - migraphx::op::convolution op; - op.padding = {1, 1}; - op.stride = {1, 1}; - op.dilation = {1, 1}; - op.group = 3; - auto l3 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); - auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); - auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4); - mm->add_instruction(op, l0, l5); - auto prog = optimize_tf("depthwise_conv_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(expanddims_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); - mm->add_literal(0); - mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0); - auto prog = optimize_tf("expanddims_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(expanddims_test_neg_dims) -{ - // this check makes sure the pb parses negative dim value correctly - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); - mm->add_literal(-1); - mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0); - auto prog = optimize_tf("expanddims_neg_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(gather_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); - auto l1 = mm->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}}); - mm->add_literal(1); - - int axis = 1; - mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1); - auto prog = optimize_tf("gather_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(identity_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_instruction(migraphx::make_op("identity"), l0); - auto prog = optimize_tf("identity_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(matmul_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); - - auto trans_l0 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0); - auto trans_l1 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); - - mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); - auto prog = optimize_tf("matmul_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(mean_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}}; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_literal(l); - mm->add_literal(l); - migraphx::op::reduce_mean op{{2, 3}}; - mm->add_instruction(op, l0); - auto l3 = mm->add_instruction(op, l0); - mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3); - auto prog = optimize_tf("mean_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(mean_test_nhwc) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto l1 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - migraphx::op::reduce_mean op{{1, 2}}; - auto l2 = mm->add_instruction(op, l1); - mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2); - auto prog = optimize_tf("mean_test_nhwc.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(mul_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); - - mm->add_instruction(migraphx::make_op("mul"), l0, l1); - auto prog = optimize_tf("mul_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(multi_output_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto l1 = mm->add_instruction(migraphx::make_op("relu"), l0); - auto l2 = mm->add_instruction(migraphx::make_op("tanh"), l0); - mm->add_return({l1, l2}); - - EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); })); - auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"}); - - EXPECT(p == prog); -} - -TEST_CASE(onehot_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}}); - mm->add_literal(2); - mm->add_literal(1.0f); - mm->add_literal(0.0f); - auto l1 = mm->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}}); - int axis = 0; - mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0); - auto prog = optimize_tf("onehot_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(noop_test) -{ - migraphx::program p; - auto prog = optimize_tf("noop_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(pack_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}}); - auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}}); - std::vector args{l0, l1, l2}; - std::vector unsqueezed_args; - int64_t axis = 1; - - std::transform( - args.begin(), - args.end(), - std::back_inserter(unsqueezed_args), - [&](migraphx::instruction_ref arg) { - return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); - }); - mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast(axis)}}), - unsqueezed_args); - auto prog = optimize_tf("pack_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(pack_test_nhwc) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); - auto lt0 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); - auto lt1 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1); - auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); - auto lt2 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2); - std::vector args{lt0, lt1, lt2}; - std::vector unsqueezed_args; - int64_t nchw_axis = 3; - - std::transform(args.begin(), - args.end(), - std::back_inserter(unsqueezed_args), - [&](migraphx::instruction_ref arg) { - return mm->add_instruction( - migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg); - }); - mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast(nchw_axis)}}), - unsqueezed_args); - auto prog = optimize_tf("pack_test_nhwc.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(pad_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); - std::vector pad_literals{1, 1, 2, 2}; - std::vector pads{1, 2, 1, 2}; - mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals); - - mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0); - auto prog = optimize_tf("pad_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(pooling_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average}; - migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max}; - avg_pool_op.stride = {2, 2}; - max_pool_op.stride = {2, 2}; - avg_pool_op.lengths = {2, 2}; - max_pool_op.lengths = {2, 2}; - mm->add_instruction(avg_pool_op, l0); - mm->add_instruction(max_pool_op, l0); - auto prog = optimize_tf("pooling_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(pow_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - mm->add_instruction(migraphx::make_op("pow"), l0, l1); - auto prog = optimize_tf("pow_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(relu_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_instruction(migraphx::make_op("relu"), l0); - auto prog = optimize_tf("relu_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(relu6_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - std::vector input_lens{1, 3, 16, 16}; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens}); - auto min_val = mm->add_literal(0.0f); - auto max_val = mm->add_literal(6.0f); - min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), - min_val); - max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), - max_val); - mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); - auto prog = optimize_tf("relu6_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(relu6_half_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - std::vector input_lens{1, 3, 16, 16}; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens}); - auto min_val = - mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.0f}}); - auto max_val = - mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {6.0f}}); - min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), - min_val); - max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), - max_val); - mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); - auto prog = optimize_tf("relu6_half_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(reshape_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}}); - migraphx::shape s0{migraphx::shape::int32_type, {4}}; - // in tf, the second arg is a literal that contains new dimensions - mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); - mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0); - auto prog = optimize_tf("reshape_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(rsqrt_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_instruction(migraphx::make_op("rsqrt"), l0); - auto prog = optimize_tf("rsqrt_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(shape_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}}); - auto prog = optimize_tf("shape_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(slice_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - std::size_t num_axes = 2; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}}); - migraphx::shape s0{migraphx::shape::int32_type, {num_axes}}; - mm->add_literal(migraphx::literal{s0, {1, 0}}); - mm->add_literal(migraphx::literal{s0, {2, -1}}); - - mm->add_instruction( - migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0); - auto prog = optimize_tf("slice_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(softmax_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); - mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0); - auto prog = optimize_tf("softmax_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(split_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - std::vector axes{0, 1}; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); - mm->add_literal(3); // num_splits - mm->add_literal(1); // split axis - mm->add_literal(1); // concat axis - mm->add_literal(1); // concat axis - auto l1 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0); - auto l2 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0); - auto l3 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0); - auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); - auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); - mm->add_return({l4, l5}); - auto prog = parse_tf("split_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(split_test_one_output) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); - mm->add_literal(1); // num_splits - mm->add_literal(1); // split axis - auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0); - mm->add_return({l1}); - auto prog = parse_tf("split_test_one_output.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(split_test_vector_as_input) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - std::vector axes{0, 1}; - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); - // split sizes - mm->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}}); - mm->add_literal(1); // split axis - mm->add_literal(1); // concat axis - mm->add_literal(1); // concat axis - auto l1 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0); - auto l2 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0); - auto l3 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0); - auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); - auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); - mm->add_return({l4, l5}); - auto prog = parse_tf("split_test_vector_as_input.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(sqdiff_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1); - auto prog = optimize_tf("sqdiff_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(squeeze_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); - mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0); - auto prog = optimize_tf("squeeze_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(stopgradient_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_instruction(migraphx::make_op("identity"), l0); - auto prog = optimize_tf("stopgradient_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(stridedslice_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); - auto l1 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - auto l2 = mm->add_instruction( - migraphx::make_op( - "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}), - l1); - auto shrink_axis = 1; - mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); - auto prog = optimize_tf("stridedslice_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(stridedslice_masks_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); - // add literals for starts, ends, and strides in tf (NHWC format) - mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, - std::vector{0, 1, 1, 0}); - mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, - std::vector{0, 0, 0, 0}); - mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, - std::vector{1, 1, 1, 1}); - - auto l1 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - auto l2 = mm->add_instruction( - migraphx::make_op( - "slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}), - l1); - auto l3 = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); - mm->add_return({l3}); - auto prog = parse_tf("stridedslice_masks_test.pb", true); - - EXPECT(p == prog); -} - -TEST_CASE(sub_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l2 = mm->add_instruction(migraphx::make_op("sub"), l0, l1); - mm->add_return({l2}); - auto prog = parse_tf("sub_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(tanh_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto l1 = mm->add_instruction(migraphx::make_op("tanh"), l0); - mm->add_return({l1}); - auto prog = parse_tf("tanh_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(transpose_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - migraphx::shape s0{migraphx::shape::int32_type, {4}}; - mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - auto prog = optimize_tf("transpose_test.pb", false); - - EXPECT(p == prog); -} - -TEST_CASE(variable_batch_test) -{ - migraphx::program p; - - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - mm->add_instruction(migraphx::make_op("identity"), l0); - auto prog = optimize_tf("variable_batch_test.pb", false); - - EXPECT(p == prog); -} - -int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/verify/gemm_add.cpp b/test/verify/gemm_add.cpp index 610d3f385e2..c46a6507545 100644 --- a/test/verify/gemm_add.cpp +++ b/test/verify/gemm_add.cpp @@ -35,9 +35,9 @@ struct gemm_add : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{DType, {1, 2, 3}}; - migraphx::shape m2_shape{DType, {1, 3, 4}}; - migraphx::shape m3_shape{DType, {1, 2, 4}}; + migraphx::shape m1_shape{DType, {1, 2, 1280}}; + migraphx::shape m2_shape{DType, {1, 1280, 320}}; + migraphx::shape m3_shape{DType, {1, 2, 320}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); @@ -47,6 +47,12 @@ struct gemm_add : verify_program> return p; } std::string section() const { return "gemm"; } + + // Turn on Exhaustive-tune to enable split-k GEMM perf-configs from MLIR + migraphx::compile_options get_compile_options() const + { + return migraphx::compile_options{.exhaustive_tune = true}; + } }; template struct gemm_add; diff --git a/test/verify/test_fmod_mod.cpp b/test/verify/test_fmod_mod.cpp index 8c968857f1f..992d88f2760 100644 --- a/test/verify/test_fmod_mod.cpp +++ b/test/verify/test_fmod_mod.cpp @@ -80,6 +80,5 @@ struct test_mod : verify_program> }; template struct test_mod; -// TODO: Fix half type test -// template struct test_mod; +template struct test_mod; template struct test_mod; diff --git a/test/verify/test_gemm.cpp b/test/verify/test_gemm.cpp index 7ca07598916..2f0b96ac959 100644 --- a/test/verify/test_gemm.cpp +++ b/test/verify/test_gemm.cpp @@ -26,6 +26,7 @@ #include #include #include + template struct test_gemm : verify_program> { diff --git a/tools/docker/ubuntu_2204.dockerfile b/tools/docker/ubuntu_2204.dockerfile index 906b256eebe..15afe6d56d7 100644 --- a/tools/docker/ubuntu_2204.dockerfile +++ b/tools/docker/ubuntu_2204.dockerfile @@ -36,7 +36,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- software-properties-common \ wget \ rocm-device-libs \ - hip-base \ + hip-dev \ libnuma-dev \ miopen-hip \ rocblas \