diff --git a/.bazelrc b/.bazelrc index 60e7326adf09..8b53bd475e5b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -104,6 +104,8 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 build:clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. build:clang --copt=-Qunused-arguments +# Error on struct/class mismatches, since this causes link failures on Windows. +build:clang --copt=-Werror=mismatched-tags # Configs for CUDA build:cuda --repo_env TF_NEED_CUDA=1 @@ -183,6 +185,7 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true - build:ci_linux_x86_64 --config=avx_linux --config=avx_posix build:ci_linux_x86_64 --config=mkl_open_source_only build:ci_linux_x86_64 --config=clang --verbose_failures=true +build:ci_linux_x86_64 --color=yes # TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA # toolchain for both CPU and GPU builds. @@ -203,6 +206,7 @@ build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 # Linux Aarch64 CI configs build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" +build:ci_linux_aarch64_base --color=yes build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" @@ -221,11 +225,13 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm build:ci_darwin_x86_64 --macos_minimum_os=10.14 build:ci_darwin_x86_64 --config=macos_cache_push build:ci_darwin_x86_64 --verbose_failures=true +build:ci_darwin_x86_64 --color=yes # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 build:ci_darwin_arm64 --config=macos_cache_push build:ci_darwin_arm64 --verbose_failures=true +build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows @@ -233,6 +239,7 @@ build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=tru build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE +build:ci_windows_amd64 --color=yes # ############################################################################# # RBE config options below. These inherit the CI configs above and set the @@ -379,4 +386,4 @@ build:debug --config debug_symbols -c fastbuild try-import %workspace%/.jax_configure.bazelrc # Load rc file with user-specific options. -try-import %workspace%/.bazelrc.user \ No newline at end of file +try-import %workspace%/.bazelrc.user diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index a4ec78f96c97..d0d889729448 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -12,7 +12,7 @@ on: branches: - main paths: - - '**/workflows/asan.yml' + - '**/workflows/asan.yaml' jobs: asan: @@ -25,14 +25,8 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - with: - path: jax - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - with: - repository: python/cpython - path: cpython - ref: v3.13.0 + # Install git before actions/checkout as otherwise it will download the code with the GitHub + # REST API and therefore any subsequent git commands will fail. - name: Install clang 18 env: DEBIAN_FRONTEND: noninteractive @@ -42,6 +36,14 @@ jobs: zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: python/cpython + path: cpython + ref: v3.13.0 - name: Build CPython with ASAN enabled env: ASAN_OPTIONS: detect_leaks=0 @@ -65,7 +67,7 @@ jobs: run: | source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax - python build/build.py \ + python build/build.py build --wheels=jaxlib --verbose \ --bazel_options=--color=yes \ --bazel_options=--copt=-fsanitize=address \ --clang_path=/usr/bin/clang-18 diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml new file mode 100644 index 000000000000..4a2e2ecb7fe6 --- /dev/null +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -0,0 +1,41 @@ +name: CI - Bazel CPU tests (RBE) + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"] + + runs-on: ${{ matrix.runner }} + # TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU Tests with RBE + run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml new file mode 100644 index 000000000000..a7cf645b50b3 --- /dev/null +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -0,0 +1,39 @@ +name: CI - Bazel GPU tests (RBE) + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-16"] + + runs-on: ${{ matrix.runner }} + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel GPU Tests with RBE + run: ./ci/run_bazel_test_gpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index ef752c66b294..56b7f1f4a377 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,4 +1,4 @@ -name: CI +name: ROCm CPU CI # We test all supported Python versions as follows: # - 3.10 : Documentation build @@ -11,10 +11,10 @@ on: # but only for the main branch push: branches: - - main + - rocm-main pull_request: branches: - - main + - rocm-main permissions: contents: read # to fetch code @@ -29,18 +29,21 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.11 - - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + - run: python -m pip install pre-commit + - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} + - run: pre-commit run --show-diff-on-failure --color=always --all-files build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" - runs-on: linux-x86-n2-32 - container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 + runs-on: ROCM-Ubuntu timeout-minutes: 60 strategy: matrix: @@ -57,13 +60,9 @@ jobs: prng-upgrade: 0 num_generated_cases: 1 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - - name: Image Setup - run: | - apt update - apt install -y libssl-dev + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -72,7 +71,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -102,15 +101,15 @@ jobs: documentation: name: Documentation - test code snippets - runs-on: ubuntu-latest + runs-on: ROCM-Ubuntu timeout-minutes: 10 strategy: matrix: python-version: ['3.10'] steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -119,7 +118,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -140,14 +139,14 @@ jobs: documentation_render: name: Documentation - render documentation runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 20 strategy: matrix: python-version: ['3.10'] steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -156,7 +155,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -165,8 +164,7 @@ jobs: pip install -r docs/requirements.txt - name: Render documentation run: | - sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html - + sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html jax2tf_test: name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" @@ -181,9 +179,9 @@ jobs: enable-x64: 0 num_generated_cases: 10 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -192,7 +190,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -217,21 +215,21 @@ jobs: ffi: name: FFI example - runs-on: ubuntu-latest - timeout-minutes: 5 + runs-on: ROCM-Ubuntu + timeout-minutes: 30 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - - name: Set up Python 3.11 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: 3.11 + python-version: 3.12 - name: Get pip cache dir id: pip-cache run: | python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} @@ -245,6 +243,10 @@ jobs: # a different toolchain. GCC is the default compiler on the # 'ubuntu-latest' runner, but we still set this explicitly just to be # clear. - CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ - - name: Run tests + CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON + - name: Run CPU tests + run: python -m pytest examples/ffi/tests + env: + JAX_PLATFORM_NAME: cpu + - name: Run GPU tests run: python -m pytest examples/ffi/tests diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 4bff1e87e7f3..fe879617c8a7 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -13,7 +13,7 @@ name: CI - Cloud TPU (nightly) on: schedule: - - cron: "0 14 * * *" # daily at 7am PST + - cron: "0 */2 * * *" # Run every 2 hours workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. @@ -24,17 +24,20 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] + jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - {type: "v3-8", cores: "4"}, - {type: "v4-8", cores: "4"}, - {type: "v5e-8", cores: "8"} + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] + python-version: ["3.10"] name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20240722 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} - runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"] + PYTHON: python${{ matrix.python-version }} + runs-on: ${{ matrix.tpu.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" timeout-minutes: 120 defaults: run: @@ -43,40 +46,66 @@ jobs: # https://opensource.google/documentation/reference/github/services#actions # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Checkout XLA at head, if we're building jaxlib at head. + - name: Checkout XLA at head + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.jaxlib-version == 'head' }} + with: + repository: openxla/xla + path: xla + # We need to mark the GitHub workspace as safe as otherwise git commands will fail. + - name: Mark GitHub workspace as safe + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Install JAX test requirements run: | - pip install -U -r build/test-requirements.txt - pip install -U -r build/collect-profile-requirements.txt + $PYTHON -m pip install -U -r build/test-requirements.txt + $PYTHON -m pip install -U -r build/collect-profile-requirements.txt - name: Install JAX run: | - pip uninstall -y jax jaxlib libtpu - if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then - pip install .[tpu] \ + $PYTHON -m pip uninstall -y jax jaxlib libtpu + if [ "${{ matrix.jaxlib-version }}" == "head" ]; then + # Build and install jaxlib at head + $PYTHON build/build.py build --wheels=jaxlib \ + --bazel_options=--config=rbe_linux_x86_64 \ + --local_xla_path="$(pwd)/xla" \ + --verbose + + $PYTHON -m pip install dist/*.whl + + # Install "jax" at head + $PYTHON -m pip install -U -e . + + # Install libtpu + $PYTHON -m pip install --pre libtpu \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then + $PYTHON -m pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - pip install --pre libtpu \ + $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $PYTHON -m pip install --pre libtpu \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $PYTHON -m pip install requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. - pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $PYTHON -m pip install requests else echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" exit 1 fi - python3 -c 'import sys; print("python version:", sys.version)' - python3 -c 'import jax; print("jax version:", jax.__version__)' - python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' - strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on' - python3 -c 'import jax; print("libtpu version:", + $PYTHON -c 'import sys; print("python version:", sys.version)' + $PYTHON -c 'import jax; print("jax version:", jax.__version__)' + $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' + strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' + $PYTHON -c 'import jax; print("libtpu version:", jax.lib.xla_bridge.get_backend().platform_version)' - name: Run tests env: @@ -84,14 +113,14 @@ jobs: PY_COLORS: 1 run: | # Run single-accelerator tests in parallel - JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ + JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ --maxfail=20 -m "not multiaccelerator" tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. - TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \ + TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \ tests/pallas/tpu_pallas_test.py::PallasCallPrintTest # Run multi-accelerator across all chips - python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests + $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests - name: Send chat on failure # Don't notify when testing the workflow from a branch. if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }} diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 010ebae78c43..92fef2cc29af 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -22,27 +22,27 @@ jobs: steps: - name: Checkout jax - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Checkout array-api-tests - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09 + ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install .[ci] - python -m pip install -r array-api-tests/requirements.txt + python -m pip install pytest-xdist -r array-api-tests/requirements.txt - name: Run the test suite env: ARRAY_API_TESTS_MODULE: jax.numpy JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt + pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 3f6d4be94323..2b1100cc048b 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Get repo - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax - name: Setup build and test enviroment diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml new file mode 100644 index 000000000000..f309427df197 --- /dev/null +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -0,0 +1,51 @@ +# Pulls the latest changes from upstream into main and opens a PR to merge +# them into rocm-main branch. + +name: ROCm Nightly Upstream Sync +on: + workflow_dispatch: + schedule: + - cron: '0 6 * * 1-5' +permissions: + contents: write + pull-requests: write +env: + SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }} +jobs: + sync-main: + runs-on: ubuntu-latest + steps: + - run: | + gh auth status + gh repo sync rocm/jax -b main + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + create-sync-branch: + needs: sync-main + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Create branch + run: | + git fetch + git checkout origin/main + git checkout -b $SYNC_BRANCH_NAME + # Try and merge rocm-main into this new branch so that we don't run upstream's CI code + git config --global user.email "github-actions@github.com" + git config --global user.name "GitHub Actions" + git merge origin/rocm-main || true + # If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts + git merge --abort || true + git push origin HEAD + open-sync-pr: + needs: create-sync-branch + runs-on: ubuntu-latest + steps: + - run: | + gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml new file mode 100644 index 000000000000..bd14fa050577 --- /dev/null +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -0,0 +1,41 @@ +name: ROCm Open Upstream PR +on: + pull_request: + types: [ labeled ] + branches: [ rocm-main ] +jobs: + open-upstream: + if: ${{ github.event.label.name == 'open-upstream' }} + permissions: + contents: write + pull-requests: write + runs-on: ubuntu-latest + env: + NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream" + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Rebase code to main + run: | + git config --global user.email "github-actions@github.com" + git config --global user.name "Github Actions" + git fetch + git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }} + git rebase --onto origin/main origin/rocm-main + # Force push here so that we don't run into conflicts with the origin branch + git push origin HEAD --force + - name: Leave link to create PR + env: + GH_TOKEN: ${{ github.token }} + run: | + # Bash is not friendly with newline characters, so make our own + NL=$'\n' + # Encode the PR title and body for passing as URL get parameters + TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri') + BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri') + # Create a link to the that will open up a new PR form to upstream and autofill the fields + CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC" + # Add a comment with the link to the PR + COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK" + gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY" + diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 46cd3f335fc6..ada9b4e5825f 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -22,7 +22,7 @@ on: jobs: upstream-dev: - runs-on: ubuntu-20.04-16core + runs-on: ROCM-Ubuntu permissions: contents: read checks: write # for upload-artifact @@ -36,9 +36,9 @@ jobs: outputs: artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -106,8 +106,8 @@ jobs: run: shell: bash steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: "3.x" - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index c06a12922a05..3904bf1b8f10 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -25,9 +25,9 @@ jobs: - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -40,7 +40,7 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` + python.exe build\build.py build --wheels=jaxlib ` --bazel_options=--color=yes ` --bazel_options=--config=win_clang ` --verbose diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 795f1f6157ba..4c404ef4cb75 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -31,11 +31,11 @@ jobs: - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -49,9 +49,10 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` + python.exe build\build.py build --wheels=jaxlib ` --bazel_options=--color=yes ` - --bazel_options=--config=win_clang + --bazel_options=--config=win_clang ` + --verbose - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87f706d3a404..ed38faa6774b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib] + additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/CHANGELOG.md b/CHANGELOG.md index d9da9a2bdc71..d8bb1478ad0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,118 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.35 +## jax 0.4.38 + +* Deprecations + * a number of APIs in the internal `jax.core` namespace have been deprecated, including + `ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`, + `Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by + APIs of the same name in {mod}`jax.extend.core`; see the documentation for + {mod}`jax.extend` for information on the compatibility guarantees of these + semi-public extensions. + +## jax 0.4.37 (Dec 9, 2024) + +This is a patch release of jax 0.4.36. Only "jax" was released at this version. + +* Bug fixes + * Fixed a bug where `jit` would error if an argument was named `f` (#25329). + * Fix a bug that will throw `index out of range` error in + {func}`jax.lax.while_loop` if the user register pytree node class with + different aux data for the flatten and flatten_with_path. + * Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e. + +## jax 0.4.36 (Dec 5, 2024) + +* Breaking Changes + * This release lands "stackless", an internal change to JAX's tracing + machinery. We made trace dispatch purely a function of context rather than a + function of both context and data. This let us delete a lot of machinery for + managing data-dependent tracing: levels, sublevels, `post_process_call`, + `new_base_main`, `custom_bind`, and so on. The change should only affect + users that use JAX internals. + + If you do use JAX internals then you may need to + update your code (see + https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f + for clues about how to do this). There might also be version skew + issues with JAX libraries that do this. If you find this change breaks your + non-JAX-internals-using code then try the + `config.jax_data_dependent_tracing_fallback` flag as a workaround, and if + you need help updating your code then please file a bug. + * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` + or with `enable_xla=False` have been deprecated since July 2024, with + JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` + with native serialization will still be supported. + * In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed + after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, + `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. + * The deprecated module `jax.experimental.export` has been removed. It was replaced + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + for information on migrating to the new API. + * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` + has been removed, after being deprecated in v0.4.27. + * Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`) + now raises an error. Previously, this returned a scalar object array. + * The following deprecated methods and functions in {mod}`jax.export` have + been removed: + * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect + already. + * `jax.export.Exported.lowering_platforms`: use `platforms`. + * `jax.export.Exported.mlir_module_serialization_version`: + use `calling_convention_version`. + * `jax.export.Exported.uses_shape_polymorphism`: + use `uses_global_constants`. + * the `lowering_platforms` kwarg for {func}`jax.export.export`: use + `platforms` instead. + * The kwargs `symbolic_scope` and `symbolic_constraints` from + {func}`jax.export.symbolic_args_specs` have been removed. They were + deprecated in June 2024. Use `scope` and `constraints` instead. + * Hashing of tracers, which has been deprecated since version 0.4.30, now + results in a `TypeError`. + * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and + replaces previous build.py usage. Run `python build/build.py --help` for + more details. Brief overview of the new subcommand options: + * `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt` + * `requirements_update`: Updates requirements_lock.txt files. + * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional + inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` + on the function inputs. + * {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now + return NaN for negative integer inputs, to match the behavior of SciPy from + https://github.com/scipy/scipy/pull/21827. + * `jax.clear_backends` was removed after being deprecated in v0.4.26. + * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom + call that we guarantee export stability. This is because this custom call + relies on Triton IR, which is not guaranteed to be stable. If you need + to export code that uses this custom call, you can use the `disabled_checks` + parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + +* New Features + * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for + passing compilation options to XLA. For the moment it's undocumented and + may be in flux. + * {func}`jax.tree_util.register_dataclass` now allows metadata fields to be + declared inline via {func}`dataclasses.field`. See the function documentation + for examples. + * Added {func}`jax.numpy.put_along_axis`. + * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions + ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now + supported on GPU. See {jax-issue}`#24663` for more details. + * Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0. + +* Bug fixes + * Fixed a bug where the GPU implementations of LU and QR decomposition would + result in an indexing overflow for batch sizes close to int32 max. See + {jax-issue}`#24843` for more details. + +* Deprecations + * `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated; + use `jax.Array` instead. + * `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError` + instead. + +## jax 0.4.35 (Oct 22, 2024) * Breaking Changes * {func}`jax.numpy.isscalar` now returns True for any array-like object with @@ -40,7 +151,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The semi-public API `jax.lib.xla_client.register_custom_call_target` has been deprecated. Use the JAX FFI instead. * The semi-public APIs `jax.lib.xla_client.dtype_to_etype`, - `jax.lib.xla_client.ops`, + `jax.lib.xla_client.ops`, `jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`, `jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and `jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO diff --git a/README.md b/README.md index c99d3db10a2a..b001a8ceeb15 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` @@ -349,7 +348,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. @@ -369,7 +368,7 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). + flow](https://jax.readthedocs.io/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), @@ -390,6 +389,7 @@ Some standouts: | Google TPU | yes | n/a | n/a | n/a | n/a | n/a | | AMD GPU | yes | no | experimental | n/a | no | no | | Apple GPU | n/a | no | n/a | experimental | n/a | n/a | +| Intel GPU | experimental | n/a | n/a | n/a | no | no | ### Instructions @@ -401,6 +401,7 @@ Some standouts: | Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | | AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | +| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) for information on alternative installation strategies. These include compiling @@ -411,23 +412,18 @@ community-supported conda build, and answers to some frequently-asked questions. ## Neural network libraries -Multiple Google research groups develop and share libraries for training neural -networks in JAX. If you want a fully featured library for neural network +Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries +for training neural networks in JAX. If you want a fully featured library for neural network training with examples and how-to guides, try -[Flax](https://github.com/google/flax). Check out the new [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) API for a -simplified development experience. - -Google X maintains the neural network library -[Equinox](https://github.com/patrick-kidger/equinox). This is used as the -foundation for several other libraries in the JAX ecosystem. - -In addition, DeepMind has open-sourced an [ecosystem of libraries around -JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) -including [Optax](https://github.com/deepmind/optax) for gradient processing and -optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and -[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch -the NeurIPS 2020 JAX Ecosystem at DeepMind talk -[here](https://www.youtube.com/watch?v=iDxJxIyzSiM)) +[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). + +Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) +on the JAX documentation site for a list of JAX-based network libraries, which includes +[Optax](https://github.com/deepmind/optax) for gradient processing and +optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and +[Equinox](https://github.com/patrick-kidger/equinox) for neural networks. +(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk +[here](https://www.youtube.com/watch?v=iDxJxIyzSiM) for additional details.) ## Citing JAX diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index d26801d8dfe5..d365a6facd90 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -17,7 +17,6 @@ import jax from jax import core -from jax._src.numpy import lax_numpy from jax import export jax.config.parse_flags_with_absl() @@ -76,7 +75,7 @@ def inequalities_slice(state): while state: for _ in range(30): a.scope._clear_caches() - start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b) + start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b) _ = 0 <= slice_size <= b _ = start >= 0 _ = start + slice_size <= b diff --git a/build/build.py b/build/build.py index 44343ebab4ef..a6c1a7922b0e 100755 --- a/build/build.py +++ b/build/build.py @@ -14,303 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# Helper script for building JAX's libjax easily. - +# CLI for building JAX wheel packages from source and for updating the +# requirements_lock.txt files import argparse -import collections -import hashlib +import asyncio import logging import os -import pathlib import platform -import re -import shutil -import stat -import subprocess import sys -import textwrap -import urllib.request +import copy -logger = logging.getLogger(__name__) +from tools import command, utils -def is_windows(): - return sys.platform.startswith("win32") - - -def shell(cmd): - try: - logger.info("shell(): %s", cmd) - output = subprocess.check_output(cmd) - except subprocess.CalledProcessError as e: - logger.info("subprocess raised: %s", e) - if e.output: print(e.output) - raise - except Exception as e: - logger.info("subprocess raised: %s", e) - raise - return output.decode("UTF-8").strip() - - -# Python - -def get_python_bin_path(python_bin_path_flag): - """Returns the path to the Python interpreter to use.""" - path = python_bin_path_flag or sys.executable - return path.replace(os.sep, "/") - - -def get_python_version(python_bin_path): - version_output = shell( - [python_bin_path, "-c", - ("import sys; print(\"{}.{}\".format(sys.version_info[0], " - "sys.version_info[1]))")]) - major, minor = map(int, version_output.split(".")) - return major, minor - -def check_python_version(python_version): - if python_version < (3, 10): - print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) - sys.exit(-1) - - -def get_githash(): - try: - return subprocess.run( - ["git", "rev-parse", "HEAD"], - encoding='utf-8', - capture_output=True).stdout.strip() - except OSError: - return "" - -# Bazel - -BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" -BazelPackage = collections.namedtuple("BazelPackage", - ["base_uri", "file", "sha256"]) -bazel_packages = { - ("Linux", "x86_64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-linux-x86_64", - sha256= - "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"), - ("Linux", "aarch64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-linux-arm64", - sha256= - "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"), - ("Darwin", "x86_64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-darwin-x86_64", - sha256= - "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"), - ("Darwin", "arm64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-darwin-arm64", - sha256= - "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"), - ("Windows", "AMD64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-windows-x86_64.exe", - sha256= - "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"), -} - +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) -def download_and_verify_bazel(): - """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" - package = bazel_packages.get((platform.system(), platform.machine())) - if package is None: - return None - - if not os.access(package.file, os.X_OK): - uri = (package.base_uri or BAZEL_BASE_URI) + package.file - sys.stdout.write(f"Downloading bazel from: {uri}\n") - - def progress(block_count, block_size, total_size): - if total_size <= 0: - total_size = 170**6 - progress = (block_count * block_size) / total_size - num_chars = 40 - progress_chars = int(num_chars * progress) - sys.stdout.write("{} [{}{}] {}%\r".format( - package.file, "#" * progress_chars, - "." * (num_chars - progress_chars), int(progress * 100.0))) - - tmp_path, _ = urllib.request.urlretrieve( - uri, None, progress if sys.stdout.isatty() else None - ) - sys.stdout.write("\n") - - # Verify that the downloaded Bazel binary has the expected SHA256. - with open(tmp_path, "rb") as downloaded_file: - contents = downloaded_file.read() - - digest = hashlib.sha256(contents).hexdigest() - if digest != package.sha256: - print( - "Checksum mismatch for downloaded bazel binary (expected {}; got {})." - .format(package.sha256, digest)) - sys.exit(-1) - - # Write the file as the bazel file name. - with open(package.file, "wb") as out_file: - out_file.write(contents) - - # Mark the file as executable. - st = os.stat(package.file) - os.chmod(package.file, - st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) - - return os.path.join(".", package.file) - - -def get_bazel_paths(bazel_path_flag): - """Yields a sequence of guesses about bazel path. Some of sequence elements - can be None. The resulting iterator is lazy and potentially has a side - effects.""" - yield bazel_path_flag - yield shutil.which("bazel") - yield download_and_verify_bazel() - - -def get_bazel_path(bazel_path_flag): - """Returns the path to a Bazel binary, downloading Bazel if not found. Also, - checks Bazel's version is at least newer than 6.5.0 - - A manual version check is needed only for really old bazel versions. - Newer bazel releases perform their own version check against .bazelversion - (see for details - https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). - """ - for path in filter(None, get_bazel_paths(bazel_path_flag)): - version = get_bazel_version(path) - if version is not None and version >= (6, 5, 0): - return path, ".".join(map(str, version)) - - print("Cannot find or download a suitable version of bazel." - "Please install bazel >= 6.5.0.") - sys.exit(-1) - - -def get_bazel_version(bazel_path): - try: - version_output = shell([bazel_path, "--version"]) - except (subprocess.CalledProcessError, OSError): - return None - match = re.search(r"bazel *([0-9\\.]+)", version_output) - if match is None: - return None - return tuple(int(x) for x in match.group(1).split(".")) - - -def get_clang_path_or_exit(): - which_clang_output = shutil.which("clang") - if which_clang_output: - # If we've found a clang on the path, need to get the fully resolved path - # to ensure that system headers are found. - return str(pathlib.Path(which_clang_output).resolve()) - else: - print( - "--clang_path is unset and clang cannot be found" - " on the PATH. Please pass --clang_path directly." - ) - sys.exit(-1) - -def get_clang_major_version(clang_path): - clang_version_proc = subprocess.run( - [clang_path, "-E", "-P", "-"], - input="__clang_major__", - check=True, - capture_output=True, - text=True, - ) - major_version = int(clang_version_proc.stdout) - - return major_version - - -def write_bazelrc(*, remote_build, - cuda_version, cudnn_version, rocm_toolkit_path, - cpu, cuda_compute_capabilities, - rocm_amdgpu_targets, target_cpu_features, - wheel_cpu, enable_mkl_dnn, use_clang, clang_path, - clang_major_version, python_version, - enable_cuda, enable_nccl, enable_rocm, - use_cuda_nvcc): - - with open("../.jax_configure.bazelrc", "w") as f: - if not remote_build: - f.write(textwrap.dedent("""\ - build --strategy=Genrule=standalone - """)) - - if use_clang: - f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n') - f.write(f'build --repo_env CC="{clang_path}"\n') - f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n') - f.write('build --copt=-Wno-error=unused-command-line-argument\n') - if clang_major_version in (16, 17, 18): - # Necessary due to XLA's old version of upb. See: - # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 - f.write("build --copt=-Wno-gnu-offsetof-extensions\n") - - if rocm_toolkit_path: - f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" - .format(rocm_toolkit_path=rocm_toolkit_path)) - if rocm_amdgpu_targets: - f.write( - f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"\n') - if cpu is not None: - f.write(f"build --cpu={cpu}\n") - - if target_cpu_features == "release": - if wheel_cpu == "x86_64": - f.write("build --config=avx_windows\n" if is_windows() - else "build --config=avx_posix\n") - elif target_cpu_features == "native": - if is_windows(): - print("--target_cpu_features=native is not supported on Windows; ignoring.") - else: - f.write("build --config=native_arch_posix\n") - - if enable_mkl_dnn: - f.write("build --config=mkl_open_source_only\n") - if enable_cuda: - f.write("build --config=cuda\n") - if use_cuda_nvcc: - f.write("build --config=build_cuda_with_nvcc\n") - else: - f.write("build --config=build_cuda_with_clang\n") - f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if cuda_version: - f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: - f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n" - .format(cudnn_version=cudnn_version)) - if cuda_compute_capabilities: - f.write( - f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') - if enable_rocm: - f.write("build --config=rocm_base\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if use_clang: - f.write("build --config=rocm\n") - f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n") - if python_version: - f.write( - "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( - python_version=python_version)) BANNER = r""" _ _ __ __ | | / \ \ \/ / @@ -321,418 +43,577 @@ def write_bazelrc(*, remote_build, """ EPILOG = """ +From the root directory of the JAX repository, run + `python build/build.py build --wheels=` to build JAX + artifacts. -From the 'build' directory in the JAX repository, run - python build.py -or - python3 build.py -to download and build JAX's XLA (jaxlib) dependency. -""" + Multiple wheels can be built with a single invocation of the CLI. + E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin + To update the requirements_lock.txt files, run + `python build/build.py requirements_update` +""" -def _parse_string_as_bool(s): - """Parses a string as a boolean argument.""" - lower = s.lower() - if lower == "true": - return True - elif lower == "false": - return False - else: - raise ValueError(f"Expected either 'true' or 'false'; got {s}") +# Define the build target for each wheel. +WHEEL_BUILD_TARGET_DICT = { + "jaxlib": "//jaxlib/tools:build_wheel", + "jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", + "jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", +} -def add_boolean_argument(parser, name, default=False, help_str=None): - """Creates a boolean flag.""" - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--" + name, - nargs="?", - default=default, - const=True, - type=_parse_string_as_bool, - help=help_str) - group.add_argument("--no" + name, dest=name, action="store_false") +def add_global_arguments(parser: argparse.ArgumentParser): + """Adds all the global arguments that applies to all the CLI subcommands.""" + parser.add_argument( + "--python_version", + type=str, + choices=["3.10", "3.11", "3.12", "3.13"], + default=f"{sys.version_info.major}.{sys.version_info.minor}", + help= + """ + Hermetic Python version to use. Default is to use the version of the + Python binary that executed the CLI. + """, + ) + bazel_group = parser.add_argument_group('Bazel Options') + bazel_group.add_argument( + "--bazel_path", + type=str, + default="", + help=""" + Path to the Bazel binary to use. The default is to find bazel via the + PATH; if none is found, downloads a fresh copy of Bazel from GitHub. + """, + ) -def _get_editable_output_paths(output_path): - """Returns the paths to the editable wheels.""" - return ( - os.path.join(output_path, "jaxlib"), - os.path.join(output_path, "jax_gpu_pjrt"), - os.path.join(output_path, "jax_gpu_plugin"), + bazel_group.add_argument( + "--bazel_startup_options", + action="append", + default=[], + help=""" + Additional startup options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_startup_options='--nobatch' + """, ) + bazel_group.add_argument( + "--bazel_options", + action="append", + default=[], + help=""" + Additional build options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_options='--local_resources=HOST_CPUS' + """, + ) -def main(): - cwd = os.getcwd() - parser = argparse.ArgumentParser( - description="Builds jaxlib from source.", epilog=EPILOG) - add_boolean_argument( - parser, - "verbose", - default=False, - help_str="Should we produce verbose debugging output?") parser.add_argument( - "--bazel_path", - help="Path to the Bazel binary to use. The default is to find bazel via " - "the PATH; if none is found, downloads a fresh copy of bazel from " - "GitHub.") - parser.add_argument( - "--python_bin_path", - help="Path to Python binary whose version to match while building with " - "hermetic python. The default is the Python interpreter used to run the " - "build script. DEPRECATED: use --python_version instead.") - parser.add_argument( - "--target_cpu_features", - choices=["release", "native", "default"], - default="release", - help="What CPU features should we target? 'release' enables CPU " - "features that should be enabled for a release build, which on " - "x86-64 architectures enables AVX. 'native' enables " - "-march=native, which generates code targeted to use all " - "features of the current machine. 'default' means don't opt-in " - "to any architectural features and use whatever the C compiler " - "generates by default.") - add_boolean_argument( - parser, - "use_clang", - default = "true", - help_str=( - "DEPRECATED: This flag is redundant because clang is " - "always used as default compiler." - ), + "--dry_run", + action="store_true", + help="Prints the Bazel command that is going to be executed.", ) + parser.add_argument( - "--clang_path", - help=( - "Path to clang binary to use. The default is " - "to find clang via the PATH." - ), + "--verbose", + action="store_true", + help="Produce verbose output for debugging.", ) - add_boolean_argument( - parser, - "enable_mkl_dnn", - default=True, - help_str="Should we build with MKL-DNN enabled?", + + parser.add_argument( + "--detailed_timestamped_log", + action="store_true", + help=""" + Enable detailed logging of the Bazel command with timestamps. The logs + will be stored and can be accessed as artifacts. + """, ) - add_boolean_argument( - parser, - "enable_cuda", - help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN." + + +def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): + """Adds all the arguments that applies to the artifact subcommands.""" + parser.add_argument( + "--wheels", + type=str, + default="jaxlib", + help= + """ + A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib", + --wheels="jaxlib,jax-cuda-plugin", etc. + Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt, + jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt + """, ) - add_boolean_argument( - parser, - "use_cuda_nvcc", - default=True, - help_str=( - "Should we build CUDA code using NVCC compiler driver? The default value " - "is true. If --nouse_cuda_nvcc flag is used then CUDA code is built " - "by clang compiler." - ), + + parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' build instead of a wheel.", ) - add_boolean_argument( - parser, - "build_gpu_plugin", - default=False, - help_str=( - "Are we building the gpu plugin in addition to jaxlib? The GPU " - "plugin is still experimental and is not ready for use yet." - ), + + parser.add_argument( + "--output_path", + type=str, + default=os.path.join(os.getcwd(), "dist"), + help="Directory to which the JAX wheel packages should be written.", ) + parser.add_argument( - "--build_gpu_kernel_plugin", - choices=["cuda", "rocm"], - default="", - help=( - "Specify 'cuda' or 'rocm' to build the respective kernel plugin." - " When this flag is set, jaxlib will not be built." - ), + "--configure_only", + action="store_true", + help=""" + If true, writes the Bazel options to the .jax_configure.bazelrc file but + does not build the artifacts. + """, ) - add_boolean_argument( - parser, - "build_gpu_pjrt_plugin", - default=False, - help_str=( - "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " - "when this flag is True." - ), + + # CUDA Options + cuda_group = parser.add_argument_group('CUDA Options') + cuda_group.add_argument( + "--cuda_version", + type=str, + help= + """ + Hermetic CUDA version to use. Default is to use the version specified + in the .bazelrc. + """, ) - parser.add_argument( - "--gpu_plugin_cuda_version", - choices=["12"], + + cuda_group.add_argument( + "--cuda_major_version", + type=str, default="12", - help="Which CUDA major version the gpu plugin is for.") - parser.add_argument( - "--gpu_plugin_rocm_version", - choices=["60"], - default="60", - help="Which ROCM major version the gpu plugin is for.") - add_boolean_argument( - parser, - "enable_rocm", - help_str="Should we build with ROCm enabled?") - add_boolean_argument( - parser, - "enable_nccl", - default=True, - help_str="Should we build with NCCL enabled? Has no effect for non-CUDA " - "builds.") - add_boolean_argument( - parser, - "remote_build", - default=False, - help_str="Should we build with RBE (Remote Build Environment)?") - parser.add_argument( - "--cuda_version", - default=None, - help="CUDA toolkit version, e.g., 12.3.2") - parser.add_argument( + help= + """ + Which CUDA major version should the wheel be tagged as? Auto-detected if + --cuda_version is set. When --cuda_version is not set, the default is to + set the major version to 12 to match the default in .bazelrc. + """, + ) + + cuda_group.add_argument( "--cudnn_version", - default=None, - help="CUDNN version, e.g., 8.9.7.29") - # Caution: if changing the default list of CUDA capabilities, you should also - # update the list in .bazelrc, which is used for wheel builds. - parser.add_argument( + type=str, + help= + """ + Hermetic cuDNN version to use. Default is to use the version specified + in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--disable_nccl", + action="store_true", + help="Should NCCL be disabled?", + ) + + cuda_group.add_argument( "--cuda_compute_capabilities", + type=str, default=None, - help="A comma-separated list of CUDA compute capabilities to support.") - parser.add_argument( + help= + """ + A comma-separated list of CUDA compute capabilities to support. Default + is to use the values specified in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--build_cuda_with_clang", + action="store_true", + help=""" + Should CUDA code be compiled using Clang? The default behavior is to + compile CUDA with NVCC. + """, + ) + + # ROCm Options + rocm_group = parser.add_argument_group('ROCm Options') + rocm_group.add_argument( + "--rocm_version", + type=str, + default="60", + help="ROCm version to use", + ) + + rocm_group.add_argument( "--rocm_amdgpu_targets", + type=str, default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100", - help="A comma-separated list of ROCm amdgpu targets to support.") - parser.add_argument( + help="A comma-separated list of ROCm amdgpu targets to support.", + ) + + rocm_group.add_argument( "--rocm_path", - default=None, - help="Path to the ROCm toolkit.") - parser.add_argument( - "--bazel_startup_options", - action="append", default=[], - help="Additional startup options to pass to bazel.") - parser.add_argument( - "--bazel_options", - action="append", default=[], - help="Additional options to pass to the main Bazel command to be " - "executed, e.g. `run`.") - parser.add_argument( - "--output_path", - default=os.path.join(cwd, "dist"), - help="Directory to which the jaxlib wheel should be written") - parser.add_argument( - "--target_cpu", - default=None, - help="CPU platform to target. Default is the same as the host machine. " - "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") - parser.add_argument( - "--editable", + type=str, + default="", + help="Path to the ROCm toolkit.", + ) + + # Compile Options + compile_group = parser.add_argument_group('Compile Options') + + compile_group.add_argument( + "--use_clang", + type=utils._parse_string_as_bool, + default="true", + const=True, + nargs="?", + help=""" + Whether to use Clang as the compiler. Not recommended to set this to + False as JAX uses Clang as the default compiler. + """, + ) + + compile_group.add_argument( + "--clang_path", + type=str, + default="", + help=""" + Path to the Clang binary to use. + """, + ) + + compile_group.add_argument( + "--disable_mkl_dnn", action="store_true", - help="Create an 'editable' jaxlib build instead of a wheel.") - parser.add_argument( - "--python_version", + help=""" + Disables MKL-DNN. + """, + ) + + compile_group.add_argument( + "--target_cpu_features", + choices=["release", "native", "default"], + default="release", + help=""" + What CPU features should we target? Release enables CPU features that + should be enabled for a release build, which on x86-64 architectures + enables AVX. Native enables -march=native, which generates code targeted + to use all features of the current machine. Default means don't opt-in + to any architectural features and use whatever the C compiler generates + by default. + """, + ) + + compile_group.add_argument( + "--target_cpu", default=None, - help="hermetic python version, e.g., 3.10") - add_boolean_argument( - parser, - "configure_only", - default=False, - help_str="If true, writes a .bazelrc file but does not build jaxlib.") - add_boolean_argument( - parser, - "requirements_update", - default=False, - help_str="If true, writes a .bazelrc and updates requirements_lock.txt " - "for a corresponding version of Python but does not build " - "jaxlib.") - add_boolean_argument( - parser, - "requirements_nightly_update", - default=False, - help_str="Same as update_requirements, but will consider dev, nightly " - "and pre-release versions of packages.") + help="CPU platform to target. Default is the same as the host machine.", + ) + + compile_group.add_argument( + "--local_xla_path", + type=str, + default=os.environ.get("JAXCI_XLA_GIT_DIR", ""), + help=""" + Path to local XLA repository to use. If not set, Bazel uses the XLA at + the pinned version in workspace.bzl. + """, + ) + +async def main(): + parser = argparse.ArgumentParser( + description=r""" + CLI for building JAX wheel packages from source and for updating the + requirements_lock.txt files + """, + epilog=EPILOG, + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + # Create subparsers for build and requirements_update + subparsers = parser.add_subparsers(dest="command", required=True) + + # requirements_update subcommand + requirements_update_parser = subparsers.add_parser( + "requirements_update", help="Updates the requirements_lock.txt files" + ) + requirements_update_parser.add_argument( + "--nightly_update", + action="store_true", + help=""" + If true, updates requirements_lock.txt for a corresponding version of + Python and will consider dev, nightly and pre-release versions of + packages. + """, + ) + add_global_arguments(requirements_update_parser) + + # Artifact build subcommand + build_artifact_parser = subparsers.add_parser( + "build", help="Builds the jaxlib, plugin, and pjrt artifact" + ) + add_artifact_subcommand_arguments(build_artifact_parser) + add_global_arguments(build_artifact_parser) + + arch = platform.machine() + os_name = platform.system().lower() args = parser.parse_args() - logging.basicConfig() + logger.info("%s", BANNER) + if args.verbose: - logger.setLevel(logging.DEBUG) + logging.getLogger().setLevel(logging.DEBUG) + logger.info("Verbose logging enabled") + + bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) + + logging.debug("Bazel path: %s", bazel_path) + logging.debug("Bazel version: %s", bazel_version) + + executor = command.SubprocessExecutor() + + # Start constructing the Bazel command + bazel_command_base = command.CommandBuilder(bazel_path) - if args.enable_cuda and args.enable_rocm: - parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") + if args.bazel_startup_options: + logging.debug( + "Additional Bazel startup options: %s", args.bazel_startup_options + ) + for option in args.bazel_startup_options: + bazel_command_base.append(option) - print(BANNER) + bazel_command_base.append("run") - output_path = os.path.abspath(args.output_path) - os.chdir(os.path.dirname(__file__ or args.prog) or '.') + if args.python_version: + logging.debug("Hermetic Python version: %s", args.python_version) + bazel_command_base.append( + f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}" + ) + + # Enable verbose failures. + bazel_command_base.append("--verbose_failures=true") + + # Requirements update subcommand execution + if args.command == "requirements_update": + requirements_command = copy.deepcopy(bazel_command_base) + if args.bazel_options: + logging.debug( + "Using additional build options: %s", args.bazel_options + ) + for option in args.bazel_options: + requirements_command.append(option) + + if args.nightly_update: + logging.info( + "--nightly_update is set. Bazel will run" + " //build:requirements_nightly.update" + ) + requirements_command.append("//build:requirements_nightly.update") + else: + requirements_command.append("//build:requirements.update") + + result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) + if result.return_code != 0: + raise RuntimeError(f"Command failed with return code {result.return_code}") + else: + sys.exit(0) - host_cpu = platform.machine() wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", "ppc": "ppc64le", "aarch64": "aarch64", } - # TODO(phawkins): support other bazel cpu overrides. - wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None - else host_cpu) - - # Find a working Bazel. - bazel_path, bazel_version = get_bazel_path(args.bazel_path) - print(f"Bazel binary path: {bazel_path}") - print(f"Bazel version: {bazel_version}") - - if args.python_version: - python_version = args.python_version - else: - python_bin_path = get_python_bin_path(args.python_bin_path) - print(f"Python binary path: {python_bin_path}") - python_version = get_python_version(python_bin_path) - print("Python version: {}".format(".".join(map(str, python_version)))) - check_python_version(python_version) - python_version = ".".join(map(str, python_version)) - - print("Use clang: {}".format("yes" if args.use_clang else "no")) - clang_path = args.clang_path - clang_major_version = None - if args.use_clang: - if not clang_path: - clang_path = get_clang_path_or_exit() - print(f"clang path: {clang_path}") - clang_major_version = get_clang_major_version(clang_path) - - print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) - print(f"Target CPU: {wheel_cpu}") - print(f"Target CPU features: {args.target_cpu_features}") - - rocm_toolkit_path = args.rocm_path - print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) - if args.enable_cuda: - if args.cuda_compute_capabilities is not None: - print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") - if args.cuda_version: - print(f"CUDA version: {args.cuda_version}") - if args.cudnn_version: - print(f"CUDNN version: {args.cudnn_version}") - print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no")) - - print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no")) - if args.enable_rocm: - if rocm_toolkit_path: - print(f"ROCm toolkit path: {rocm_toolkit_path}") - print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}") - - write_bazelrc( - remote_build=args.remote_build, - cuda_version=args.cuda_version, - cudnn_version=args.cudnn_version, - rocm_toolkit_path=rocm_toolkit_path, - cpu=args.target_cpu, - cuda_compute_capabilities=args.cuda_compute_capabilities, - rocm_amdgpu_targets=args.rocm_amdgpu_targets, - target_cpu_features=args.target_cpu_features, - wheel_cpu=wheel_cpu, - enable_mkl_dnn=args.enable_mkl_dnn, - use_clang=args.use_clang, - clang_path=clang_path, - clang_major_version=clang_major_version, - python_version=python_version, - enable_cuda=args.enable_cuda, - enable_nccl=args.enable_nccl, - enable_rocm=args.enable_rocm, - use_cuda_nvcc=args.use_cuda_nvcc, + target_cpu = ( + wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch ) - if args.requirements_update or args.requirements_nightly_update: - if args.requirements_update: - task = "//build:requirements.update" - else: # args.requirements_nightly_update - task = "//build:requirements_nightly.update" - update_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", task, *args.bazel_options]) - print(" ".join(update_command)) - shell(update_command) - return - - if args.configure_only: - return - - print("\nBuilding XLA and installing it in the jaxlib source tree...") - - command_base = ( - bazel_path, - *args.bazel_startup_options, - "run", - "--verbose_failures=true", - *args.bazel_options, - ) - - if args.build_gpu_plugin and args.editable: - output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = ( - _get_editable_output_paths(output_path) + if args.local_xla_path: + logging.debug("Local XLA path: %s", args.local_xla_path) + bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") + + if args.target_cpu: + logging.debug("Target CPU: %s", args.target_cpu) + bazel_command_base.append(f"--cpu={args.target_cpu}") + + if args.disable_nccl: + logging.debug("Disabling NCCL") + bazel_command_base.append("--config=nonccl") + + git_hash = utils.get_githash() + + # Wheel build command execution + for wheel in args.wheels.split(","): + # Allow CUDA/ROCm wheels without the "jax-" prefix. + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" + ) + sys.exit(1) + + wheel_build_command = copy.deepcopy(bazel_command_base) + print("\n") + logger.info( + "Building %s for %s %s...", + wheel, + os_name, + arch, ) - else: - output_path_jaxlib = output_path - output_path_jax_pjrt = output_path - output_path_jax_kernel = output_path - - if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: - build_cpu_wheel_command = [ - *command_base, - "//jaxlib/tools:build_wheel", "--", - f"--output_path={output_path_jaxlib}", - f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}" - ] - if args.build_gpu_plugin: - build_cpu_wheel_command.append("--skip_gpu_kernels") - if args.editable: - build_cpu_wheel_command.append("--editable") - print(" ".join(build_cpu_wheel_command)) - shell(build_cpu_wheel_command) - - if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ - (args.build_gpu_kernel_plugin == "rocm"): - build_gpu_kernels_command = [ - *command_base, - "//jaxlib/tools:build_gpu_kernels_wheel", "--", - f"--output_path={output_path_jax_kernel}", - f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + + clang_path = "" + if args.use_clang: + clang_path = args.clang_path or utils.get_clang_path_or_exit() + clang_major_version = utils.get_clang_major_version(clang_path) + logging.debug( + "Using Clang as the compiler, clang path: %s, clang version: %s", + clang_path, + clang_major_version, + ) + + # Use double quotes around clang path to avoid path issues on Windows. + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_gpu_kernels_command.append("--editable") - print(" ".join(build_gpu_kernels_command)) - shell(build_gpu_kernels_command) - - if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: - build_pjrt_plugin_command = [ - *command_base, - "//jaxlib/tools:build_gpu_plugin_wheel", "--", - f"--output_path={output_path_jax_pjrt}", - f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + logging.debug("Use Clang: False") + + # Do not apply --config=clang on Mac as these settings do not apply to + # Apple Clang. + if os_name != "darwin": + wheel_build_command.append("--config=clang") + + if not args.disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + wheel_build_command.append("--config=mkl_open_source_only") + + if args.target_cpu_features == "release": + if arch in ["x86_64", "AMD64"]: + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + wheel_build_command.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif wheel_build_command == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." + ) + else: + logging.debug("Using native cpu features: --config=native_arch_posix") + wheel_build_command.append("--config=native_arch_posix") else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_pjrt_plugin_command.append("--editable") - print(" ".join(build_pjrt_plugin_command)) - shell(build_pjrt_plugin_command) + logging.debug("Using default cpu features") + + if "cuda" in wheel: + wheel_build_command.append("--config=cuda") + wheel_build_command.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command.append("--config=build_cuda_with_clang") + else: + logging.debug("Building CUDA with NVCC") + wheel_build_command.append("--config=build_cuda_with_nvcc") + + if args.cuda_version: + logging.debug("Hermetic CUDA version: %s", args.cuda_version) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" + ) + if args.cudnn_version: + logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" + ) + if args.cuda_compute_capabilities: + logging.debug( + "Hermetic CUDA compute capabilities: %s", + args.cuda_compute_capabilities, + ) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" + ) + + if "rocm" in wheel: + wheel_build_command.append("--config=rocm_base") + if args.use_clang: + wheel_build_command.append("--config=rocm") + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + wheel_build_command.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" + ) + + # Append additional build options at the end to override any options set in + # .bazelrc or above. + if args.bazel_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_options + ) + for option in args.bazel_options: + wheel_build_command.append(option) + + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.info("Bazel options written to .jax_configure.bazelrc") + + if args.configure_only: + logging.info("--configure_only is set so not running any Bazel commands.") + else: + # Append the build target to the Bazel command. + build_target = WHEEL_BUILD_TARGET_DICT[wheel] + wheel_build_command.append(build_target) + + wheel_build_command.append("--") + + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + + if args.editable: + logger.info("Building an editable build") + output_path = os.path.join(output_path, wheel) + wheel_build_command.append("--editable") + + wheel_build_command.append(f'--output_path="{output_path}"') + wheel_build_command.append(f"--cpu={target_cpu}") + + if "cuda" in wheel: + wheel_build_command.append("--enable-cuda=True") + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + wheel_build_command.append(f"--platform_version={cuda_major_version}") + + if "rocm" in wheel: + wheel_build_command.append("--enable-rocm=True") + wheel_build_command.append(f"--platform_version={args.rocm_version}") + + wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") + + result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) + # Exit with error if any wheel build fails. + if result.return_code != 0: + raise RuntimeError(f"Command failed with return code {result.return_code}") - shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) + # Exit with success if all wheels in the list were built successfully. + sys.exit(0) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/build/requirements.in b/build/requirements.in index a8d81fa5c670..e122aaa4ad78 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -3,11 +3,6 @@ # -r test-requirements.txt -# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement -# below. -matplotlib~=3.8.4; python_version<="3.10" -matplotlib; python_version>="3.11" - # # build deps # diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index adabb0dd2e70..ccffa247f36d 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -12,6 +12,10 @@ attrs==23.2.0 \ --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 @@ -295,7 +299,7 @@ matplotlib==3.8.4 ; python_version <= "3.10" \ --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -323,7 +327,7 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0 \ +numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ @@ -371,7 +375,6 @@ numpy==2.0.0 \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -380,84 +383,93 @@ numpy==2.0.0 \ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 # via + # auditwheel # build # matplotlib # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a +pillow==11.0.0 \ + --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ + --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ + --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ + --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ + --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ + --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ + --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ + --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ + --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ + --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ + --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ + --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ + --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ + --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ + --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ + --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ + --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ + --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ + --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ + --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ + --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ + --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ + --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ + --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ + --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ + --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ + --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ + --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ + --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ + --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ + --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ + --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ + --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ + --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ + --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ + --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ + --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ + --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ + --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ + --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ + --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ + --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ + --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ + --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ + --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ + --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ + --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ + --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ + --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ + --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ + --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ + --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ + --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ + --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ + --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ + --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ + --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ + --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ + --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ + --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ + --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ + --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ + --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ + --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ + --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ + --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ + --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ + --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ + --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ + --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ + --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ + --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ + --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ + --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ + --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 # via # -r build/test-requirements.txt # matplotlib @@ -487,6 +499,10 @@ psutil==5.9.8 \ --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a @@ -621,4 +637,6 @@ zstandard==0.22.0 \ setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 053e996cefad..7f3ee61ff7f6 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -12,6 +12,10 @@ attrs==23.2.0 \ --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 @@ -290,7 +294,7 @@ matplotlib==3.9.0 ; python_version >= "3.11" \ --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -318,7 +322,7 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0 \ +numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ @@ -366,7 +370,6 @@ numpy==2.0.0 \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -375,84 +378,93 @@ numpy==2.0.0 \ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 # via + # auditwheel # build # matplotlib # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a +pillow==11.0.0 \ + --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ + --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ + --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ + --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ + --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ + --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ + --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ + --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ + --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ + --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ + --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ + --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ + --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ + --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ + --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ + --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ + --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ + --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ + --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ + --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ + --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ + --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ + --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ + --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ + --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ + --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ + --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ + --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ + --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ + --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ + --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ + --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ + --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ + --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ + --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ + --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ + --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ + --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ + --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ + --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ + --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ + --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ + --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ + --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ + --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ + --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ + --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ + --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ + --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ + --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ + --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ + --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ + --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ + --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ + --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ + --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ + --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ + --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ + --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ + --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ + --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ + --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ + --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ + --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ + --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ + --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ + --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ + --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ + --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ + --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ + --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ + --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ + --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ + --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ + --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 # via # -r build/test-requirements.txt # matplotlib @@ -482,6 +494,10 @@ psutil==5.9.8 \ --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a @@ -610,4 +626,6 @@ zstandard==0.22.0 \ setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 1468e64c29cd..bf22c3623b47 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -12,6 +12,10 @@ attrs==23.2.0 \ --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 @@ -290,7 +294,7 @@ matplotlib==3.9.0 ; python_version >= "3.11" \ --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -318,7 +322,7 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0 \ +numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ @@ -366,7 +370,6 @@ numpy==2.0.0 \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -375,84 +378,93 @@ numpy==2.0.0 \ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 # via + # auditwheel # build # matplotlib # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a +pillow==11.0.0 \ + --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ + --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ + --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ + --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ + --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ + --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ + --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ + --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ + --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ + --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ + --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ + --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ + --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ + --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ + --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ + --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ + --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ + --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ + --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ + --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ + --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ + --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ + --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ + --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ + --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ + --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ + --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ + --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ + --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ + --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ + --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ + --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ + --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ + --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ + --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ + --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ + --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ + --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ + --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ + --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ + --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ + --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ + --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ + --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ + --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ + --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ + --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ + --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ + --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ + --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ + --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ + --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ + --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ + --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ + --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ + --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ + --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ + --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ + --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ + --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ + --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ + --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ + --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ + --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ + --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ + --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ + --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ + --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ + --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ + --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ + --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ + --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ + --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ + --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ + --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 # via # -r build/test-requirements.txt # matplotlib @@ -482,6 +494,10 @@ psutil==5.9.8 \ --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a @@ -610,4 +626,6 @@ zstandard==0.22.0 \ setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 019c088fbd91..9fa78c062ce9 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -12,6 +12,10 @@ attrs==24.2.0 \ --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \ --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 @@ -338,7 +342,7 @@ matplotlib==3.9.2 ; python_version >= "3.11" \ --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \ --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \ --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413 - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -426,7 +430,6 @@ numpy==2.1.2 ; python_version >= "3.13" \ --hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -434,11 +437,14 @@ numpy==2.1.2 ; python_version >= "3.13" \ opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 # via + # auditwheel # build # matplotlib # pytest @@ -553,6 +559,10 @@ psutil==6.0.0 \ --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index e20291cefd63..575dce87664e 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -1,5 +1,6 @@ ################################################################################ -FROM ubuntu:20.04 AS rocm_base +ARG BASE_DOCKER=ubuntu:22.04 +FROM $BASE_DOCKER AS rocm_base ################################################################################ RUN --mount=type=cache,target=/var/cache/apt \ diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 1ec5c6e7f36f..a32f502f377a 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -89,9 +89,9 @@ def dist_wheels( mounts = [ "-v", - "./:/jax", + os.path.abspath("./") + ":/jax", "-v", - "./wheelhouse:/wheelhouse", + os.path.abspath("./wheelhouse") + ":/wheelhouse", ] if xla_path: @@ -143,6 +143,7 @@ def _fetch_jax_metadata(xla_path): def dist_docker( rocm_version, + base_docker, python_versions, xla_path, rocm_build_job="", @@ -168,6 +169,7 @@ def dist_docker( "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, + "--build-arg=BASE_DOCKER=%s" % base_docker, "--build-arg=PYTHON_VERSION=%s" % python_version, "--build-arg=JAX_VERSION=%(jax_version)s" % md, "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, @@ -210,7 +212,7 @@ def test(image_name): # JAX and jaxlib are already installed from wheels mounts = [ "-v", - "./:/jax", + os.path.abspath("./") + ":/jax", ] cmd.extend(mounts) @@ -231,6 +233,12 @@ def test(image_name): def parse_args(): p = argparse.ArgumentParser() + p.add_argument( + "--base-docker", + default="", + help="Argument to override base docker in dockerfile", + ) + p.add_argument( "--python-versions", type=lambda x: x.split(","), @@ -308,6 +316,7 @@ def main(): ) dist_docker( args.rocm_version, + args.base_docker, args.python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 302a0449b19e..386f70ee1a96 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -48,12 +48,12 @@ PYTHON_VERSION="3.10" ROCM_VERSION="6.1.3" ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" -BASE_DOCKER="ubuntu:20.04" +BASE_DOCKER="ubuntu:22.04" CUSTOM_INSTALL="" JAX_USE_CLANG="" POSITIONAL_ARGS=() -RUNTIME_FLAG=1 +RUNTIME_FLAG=0 while [[ $# -gt 0 ]]; do case $1 in @@ -90,6 +90,10 @@ while [[ $# -gt 0 ]]; do ROCM_BUILD_NUM="$2" shift 2 ;; + --base_docker) + BASE_DOCKER="$2" + shift 2 + ;; --use_clang) JAX_USE_CLANG="$2" shift 2 @@ -113,11 +117,13 @@ function upsearch (){ } # Set up WORKSPACE. -WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" -BUILD_TAG="${BUILD_TAG:-jax}" - -# Determine the docker image name and BUILD_TAG. -DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}" +if [ ${RUNTIME_FLAG} -eq 0 ]; then + DOCKER_IMG_NAME="${BUILD_TAG}" +else + WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" + BUILD_TAG="${BUILD_TAG:-jax}" + DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}" +fi # Under Jenkins matrix build, the build tag may contain characters such as # commas (,) and equal signs (=), which are not valid inside docker image names. @@ -152,6 +158,7 @@ fi # which is the ROCm image that is shipped for users to use (i.e. distributable). ./build/rocm/ci_build \ --rocm-version $ROCM_VERSION \ + --base-docker $BASE_DOCKER \ --python-versions $PYTHON_VERSION \ --xla-source-dir=$XLA_CLONE_DIR \ --rocm-build-job=$ROCM_BUILD_JOB \ diff --git a/build/rocm/dev_build_rocm.py b/build/rocm/dev_build_rocm.py index 2be64152f667..aa5754b789d3 100755 --- a/build/rocm/dev_build_rocm.py +++ b/build/rocm/dev_build_rocm.py @@ -77,13 +77,14 @@ def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path): build_command = [ "python3", "./build/build.py", - "--enable_rocm", - "--build_gpu_plugin", - "--gpu_plugin_rocm_version=60", + "build" f"--use_clang={str(use_clang).lower()}", + "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt" + "--rocm_path=%/opt/rocm-{rocm_version}/", + "--rocm_version=60", f"--rocm_amdgpu_targets={rocm_target}", - f"--rocm_path=/opt/rocm-{rocm_version}/", bazel_options, + "--verbose" ] if clang_option: diff --git a/build/rocm/run_multi_gpu.sh b/build/rocm/run_multi_gpu.sh index b5d5798e7920..aa1d4d0f38ed 100755 --- a/build/rocm/run_multi_gpu.sh +++ b/build/rocm/run_multi_gpu.sh @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -eu +set -xu # Function to run tests with specified GPUs run_tests() { diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index deb6ab703391..ec825f40b7d2 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -93,11 +93,12 @@ def build_jaxlib_wheel( cmd = [ "python", "build/build.py", - "--enable_rocm", - "--build_gpu_plugin", - "--gpu_plugin_rocm_version=60", + "build" + "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt" "--rocm_path=%s" % rocm_path, + "--rocm_version=60", "--use_clang=%s" % use_clang, + "--verbose" ] # Add clang path if clang is used. diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index 993c2f94b558..d9cb9ea3f25a 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -229,7 +229,7 @@ def _build_installer_url(rocm_version, metadata): rv = parse_version(rocm_version) - base_url = "http://artifactory-cdn.amd.com/artifactory/list" + base_url = "https://artifactory-cdn.amd.com/artifactory/list" if md["ID"] == "ubuntu": fmt = "amdgpu-install-internal_%(rocm_major)s.%(rocm_minor)s-%(os_version)s-1_all.deb" diff --git a/build/test-requirements.txt b/build/test-requirements.txt index bec6afce1853..94b2bbb965dc 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -6,7 +6,6 @@ filelock flatbuffers hypothesis mpmath>=1.3 -numpy>=1.22 pillow>=10.4.0 portpicker pytest-xdist @@ -14,3 +13,9 @@ wheel rich # TODO(ybaturina): remove setuptools version setuptools<71.0.0 +# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement +# below. +matplotlib~=3.8.4; python_version=="3.10" +matplotlib; python_version>="3.11" +opt-einsum +auditwheel diff --git a/build/tools/command.py b/build/tools/command.py new file mode 100644 index 000000000000..cc95d7eea4af --- /dev/null +++ b/build/tools/command.py @@ -0,0 +1,112 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for the JAX build CLI for running subprocess commands. +import asyncio +import dataclasses +import datetime +import os +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + +class CommandBuilder: + def __init__(self, base_command: str): + self.command = [base_command] + + def append(self, parameter: str): + self.command.append(parameter) + return self + + def get_command_as_string(self) -> str: + return " ".join(self.command) + + def get_command_as_list(self) -> list[str]: + return self.command + +@dataclasses.dataclass +class CommandResult: + """ + Represents the result of executing a subprocess command. + """ + + command: str + return_code: int = 2 # Defaults to not successful + logs: str = "" + start_time: datetime.datetime = dataclasses.field( + default_factory=datetime.datetime.now + ) + end_time: Optional[datetime.datetime] = None + + +async def _process_log_stream(stream, result: CommandResult): + """Logs the output of a subprocess stream.""" + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + +class SubprocessExecutor: + """ + Manages execution of subprocess commands with reusable environment and logging. + """ + + def __init__(self, environment: Dict[str, str] = None): + """ + + Args: + environment: + """ + self.environment = environment or dict(os.environ) + + async def run(self, cmd: str, dry_run: bool = False, detailed_timestamped_log: bool = False) -> CommandResult: + """ + Executes a subprocess command. + + Args: + cmd: The command to execute. + dry_run: If True, prints the command instead of executing it. + + Returns: + A CommandResult instance. + """ + result = CommandResult(command=cmd) + if dry_run: + logger.info("[DRY RUN] %s", cmd) + result.return_code = 0 # Dry run is a success + return result + + logger.info("[EXECUTING] %s", cmd) + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE if detailed_timestamped_log else None, + stderr=asyncio.subprocess.PIPE if detailed_timestamped_log else None, + env=self.environment, + ) + + if detailed_timestamped_log: + await asyncio.gather( + _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) + ) + + result.return_code = await process.wait() + result.end_time = datetime.datetime.now() + logger.debug("Command finished with return code %s", result.return_code) + return result diff --git a/build/tools/utils.py b/build/tools/utils.py new file mode 100644 index 000000000000..5d7c8e0f20b2 --- /dev/null +++ b/build/tools/utils.py @@ -0,0 +1,236 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for tools/utilities used by the JAX build CLI. +import collections +import hashlib +import logging +import os +import pathlib +import platform +import re +import shutil +import stat +import subprocess +import sys +import urllib.request + +logger = logging.getLogger(__name__) + +BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" +BazelPackage = collections.namedtuple( + "BazelPackage", ["base_uri", "file", "sha256"] +) +bazel_packages = { + ("Linux", "x86_64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-linux-x86_64", + sha256=( + "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307" + ), + ), + ("Linux", "aarch64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-linux-arm64", + sha256=( + "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f" + ), + ), + ("Darwin", "x86_64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-darwin-x86_64", + sha256=( + "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29" + ), + ), + ("Darwin", "arm64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-darwin-arm64", + sha256=( + "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb" + ), + ), + ("Windows", "AMD64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-windows-x86_64.exe", + sha256=( + "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6" + ), + ), +} + +def download_and_verify_bazel(): + """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" + package = bazel_packages.get((platform.system(), platform.machine())) + if package is None: + return None + + if not os.access(package.file, os.X_OK): + uri = (package.base_uri or BAZEL_BASE_URI) + package.file + sys.stdout.write(f"Downloading bazel from: {uri}\n") + + def progress(block_count, block_size, total_size): + if total_size <= 0: + total_size = 170**6 + progress = (block_count * block_size) / total_size + num_chars = 40 + progress_chars = int(num_chars * progress) + sys.stdout.write( + "{} [{}{}] {}%\r".format( + package.file, + "#" * progress_chars, + "." * (num_chars - progress_chars), + int(progress * 100.0), + ) + ) + + tmp_path, _ = urllib.request.urlretrieve( + uri, None, progress if sys.stdout.isatty() else None + ) + sys.stdout.write("\n") + + # Verify that the downloaded Bazel binary has the expected SHA256. + with open(tmp_path, "rb") as downloaded_file: + contents = downloaded_file.read() + + digest = hashlib.sha256(contents).hexdigest() + if digest != package.sha256: + print( + "Checksum mismatch for downloaded bazel binary (expected {}; got {})." + .format(package.sha256, digest) + ) + sys.exit(-1) + + # Write the file as the bazel file name. + with open(package.file, "wb") as out_file: + out_file.write(contents) + + # Mark the file as executable. + st = os.stat(package.file) + os.chmod( + package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH + ) + + return os.path.join(".", package.file) + +def get_bazel_paths(bazel_path_flag): + """Yields a sequence of guesses about bazel path. + + Some of sequence elements can be None. The resulting iterator is lazy and + potentially has a side effects. + """ + yield bazel_path_flag + yield shutil.which("bazel") + yield download_and_verify_bazel() + +def get_bazel_path(bazel_path_flag): + """Returns the path to a Bazel binary, downloading Bazel if not found. + + Also, checks Bazel's version is at least newer than 6.5.0 + + A manual version check is needed only for really old bazel versions. + Newer bazel releases perform their own version check against .bazelversion + (see for details + https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). + """ + for path in filter(None, get_bazel_paths(bazel_path_flag)): + version = get_bazel_version(path) + if version is not None and version >= (6, 5, 0): + return path, ".".join(map(str, version)) + + print( + "Cannot find or download a suitable version of bazel." + "Please install bazel >= 6.5.0." + ) + sys.exit(-1) + +def get_bazel_version(bazel_path): + try: + version_output = subprocess.run( + [bazel_path, "--version"], + encoding="utf-8", + capture_output=True, + check=True, + ).stdout.strip() + except (subprocess.CalledProcessError, OSError): + return None + match = re.search(r"bazel *([0-9\\.]+)", version_output) + if match is None: + return None + return tuple(int(x) for x in match.group(1).split(".")) + +def get_clang_path_or_exit(): + which_clang_output = shutil.which("clang") + if which_clang_output: + # If we've found a clang on the path, need to get the fully resolved path + # to ensure that system headers are found. + return str(pathlib.Path(which_clang_output).resolve()) + else: + print( + "--clang_path is unset and clang cannot be found" + " on the PATH. Please pass --clang_path directly." + ) + sys.exit(-1) + +def get_clang_major_version(clang_path): + clang_version_proc = subprocess.run( + [clang_path, "-E", "-P", "-"], + input="__clang_major__", + check=True, + capture_output=True, + text=True, + ) + major_version = int(clang_version_proc.stdout) + + return major_version + +def get_jax_configure_bazel_options(bazel_command: list[str]): + """Returns the bazel options to be written to .jax_configure.bazelrc.""" + # Get the index of the "run" parameter. Build options will come after "run" so + # we find the index of "run" and filter everything after it. + start = bazel_command.index("run") + jax_configure_bazel_options = "" + try: + for i in range(start + 1, len(bazel_command)): + bazel_flag = bazel_command[i] + # On Windows, replace all backslashes with double backslashes to avoid + # unintended escape sequences. + if platform.system() == "Windows": + bazel_flag = bazel_flag.replace("\\", "\\\\") + jax_configure_bazel_options += f"build {bazel_flag}\n" + return jax_configure_bazel_options + except ValueError: + logging.error("Unable to find index for 'run' in the Bazel command") + return "" + +def get_githash(): + try: + return subprocess.run( + ["git", "rev-parse", "HEAD"], + encoding="utf-8", + capture_output=True, + check=True, + ).stdout.strip() + except OSError: + return "" + +def _parse_string_as_bool(s): + """Parses a string as a boolean value.""" + lower = s.lower() + if lower == "true": + return True + elif lower == "false": + return False + else: + raise ValueError(f"Expected either 'true' or 'false'; got {s}") diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 000000000000..ea867df52f97 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,10 @@ +# JAX continuous integration + +> [!WARNING] +> This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> JAX repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +******************************************************************************** \ No newline at end of file diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh new file mode 100644 index 000000000000..698de38418b7 --- /dev/null +++ b/ci/build_artifacts.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +## +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Build JAX artifacts. +# Usage: ./ci/build_artifacts.sh "" +# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt +# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib" +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +artifact="$1" + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt") + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Adjust the values when running on Windows x86 to match the config in +# .bazelrc +if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then + os="windows" + arch="amd64" +fi + +if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then + + # Build the jax artifact + if [[ "$artifact" == "jax" ]]; then + python -m build --outdir $JAXCI_OUTPUT_DIR + else + + # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" + # flags in the .bazelrc depending upon the platform we are building for. + bazelrc_config="${os}_${arch}" + + # TODO(b/379903748): Add remote cache options for Linux and Windows. + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then + bazelrc_config="rbe_${bazelrc_config}" + else + bazelrc_config="ci_${bazelrc_config}" + fi + + # Use the "_cuda" configs when building the CUDA artifacts. + if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then + bazelrc_config="${bazelrc_config}_cuda" + fi + + # Build the artifact. + python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log + + # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we + # run `auditwheel show` to verify manylinux compliance. + if [[ "$os" == "linux" ]]; then + ./ci/utilities/run_auditwheel.sh + fi + + fi + +else + echo "Error: Invalid artifact: $artifact. Allowed values are: ${allowed_artifacts[@]}" + exit 1 +fi \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env new file mode 100644 index 000000000000..ae434dc61c8f --- /dev/null +++ b/ci/envs/default.env @@ -0,0 +1,69 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file contains all the default values for the "JAXCI_" environment +# variables used in the CI scripts. These variables are used to control the +# behavior of the CI scripts such as the Python version used, path to JAX/XLA +# repo, if to clone XLA repo, etc. + +# The path to the JAX git repository. +export JAXCI_JAX_GIT_DIR=$(pwd) + +# Controls the version of Hermetic Python to use. Use system default if not +# set. +export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} + +# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local +# copy of XLA instead of the pinned version in the WORKSPACE. When +# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} + +# If set to 1, the builds will clone the XLA repository at HEAD and set its +# path in JAXCI_XLA_GIT_DIR. +export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} + +# Allows overriding the XLA commit that is used. +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} + +# Controls the location where the artifacts are written to. +export JAXCI_OUTPUT_DIR="$(pwd)/dist" + +# When enabled, artifacts will be built with RBE. Requires gcloud authentication +# and only certain platforms support RBE. Therefore, this flag is enabled only +# for CI builds where RBE is supported. +export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} + +# ############################################################################# +# Test script specific environment variables. +# ############################################################################# +# The maximum number of tests to run per GPU when running single accelerator +# tests with parallel execution with Bazel. The GPU limit is set because we +# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we +# use L4 machines which have 24GB of RAM but can be overriden if we use a +# different GPU type. +export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} + +# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override +# this value in the Github action workflow files. +export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} + +# Pytest specific environment variables below. Used in run_pytest_*.sh scripts. +# Sets the number of TPU cores for the TPU machine type. These values are +# defined in the TPU GitHub Actions workflow. +export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} + +# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels +# on the system. By default, it is set to match the version of the hermetic +# Python used by Bazel for building the wheels. +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh new file mode 100755 index 000000000000..248111e0247a --- /dev/null +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel CPU tests with RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel CPU tests with RBE. +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# When running on Mac or Linux Aarch64, we only build the test targets and +# not run them. These platforms do not have native RBE support so we +# RBE cross-compile them on remote Linux x86 machines. As the tests still +# need to be run on the host machine and because running the tests on a +# single machine can take a long time, we skip running them on these +# platforms. +if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + echo "Building RBE CPU tests..." + bazel build --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + //tests:cpu_tests //tests:backend_independent_tests +else + echo "Running RBE CPU tests..." + bazel test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + //tests:cpu_tests //tests:backend_independent_tests +fi \ No newline at end of file diff --git a/ci/run_bazel_test_gpu_non_rbe.sh b/ci/run_bazel_test_gpu_non_rbe.sh new file mode 100755 index 000000000000..7828cf41c60e --- /dev/null +++ b/ci/run_bazel_test_gpu_non_rbe.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Run Bazel GPU tests without RBE. This runs two commands: single accelerator +# tests with one GPU a piece, multiaccelerator tests with all GPUS. +# Requires that jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are stored +# inside the ../dist folder +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel GPU tests (single accelerator and multiaccelerator tests) directly +# on the VM without RBE. +nvidia-smi +echo "Running single accelerator tests (without RBE)..." + +# Set up test environment variables. +export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +export num_test_jobs=$((gpu_count * JAXCI_MAX_TESTS_PER_GPU)) +export num_cpu_cores=$(nproc) + +# tests_jobs = max(gpu_count * max_tests_per_gpu, num_cpu_cores) +if [[ $num_test_jobs -gt $num_cpu_cores ]]; then + num_test_jobs=$num_cpu_cores +fi +# End of test environment variables setup. + +# Runs single accelerator tests with one GPU apiece. +# It appears --run_under needs an absolute path. +# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` +# should match the VM's CPU core count (set in `--local_test_jobs`). +bazel test --config=ci_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --//jax:build_jaxlib=false \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=$gpu_count \ + --test_env=JAX_TESTS_PER_ACCELERATOR=$JAXCI_MAX_TESTS_PER_GPU \ + --local_test_jobs=$num_test_jobs \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + +echo "Running multi-accelerator tests (without RBE)..." +# Runs multiaccelerator tests with all GPUs directly on the VM without RBE.. +bazel test --config=ci_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --//jax:build_jaxlib=false \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --jobs=8 \ + --test_tag_filters=multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests/pallas:gpu_tests \ No newline at end of file diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh new file mode 100755 index 000000000000..17bd8d9db4f8 --- /dev/null +++ b/ci/run_bazel_test_gpu_rbe.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one +# GPU apiece on RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece). +echo "Running RBE GPU tests..." + +bazel test --config=rbe_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh new file mode 100644 index 000000000000..2b19ca5ddaa5 --- /dev/null +++ b/ci/run_pytest_cpu.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pyest CPU tests. Requires a jaxlib wheel to be present +# inside the $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +# Set up all test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export TF_CPP_MIN_LOG_LEVEL=0 +# End of test environment variable setup + +echo "Running CPU tests..." +"$JAXCI_PYTHON" -m pytest -n auto --tb=short --maxfail=20 tests examples \ No newline at end of file diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh new file mode 100644 index 000000000000..7bc2492781b2 --- /dev/null +++ b/ci/run_pytest_gpu.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt +# wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels inside the +# $JAXCI_OUTPUT_DIR directory on the system. +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +nvidia-smi + +# Set up all test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export NCCL_DEBUG=WARN +export TF_CPP_MIN_LOG_LEVEL=0 + +# Set the number of processes to run to be 4x the number of GPUs. +export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +export num_processes=`expr 4 \* $gpu_count` + +export XLA_PYTHON_CLIENT_ALLOCATOR=platform +export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 +# End of test environment variable setup + +echo "Running GPU tests..." +"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ +tests examples \ +--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ +--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ +--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh new file mode 100644 index 000000000000..783d2f9feca5 --- /dev/null +++ b/ci/run_pytest_tpu.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pyest CPU tests. Requires a jaxlib wheel to be present +# inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' +"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' +"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' +strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on' +"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' + +echo "Running TPU tests..." +export JAX_PLATFORMS=tpu,cpu +# Run single-accelerator tests in parallel +export JAX_ENABLE_TPU_XDIST=true + +"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ +--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ +--maxfail=20 -m "not multiaccelerator" tests examples + +# Run Pallas printing tests, which need to run with I/O capturing disabled. +export TPU_STDERR_LOG_LEVEL=0 +"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + +# Run multi-accelerator across all chips +"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests \ No newline at end of file diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py new file mode 100644 index 000000000000..6164e6a5e29d --- /dev/null +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -0,0 +1,80 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Converts MSYS Linux-like paths stored in env variables to Windows paths. + +This is necessary on Windows, because some applications do not understand/handle +Linux-like paths MSYS uses, for example, Bazel. +""" +import argparse +import os +import subprocess + +def msys_to_windows_path(msys_path): + """Converts an MSYS path to a Windows path using cygpath. + + Args: + msys_path: The MSYS path to convert. + + Returns: + The corresponding Windows path. + """ + try: + # Use cygpath with the -w flag to convert to Windows format + process = subprocess.run(['cygpath', '-w', msys_path], capture_output=True, text=True, check=True) + windows_path = process.stdout.strip() + return windows_path + except FileNotFoundError: + print("Error: cygpath not found. Make sure it's in your PATH.") + return None + except subprocess.CalledProcessError as e: + print(f"Error converting path: {e}") + return None + +def should_convert(var: str, + convert: list[str] | None): + """Check the variable name against convert list""" + if var in convert: + return True + else: + return False + +def main(parsed_args: argparse.Namespace): + converted_paths = {} + + for var, value in os.environ.items(): + if not value or not should_convert(var, + parsed_args.convert): + continue + converted_path = msys_to_windows_path(value) + converted_paths[var] = converted_path + + var_str = '\n'.join(f'export {k}="{v}"' + for k, v in converted_paths.items()) + # The string can then be piped into `source`, to re-set the + # 'converted' variables. + print(var_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=( + 'Convert MSYS paths in environment variables to Windows paths.')) + parser.add_argument('--convert', + nargs='+', + required=True, + help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2') + args = parser.parse_args() + + main(args) diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh new file mode 100644 index 000000000000..181256b90804 --- /dev/null +++ b/ci/utilities/install_wheels_locally.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python +# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to +# avoid using the Windows version of `find` on Msys. +WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) + +if [[ -z "$WHEELS" ]]; then + echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" + exit 1 +fi + +echo "Installing the following wheels:" +echo "${WHEELS[@]}" +"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" + +echo "Installing the JAX package in editable mode at the current commit..." +# Install JAX package at the current commit. +"$JAXCI_PYTHON" -m pip install -U -e . diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh new file mode 100644 index 000000000000..30b6a3b51865 --- /dev/null +++ b/ci/utilities/run_auditwheel.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Runs auditwheel to verify manylinux compatibility. + +# Get a list of all the wheels in the output directory. Only look for wheels +# that need to be verified for manylinux compliance. +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" \)) + +if [[ -z "$WHEELS" ]]; then + echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" + exit 1 +fi + +for wheel in $WHEELS; do + printf "\nRunning auditwheel on the following wheel:" + ls $wheel + OUTPUT_FULL=$(python -m auditwheel show $wheel) + # Remove the wheel name from the output to avoid false positives. + wheel_name=$(basename $wheel) + OUTPUT=${OUTPUT_FULL//${wheel_name}/} + + # If a wheel is manylinux2014 compliant, `auditwheel show` will return the + # platform tag as manylinux_2_17. manylinux2014 is an alias for + # manylinux_2_17. + if echo "$OUTPUT" | grep -q "manylinux_2_17"; then + printf "\n$wheel_name is manylinux2014 compliant.\n" + else + echo "$OUTPUT_FULL" + printf "\n$wheel_name is NOT manylinux2014 compliant.\n" + exit 1 + fi +done \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh new file mode 100644 index 000000000000..964a6e4ac679 --- /dev/null +++ b/ci/utilities/setup_build_environment.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Set up the build environment for JAX CI jobs. This script depends on the +# "JAXCI_" environment variables set or sourced in the build script. + +# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# jobs running on Linux runners in GitHub Actions. Without this, git complains +# that the directory has dubious ownership and refuses to run any commands. +# Avoid running on Windows runners as git runs into issues with not being able +# to lock the config file. Other git commands seem to work on the Windows +# runners so we can skip this step for Windows. +# TODO(b/375073267): Remove this once we understand why git repositories are +# being marked as unsafe inside the self-hosted runners. +if [[ ! $(uname -s) =~ "MSYS_NT" ]]; then + git config --global --add safe.directory $JAXCI_JAX_GIT_DIR +fi + +function clone_main_xla() { + echo "Cloning XLA at HEAD to $(pwd)/xla" + git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + export JAXCI_XLA_GIT_DIR=$(pwd)/xla +} + +# Clone XLA at HEAD if required. +if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + # Clone only if $(pwd)/xla does not exist to avoid failure on re-runs. + if [[ ! -d $(pwd)/xla ]]; then + clone_main_xla + else + echo "JAXCI_CLONE_MAIN_XLA set but local XLA folder already exists: $(pwd)/xla so using that instead." + # Set JAXCI_XLA_GIT_DIR if local XLA already exists + export JAXCI_XLA_GIT_DIR=$(pwd)/xla + fi +fi + +# If a XLA commit is provided, check out XLA at that commit. +if [[ ! -z "$JAXCI_XLA_COMMIT" ]]; then + # Clone XLA at HEAD if a path to local XLA is not provided. + if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + clone_main_xla + fi + pushd "$JAXCI_XLA_GIT_DIR" + + git fetch --depth=1 origin "$JAXCI_XLA_COMMIT" + echo "JAXCI_XLA_COMMIT is set. Checking out XLA at $JAXCI_XLA_COMMIT" + git checkout "$JAXCI_XLA_COMMIT" + + popd +fi + +if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then + echo "INFO: Overriding XLA to be read from $JAXCI_XLA_GIT_DIR instead of the" + echo "pinned version in the WORKSPACE." + echo "If you would like to revert this behavior, unset JAXCI_CLONE_MAIN_XLA" + echo "and JAXCI_XLA_COMMIT in your environment. Note that the Bazel RBE test" + echo "commands overrides the XLA repository and thus require a local copy of" + echo "XLA to run." +fi + +# On Windows, convert MSYS Linux-like paths to Windows paths. +if [[ $(uname -s) =~ "MSYS_NT" ]]; then + echo 'Converting MSYS Linux-like paths to Windows paths (for Bazel, Python, etc.)' + # Convert all "JAXCI.*DIR" variables + source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR" | awk -F= '{print $1}')) +fi \ No newline at end of file diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index fcb7b570e493..2163272e2542 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the gradient. (And if you implement the interface to support vmat, it will also be on the outer primitive). -JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic. +JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic. XLA sharding goes in two phases: a sharding propagation phase and a partition phase. -The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph. +The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph. For XLA to be able to shard our custom operations, it needs us to define 2 extra functions: infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively. The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding. The partition() function will do a few things: -- tell which input sharding will be expected. XLA will reshad if needed. +- tell which input sharding will be expected. XLA will reshard if needed. - tell the final version of the output sharding. - give a function that will create the new instruction from the sharded inputs. @@ -679,7 +679,7 @@ class RmsNormFwdClass: NamedSharding(mesh, PartitionSpec(None, None))) invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) @@ -739,7 +739,7 @@ class RmsNormBwdClass: output_shardings = (output_sharding, invvar_sharding, invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables def impl(g, invvar, x, weight): grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 31a00c49071e..1cdf67c41a90 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -353,7 +353,7 @@ def partition(eps: float, mesh : jax.sharding.Mesh, NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything. invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) diff --git a/docs/_static/jax-hero.svg b/docs/_static/jax-hero.svg new file mode 100644 index 000000000000..04626f43eacd --- /dev/null +++ b/docs/_static/jax-hero.svg @@ -0,0 +1,118 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/pallas/vector_layout_example.svg b/docs/_static/pallas/vector_layout_example.svg new file mode 100644 index 000000000000..f1c9403573d8 --- /dev/null +++ b/docs/_static/pallas/vector_layout_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/style.css b/docs/_static/style.css index 296912ace2c8..36b54b8432f0 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,34 +1,280 @@ @import url("theme.css"); +@import url('https://fonts.googleapis.com/css2?family=Google+Sans'); + +/* Base LP sidebar modifications */ +body:has(.hero) .sidebar-toggle, +body:has(.hero) .bd-sidebar-secondary { + display: none !important; +} + +body:has(.hero) .search-button { + display: flex !important; +} + +body:has(.hero) .primary-toggle { + display: inline-block !important; +} + +body:has(.hero) .prev-next-footer { + display: none; +} + +body:has(.hero) .bd-article-container { + max-width: unset !important; +} + +body:has(.hero) .bd-page-width { + max-width: unset !important; +} + +body:has(.hero) .bd-article { + display: flex; + flex-direction: column; + padding: 0; +} + +body:has(.hero) .bd-container { + flex-direction: column; +} + +@media (min-width: 960px) { + body:has(.hero) .bd-header-article { + justify-content: center; + } + + body:has(.hero) .header-article-items, + body:has(.hero) .bd-article > section { + max-width: 65rem !important; + align-self: center; + width: -moz-available; + width: -webkit-fill-available; + width: fill-available; + } +} + +/* Custom CSS */ :root { --block-bg-opacity: .5; } +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) { + padding: 0; +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * { + padding-inline: 2rem !important; +} + +@media (max-width: 768px) { + .bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * { + padding-inline: 1rem !important; + } +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) h1 { + display: none; +} + .wy-side-nav-search { background-color: #fff; } -.getting-started { - background-color: rgba(78, 150, 253, var(--block-bg-opacity)); +.getting-started, +.user-guides, +.installation { + background: #3C4043; + color: white; + height: 170px; + border: none !important; + border-radius: 12px; +} + +.getting-started:hover, +.user-guides:hover, +.installation:hover { + background: #AECBFA; + color: #202124; + transform: unset !important; +} + +.getting-started .sd-card-body, +.user-guides .sd-card-body, +.installation .sd-card-body { + display: flex; + align-items: center; + justify-content: center; + font: 500 24px 'Roboto', sans-serif; +} + +.getting-started .sd-card-title, +.user-guides .sd-card-title, +.installation .sd-card-title { + display: flex; + flex-direction: column; + align-items: center; + gap: 12px; +} + +.getting-started svg, +.user-guides svg, +.installation svg { + color: #8AB4F8; +} + +.getting-started:hover svg, +.user-guides:hover svg, +.installation:hover svg { + color: #3C4043; +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > .hero { + padding-inline: 2rem 0 !important; } -.user-guides { - background-color: rgba(0, 169, 154, var(--block-bg-opacity)); +.hero { + display: grid; + grid: auto-flow / 1fr .6fr; + align-items: center; + background: rgb(32,33,36); + background: linear-gradient(90deg, rgba(32,33,36,1) 0%, rgba(39,45,56,1) 100%); + position: relative; + overflow: hidden; + border-radius: 24px; } -.developer-docs { - background-color: rgba(171, 0, 182, var(--block-bg-opacity)); +.hero > img { + position: absolute; + top: 0; + right: 0; + height: 100%; + background: transparent !important; +} + +.hero-left { + padding-block: 24px; + display: flex; + flex-direction: column; +} + +.hero-left img { + width: 100px; + height: auto; + position: relative; + margin-bottom: 16px; + background: transparent !important; +} + +.hero-left h2 { + font: 500 32px 'Google Sans', 'Roboto', sans-serif; + color: white; + margin-top: 0; +} + +.hero-left p { + font: 400 16px 'Roboto', sans-serif; + color: white; +} + +@media (max-width: 1295px) { + .hero > img { + right: -75px; + } +} + +@media (max-width: 750px) { + .hero { + grid: auto-flow / 1fr; + } + + .hero-left { + padding-right: 2rem; + } + + .hero > img { + display: none; + } +} + +.product-offerings { + margin-block: 32px !important; +} + +.product-offerings .sd-card-title { + font: 400 24px 'Google Sans', 'Roboto', sans-serif; +} + +.color-cards { + background: #E8EAED; + color: #222832; + padding: 48px 12px 0 12px; + margin-bottom: 0 !important; + border-radius: 24px 24px 0 0; +} + +.color-cards > div { + gap: 24px 0; +} + +.color-cards + p { + background: #E8EAED; + padding: 24px 12px 48px 12px; + font-weight: 600; + color: #222832; + border-radius: 0 0 24px 24px; +} + +.color-cards + p > a { + color: #222832; +} + +.color-cards + p > a:hover, +html[data-theme="dark"] .color-cards + p > a:hover { + color: #e89217; +} + +html[data-theme="dark"] .color-cards, +html[data-theme="dark"] .hero, +html[data-theme="dark"] .color-cards + p, +html[data-theme="dark"] .color-cards + p > a { + background: #202124; + color: white; } .ecosystem-grid { font-size: smaller; } +.ecosystem-grid > div { + gap: 20px; +} + +.ecosystem-grid .sd-col { + border: 1px solid #dadce0; + border-radius: 8px; + width: calc(50% - 10px); + padding: 16px; +} + +.ecosystem-grid .sd-col > p { + display: flex; + flex-direction: column; + gap: 10px; +} + +.ecosystem-grid .sd-col > p > svg { + color: #00897B; +} + .ecosystem-grid ul { list-style-type: none; padding-inline-start: 0.5em; } +.ecosystem-grid a { + text-decoration: none; +} + div.red-background pre { background-color: rgba(244, 204, 204, var(--block-bg-opacity)); } diff --git a/docs/about.md b/docs/about.md new file mode 100644 index 000000000000..c4bc93140fbc --- /dev/null +++ b/docs/about.md @@ -0,0 +1,123 @@ +(about-the-project)= + +# About the project + +The JAX project is led by the JAX core team. We develop in the open, +and welcome open-source contributions from across the community. We +frequently see contributions from [Google +DeepMind](https://deepmind.google/), Alphabet more broadly, +[NVIDIA](https://docs.nvidia.com/deeplearning/frameworks/jax-release-notes/overview.html), +and elsewhere. + +At the heart of the project is the [JAX +core](http://github.com/google/jax) library, which focuses on the +fundamentals of machine learning and numerical computing, at scale. + +When [developing](#development) the core, we want to maintain agility +and a focused scope, so we lean heavily on a surrounding [modular +technology stack](#components). First, we design the `jax` module +to be +[composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) +and +[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +that a wide variety of domain-specific libraries can thrive outside of +it in a decentralized manner. Second, we lean heavily on a modular +backend stack (compiler and runtime) to target different +accelerators. Whether you are [writing a new domain-specific library +built with JAX](#upstack), or looking to [support +new hardware](#downstack), you can often +contribute these with *minimal to no modifications* to the JAX core +codebase. + +Many of JAX's core contributors have roots in open-source software and +in research, in fields spanning computer science and the natural +sciences. We strive to continuously enable the cutting edge of machine +learning and numerical computing---across all compute platforms and +accelerators---and to discover the truths of array programming at +scale. + +(development)= +## Open development + +JAX's day-to-day development takes place in the open on GitHub, using +pull requests, the issue tracker, discussions, and [JAX Enhancement +Proposals +(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +and participating in these is a good way to get involved. We also +maintain [developer +notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +that cover JAX's internal design. + +The JAX core team determines whether to accept changes and +enhancements. Maintaining a simple decision-making structure currently +helps us develop at the speed of the research frontier. Open +development is a core value of ours, and we may adapt to a more +intricate decision structure over time (e.g. with designated area +owners) if/when it becomes useful to do so. + +For more see [contributing to +JAX](https://jax.readthedocs.io/en/latest/contributing.html). + +(components)= +## A modular stack + +To enable (a) a growing community of users across numerical domains, +and (b) an advancing hardware landscape, we lean heavily on +**modularity**. + +(upstack)= +### Libraries built on JAX + +While the JAX core library focuses on the fundamentals, we want to +encourage domain-specific libraries and tools to be built on top of +JAX. Indeed, [many +libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +emerged around JAX to offer higher-level features and extensions. + +How do we encourage such decentralized development? We guide it with +several technical choices. First, JAX's main API focuses on basic +building blocks (e.g. numerical primitives, NumPy operations, arrays, +and transformations), encouraging auxiliary libraries to develop +utilities as needed for their domain. In addition, JAX exposes a +handful of more advanced APIs for +[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +and +[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +can [lean on these +APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +order to use JAX as an internal means of implementation, to integrate +more with its transformations like autodiff, and more. + +Projects across the JAX ecosystem are developed in a distributed and +often open fashion. They are not governed by the JAX core team, even +though sometimes team members contribute to them or maintain contact +with their developers. + +(downstack)= +### A pluggable backend + +We want JAX to run on CPUs, GPUs, TPUs, and other hardware platforms +as they emerge. To encourage unhindered support of JAX on new +platforms, the JAX core emphasizes modularity in its backend too. + +To manage hardware devices and memory, and for compilation to such +devices, JAX calls out to the open [XLA +compiler](https://openxla.org/) and the [PJRT +runtime](https://github.com/openxla/xla/tree/main/xla/pjrt/c#pjrt---uniform-device-api). Both +of these are projects external to the JAX core, governed and +maintained by OpenXLA (again, with frequent contributions from and +discussion with the JAX core developers). + +XLA aims for interoperability across accelerators (e.g. by ingesting +[StableHLO](https://openxla.org/stablehlo) as input) and PJRT offers +extensibility through a plug-in device API. Adding support for new +devices is done by implementing a backend lowering for XLA, and +implementing a plug-in device API defined by PJRT. If you're looking +to contribute to compilation, or to supporting new hardware, we +encourage you to contribute at the XLA and PJRT layers. + +These open system components allow third parties to support JAX on new +accelerator platforms, *without requiring changes in the JAX +core*. There are several plug-ins in development today. For example, a +team at Apple is working on a PJRT plug-in to get [JAX running on +Apple Metal](https://developer.apple.com/metal/jax/). diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index 023dc8040954..c56e82c77450 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -350,7 +350,7 @@ This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \math and so on. -To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of the two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. +To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. ## How it's made: Two foundational autodiff functions @@ -475,7 +475,7 @@ where we use `CT a` to denote the type for the cotangent space for `a`. In words This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. -There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). +There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). @@ -1762,7 +1762,6 @@ print(grad(app, 1)(lambda x: x ** 2, 4.)) Refer to `fixed_point` above for another usage example. **You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments. -s ## Next steps diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 8b418b16f878..e620967de4b7 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2797,7 +2797,7 @@ "representing unknown outputs, we need avals, which we get from the abstract\n", "eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n", "`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n", - "weakrefs.)\n", + "`weakref`s.)\n", "\n", "That `process_primitive` logic applies to most primitives, but `xla_call_p`\n", "requires recursive treatment. So we special-case its rule in a\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 9e726e5ed82e..1c16db80f608 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -2195,7 +2195,7 @@ output. If instead any input is unknown then we instead stage out into a representing unknown outputs, we need avals, which we get from the abstract eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using -weakrefs.) +`weakref`s.) That `process_primitive` logic applies to most primitives, but `xla_call_p` requires recursive treatment. So we special-case its rule in a diff --git a/docs/autodidax.py b/docs/autodidax.py index f57af2cd96f2..f74617f31416 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -2187,7 +2187,7 @@ def full_lower(self): # representing unknown outputs, we need avals, which we get from the abstract # eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and # `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using -# weakrefs.) +# `weakref`s.) # # That `process_primitive` logic applies to most primitives, but `xla_call_p` # requires recursive treatment. So we special-case its rule in a diff --git a/docs/conf.py b/docs/conf.py index d57420dec881..8007c0b3d828 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,7 +49,7 @@ def _do_not_evaluate_in_jax( # -- Project information ----------------------------------------------------- project = 'JAX' -copyright = '2024, The JAX Authors. NumPy and SciPy documentation are copyright the respective authors.' +copyright = '2024, The JAX Authors' author = 'The JAX authors' # The short X.Y version diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index 55094fc88958..f89122f944cc 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -25,4 +25,3 @@ some of JAX's (extensible) internals. autodidax jep/index - jax_internal_api diff --git a/docs/control-flow.md b/docs/control-flow.md new file mode 100644 index 000000000000..7cb959f3e434 --- /dev/null +++ b/docs/control-flow.md @@ -0,0 +1,394 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + ++++ {"id": "rg4CpMZ8c3ri"} + +(control-flow)= +# Control flow and logical operators with JIT + + + +When executing eagerly (outside of `jit`), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators with `jit` is more complicated. + +In a nutshell, Python control flow and logical operators are evaluated at JIT compile time, such that the compiled function represents a single path through the [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph) (logical operators affect the path via short-circuiting). If the path depends on the values of the inputs, the function (by default) cannot be JIT compiled. The path may depend on the shape or dtype of the inputs, and the function is re-compiled every time it is called on an input with a new shape or dtype. + +```{code-cell} +from jax import grad, jit +import jax.numpy as jnp +``` + +For example, this works: + +```{code-cell} +:id: OZ_BJX0CplNC +:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c + +@jit +def f(x): + for i in range(3): + x = 2 * x + return x + +print(f(3)) +``` + ++++ {"id": "22RzeJ4QqAuX"} + +So does this: + +```{code-cell} +:id: pinVnmRWp6w6 +:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422 + +@jit +def g(x): + y = 0. + for i in range(x.shape[0]): + y = y + x[i] + return y + +print(g(jnp.array([1., 2., 3.]))) +``` + ++++ {"id": "TStltU2dqf8A"} + +But this doesn't, at least by default: + +```{code-cell} +:id: 9z38AIKclRNM +:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac +:tags: [raises-exception] + +@jit +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +# This will fail! +f(2) +``` + +Neither does this: + +```{code-cell} +:tags: [raises-exception] + +@jit +def g(x): + return (x > 0) and (x < 3) + +# This will fail! +g(2) +``` + ++++ {"id": "pIbr4TVPqtDN"} + +__What gives!?__ + +When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation. + +For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. + +To get a view of your Python code that is valid for many different argument values, JAX traces it with the `ShapedArray` abstraction as input, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. + +But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace. + +The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnames` (or `static_argnums`) argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again: + +```{code-cell} +:id: -Tzp0H7Bt1Sn +:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a + +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +f = jit(f, static_argnames='x') + +print(f(2.)) +``` + ++++ {"id": "MHm1hIQAvBVs"} + +Here's another example, this time involving a loop: + +```{code-cell} +:id: iwY86_JKvD6b +:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937 + +def f(x, n): + y = 0. + for i in range(n): + y = y + x[i] + return y + +f = jit(f, static_argnames='n') + +f(jnp.array([2., 3., 4.]), 2) +``` + ++++ {"id": "nSPTOX8DvOeO"} + +In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation + ++++ {"id": "wWdg8LTYwCW3"} + +️⚠️ **functions with argument-__value__ dependent shapes** + +These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`. + +```{code-cell} +:id: Tqe9uLmUI_Gv +:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883 + +def example_fun(length, val): + return jnp.ones((length,)) * val +# un-jit'd works fine +print(example_fun(5, 4)) +``` + +```{code-cell} +:id: fOlR54XRgHpd +:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71 +:tags: [raises-exception] + +bad_example_jit = jit(example_fun) +# this will fail: +bad_example_jit(10, 4) +``` + +```{code-cell} +:id: kH0lOD4GgFyI +:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade + +# static_argnames tells JAX to recompile on changes at these argument positions: +good_example_jit = jit(example_fun, static_argnames='length') +# first compile +print(good_example_jit(10, 4)) +# recompiles +print(good_example_jit(5, 4)) +``` + ++++ {"id": "MStx_r2oKxpp"} + +`static_argnames` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! + +Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: + +```{code-cell} +:id: m2ABpRd8K094 +:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3 + +@jit +def f(x): + print(x) + y = 2 * x + print(y) + return y +f(2) +``` + ++++ {"id": "uCDcWG4MnVn-"} + +## Structured control flow primitives + +There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives: + + - `lax.cond` _differentiable_ + - `lax.while_loop` __fwd-mode-differentiable__ + - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static. + - `lax.scan` _differentiable_ + ++++ {"id": "Sd9xrLMXeK3A"} + +### `cond` +python equivalent: + +```python +def cond(pred, true_fun, false_fun, operand): + if pred: + return true_fun(operand) + else: + return false_fun(operand) +``` + +```{code-cell} +:id: SGxz9JOWeiyH +:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3 + +from jax import lax + +operand = jnp.array([0.]) +lax.cond(True, lambda x: x+1, lambda x: x-1, operand) +# --> array([1.], dtype=float32) +lax.cond(False, lambda x: x+1, lambda x: x-1, operand) +# --> array([-1.], dtype=float32) +``` + ++++ {"id": "lIYdn1woOS1n"} + +`jax.lax` provides two other functions that allow branching on dynamic predicates: + +- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is + like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays + rather than as functions. +- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is + like `lax.cond`, but allows switching between any number of callable choices. + +In addition, `jax.numpy` provides several numpy-style interfaces to these functions: + +- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with + three arguments is the numpy-style wrapper of `lax.select`. +- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) + is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. +- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has + an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather + than as functions. It is implemented in terms of multiple calls to `lax.select`. + ++++ {"id": "xkOFAw24eOMg"} + +### `while_loop` + +python equivalent: +``` +def while_loop(cond_fun, body_fun, init_val): + val = init_val + while cond_fun(val): + val = body_fun(val) + return val +``` + +```{code-cell} +:id: jM-D39a-c436 +:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e + +init_val = 0 +cond_fun = lambda x: x < 10 +body_fun = lambda x: x+1 +lax.while_loop(cond_fun, body_fun, init_val) +# --> array(10, dtype=int32) +``` + ++++ {"id": "apo3n3HAeQY_"} + +### `fori_loop` +python equivalent: +``` +def fori_loop(start, stop, body_fun, init_val): + val = init_val + for i in range(start, stop): + val = body_fun(i, val) + return val +``` + +```{code-cell} +:id: dt3tUpOmeR8u +:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380 + +init_val = 0 +start = 0 +stop = 10 +body_fun = lambda i,x: x+i +lax.fori_loop(start, stop, body_fun, init_val) +# --> array(45, dtype=int32) +``` + ++++ {"id": "SipXS5qiqk8e"} + +### Summary + +$$ +\begin{array} {r|rr} +\hline \ +\textrm{construct} +& \textrm{jit} +& \textrm{grad} \\ +\hline \ +\textrm{if} & ❌ & ✔ \\ +\textrm{for} & ✔* & ✔\\ +\textrm{while} & ✔* & ✔\\ +\textrm{lax.cond} & ✔ & ✔\\ +\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ +\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ +\textrm{lax.scan} & ✔ & ✔\\ +\hline +\end{array} +$$ + +
+ +$\ast$ = argument-value-independent loop condition - unrolls the loop + +
+ +## Logical operators + +`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`. + +For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar. + +```{code-cell} +def python_check_positive_even(x): + is_even = x % 2 == 0 + # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated. + return is_even and (x > 0) + +@jit +def jax_check_positive_even(x): + is_even = x % 2 == 0 + # `logical_and` does not short circuit, so `x > 0` is always evaluated. + return jnp.logical_and(is_even, x > 0) + +print(python_check_positive_even(24)) +print(jax_check_positive_even(24)) +``` + +When the JAX version with `logical_and` is applied to an array, it returns elementwise values. + +```{code-cell} +x = jnp.array([-1, 2, 5]) +print(jax_check_positive_even(x)) +``` + +Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior. + +```{code-cell} +:tags: [raises-exception] + +print(python_check_positive_even(x)) +``` + ++++ {"id": "izLTvT24dAq0"} + +## Python control flow + autodiff + +Remember that the above constraints on control flow and logical operators are relevant only with `jit`. If you just want to apply `grad` to your python functions, without `jit`, you can use regular Python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). + +```{code-cell} +:id: aAx0T3F8lLtu +:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52 + +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +print(grad(f)(2.)) # ok! +print(grad(f)(4.)) # ok! +``` diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD deleted file mode 100644 index 4954ce3db4fa..000000000000 --- a/docs/cuda_custom_call/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load( - "//jaxlib:jax.bzl", - "cuda_library", - "jax_generate_backend_suites", - "jax_multiplatform_test", -) - -licenses(["notice"]) - -package( - default_applicable_licenses = [], - default_visibility = ["//visibility:private"], -) - -jax_generate_backend_suites() - -jax_multiplatform_test( - name = "cuda_custom_call_test", - srcs = ["cuda_custom_call_test.py"], - data = [":foo"], - enable_backends = ["gpu"], - tags = ["notap"], - deps = [ - "//jax:extend", - ], -) - -# this second target is needed to properly link in CUDA runtime symbols -# such as cudaLaunchKernel, even though we are only building one library. -cc_shared_library( - name = "foo", - deps = [ - ":foo_", - "@xla//xla/tsl/cuda:cudart", - ], -) - -cuda_library( - name = "foo_", - srcs = ["foo.cu.cc"], - deps = [ - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - ], -) diff --git a/docs/cuda_custom_call/Makefile b/docs/cuda_custom_call/Makefile deleted file mode 100644 index ca51b63b5eaf..000000000000 --- a/docs/cuda_custom_call/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# This Makefile is not used by Bazel for this test, it is intended to serve as -# documentation of build instructions for JAX users that are not using Bazel to -# build their custom call code. For that reason, this Makefile is likely subject -# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in -# this directory no longer runs the test to completion. -NVCC = nvcc -NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())') -NVCCFLAGS += -arch native -# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu -NVCCFLAGS += -x cu - -# depends on libfoo.so being in the same directory as cuda_custom_call_test.py -check: libfoo.so - python cuda_custom_call_test.py - -lib%.so: %.cu.cc - $(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $< - -clean: - rm -rf *.so diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py deleted file mode 100644 index f63bbd670bf5..000000000000 --- a/docs/cuda_custom_call/cuda_custom_call_test.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# This test is intentionally structured to stay close to what a standalone JAX -# custom call integration might look like. JAX test harness is in a separate -# section towards the end of this file. The test can be run standalone by typing -# "make" in the directory containing this file. - -import os -import ctypes -import unittest - -import numpy as np - -import jax -import jax.numpy as jnp -from jax.extend import ffi - -# start test boilerplate -from absl.testing import absltest -from jax._src import config -from jax._src import test_util as jtu - -config.parse_flags_with_absl() -# end test boilerplate - -# XLA needs uppercase, "cuda" isn't recognized -XLA_PLATFORM = "CUDA" - -# JAX needs lowercase, "CUDA" isn't recognized -JAX_PLATFORM = "cuda" - -# 0 = original ("opaque"), 1 = FFI -XLA_CUSTOM_CALL_API_VERSION = 1 - -# these strings are how we identify kernels to XLA: -# - first we register a pointer to the kernel with XLA under this name -# - then we "tell" JAX to emit StableHLO specifying this name to XLA -XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd" -XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd" - -# load the shared library with the FFI target definitions -if jtu.is_running_under_pytest(): - raise unittest.SkipTest("libfoo.so hasn't been built") -SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so") -library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) - -# register the custom calls targets with XLA, api_version=1 by default -ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_FWD, - fn=ffi.pycapsule(library.FooFwd), - platform=XLA_PLATFORM) -ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_BWD, - fn=ffi.pycapsule(library.FooBwd), - platform=XLA_PLATFORM) - -def foo_fwd(a, b): - assert a.dtype == jnp.float32 - assert a.shape == b.shape - assert a.dtype == b.dtype - n = np.prod(a.shape).astype(np.uint64) - out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) - c, b_plus_1 = ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_FWD, (out_type, out_type), - a, b, n=n) - return c, (a, b_plus_1) - - -def foo_bwd(res, c_grad): - a, b_plus_1 = res - assert c_grad.dtype == jnp.float32 - assert c_grad.shape == a.shape - assert a.shape == b_plus_1.shape - assert c_grad.dtype == a.dtype - assert a.dtype == b_plus_1.dtype - n = np.prod(a.shape).astype(np.uint64) - out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) - return ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_BWD, (out_type, out_type), - c_grad, a, b_plus_1, n=n) - - -@jax.custom_vjp -def foo(a, b): - c, _ = foo_fwd(a, b) - return c - - -foo.defvjp(foo_fwd, foo_bwd) - -#-----------------------------------------------------------------------------# -# Test # -#-----------------------------------------------------------------------------# - - -class CustomCallTest(jtu.JaxTestCase): - - def test_fwd_interpretable(self): - shape = (2, 3) - a = 2. * jnp.ones(shape) - b = 3. * jnp.ones(shape) - observed = jax.jit(foo)(a, b) - expected = (2. * (3. + 1.)) - self.assertArraysEqual(observed, expected) - - def test_bwd_interpretable(self): - shape = (2, 3) - a = 2. * jnp.ones(shape) - b = 3. * jnp.ones(shape) - - def loss(a, b): - return jnp.sum(foo(a, b)) - - da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) - da_expected = b + 1 - db_expected = a - self.assertArraysEqual(da_observed, da_expected) - self.assertArraysEqual(db_observed, db_expected) - - def test_fwd_random(self): - shape = (2, 3) - akey, bkey = jax.random.split(jax.random.key(0)) - a = jax.random.normal(key=akey, shape=shape) - b = jax.random.normal(key=bkey, shape=shape) - observed = jax.jit(foo)(a, b) - expected = a * (b + 1) - self.assertAllClose(observed, expected) - - def test_bwd_random(self): - shape = (2, 3) - akey, bkey = jax.random.split(jax.random.key(0)) - a = jax.random.normal(key=akey, shape=shape) - b = jax.random.normal(key=bkey, shape=shape) - jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",)) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/docs/deprecation.md b/docs/deprecation.md index 385d31271421..603a027f5efc 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -18,6 +18,7 @@ This means we support at least: * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. * **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**. + * **Python 3.13** was released October 2024, and will be supported in new JAX releases at least until **July 2028**. * All NumPy feature releases in the 24 months prior to each JAX release. For example: @@ -25,6 +26,7 @@ This means we support at least: * **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025** * **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025** * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026** + * **NumPy 2.1** was released August 2024, and will be supported in new JAX releases at least until **August 2026** * All SciPy feature releases in the 24 months prior to each JAX release. For example: @@ -32,6 +34,7 @@ This means we support at least: * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. * **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**. * **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. + * **Scipy 1.14** was released June 2024, and will be supported in new JAX releases at least until **June 2026**. JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed diff --git a/docs/developer.md b/docs/developer.md index 68e8e931e2e5..e8069b3b5fe6 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -63,7 +63,7 @@ To build `jaxlib` from source, you must also install some prerequisites: To build `jaxlib` for CPU or TPU, you can run: ``` -python build/build.py +python build/build.py build --wheels=jaxlib --verbose pip install dist/*.whl # installs jaxlib (includes XLA) ``` @@ -71,7 +71,7 @@ To build a wheel for a version of Python different from your current system installation pass `--python_version` flag to the build command: ``` -python build/build.py --python_version=3.12 +python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose ``` The rest of this document assumes that you are building for Python version @@ -81,13 +81,13 @@ version, simply append `--python_version=` flag every time you call installation regardless of whether the `--python_version` parameter is passed or not. -There are two ways to build `jaxlib` with CUDA support: (1) use -`python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda -support, or (2) use -`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` +If you would like to build `jaxlib` and the CUDA plugins: Run +``` +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt +``` to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and -jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and -clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag. +jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and +clang, but it can be restricted to clang via the `--build_cuda_with_clang` flag. See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you @@ -102,18 +102,28 @@ current directory. target dependencies. To download the specific versions of CUDA/CUDNN redistributions, you can use - the following command: + the `--cuda_version` and `--cudnn_version` flags: ```bash - python build/build.py --enable_cuda \ - --cuda_version=12.3.2 --cudnn_version=9.1.1 + python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 + ``` + or + ```bash + python build/build.py build --wheels=jax-cuda-pjrt --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 ``` + Please note that these parameters are optional: by default Bazel will + download CUDA and CUDNN redistribution versions provided in `.bazelrc` in the + environment variables `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` + respectively. + To point to CUDA/CUDNN/NCCL redistributions on local file system, you can use the following command: ```bash - python build/build.py --enable_cuda \ + python build/build.py build --wheels=jax-cuda-plugin \ --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" @@ -136,7 +146,7 @@ ways to do this: line flag to `build.py` as follows: ``` - python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` - modify the `WORKSPACE` file in the root of the JAX source tree to point to @@ -178,7 +188,7 @@ path of the current session. Ensure `bazel`, `patch` and `realpath` are accessible. Activate the conda environment. ``` -python .\build\build.py +python .\build\build.py build --wheels=jaxlib ``` To build with debug information, add the flag `--bazel_options='--copt=/Z7'`. @@ -198,12 +208,14 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`, and selecting the appropriate options. -To build jaxlib with ROCM support, you can run the following build command, +To build jaxlib with ROCM support, you can run the following build commands, suitably adjusted for your paths and ROCM version. ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 ``` +to generate three wheels (jaxlib without rocm, jax-rocm-plugin, and +jax-rocm-pjrt) AMD's fork of the XLA repository may include fixes not present in the upstream XLA repository. If you experience problems with the upstream repository, you can @@ -216,7 +228,7 @@ git clone https://github.com/ROCm/xla.git and override the XLA repository with which JAX is built: ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --bazel_options=--override_repository=xla=/rel/xla/ --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py build --wheels=jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/ ``` For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`. @@ -241,7 +253,7 @@ run `build/build.py` script. To choose a specific version explicitly you may pass `--python_version` argument to the tool: ``` -python build/build.py --python_version=3.12 +python build/build.py build --python_version=3.12 ``` Under the hood, the hermetic Python version is controlled @@ -279,7 +291,7 @@ direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Alternatively, if you need more control, you may run the bazel command @@ -323,7 +335,7 @@ For example: ``` echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` ### Specifying dependencies on nightly wheels @@ -333,7 +345,7 @@ dependencies we provide a special version of the dependency updater command as follows: ``` -python build/build.py --requirements_nightly_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 --nightly_update ``` Or, if you run `bazel` directly (the two commands are equivalent): @@ -347,99 +359,162 @@ accept pre-release, dev and nightly packages, it will also search https://pypi.anaconda.org/scientific-python-nightly-wheels/simple as an extra index url and will not put hashes in the resultant requirements lock file. -### Building with pre-release Python version - -We support all of the current versions of Python out of the box, but if you need -to build and test against a different version (for example the latest unstable -version which hasn't been released officially yet) please follow the -instructions below. - -1) Make sure you have installed necessary linux packages needed to build Python - interpreter itself and key packages (like `numpy` or `scipy`) from source. On - a typical Debian system you may need to install the following packages: - -``` -sudo apt-get update -sudo apt-get build-dep python3 -y -sudo apt-get install pkg-config zlib1g-dev libssl-dev -y -# to build scipy -sudo apt-get install libopenblas-dev -y -``` - -2) Check your `WORKSPACE` file and make sure it - has `custom_python_interpreter()` entry there, pointing to the version of - Python you want to build. - -3) Run `bazel build @python_dev//:python_dev -repo_env=HERMETIC_PYTHON_VERSION=3.12` - to build Python interpreter. Note, it is easy to confuse Python version used - to conduct the build (which is needed for technical reasons and is defined by - `HERMETIC_PYTHON_VERSION=3.12`) and the version of Python you are building - (defined by whichever version you specified in `custom_python_interpreter()` - on step 2). For build to succeed, please make sure that hermetic Python you - choose to conduct the build already exists in your configuraiton (the actual - version does not matter, as long as it is a working one). By default, Python - binary will be built with GCC compiler. If you wish to build it with clang, - you need to set corresponding env variables to do so ( - e.g. `--repo_env=CC=/usr/lib/llvm-17/bin/clang --repo_env=CXX=/usr/lib/llvm-17/bin/clang++`). - -4) Check the output of the previous command. At the very end of it you will find - a code snippet for `python_register_toolchains()` entry with your newly built - Python in it. Copy that code snippet in your `WORKSPACE` file either right - after `python_init_toolchains()` entry (to add the new version of Python) or - instead of it (to replace an existing version, like replacing `3.12` with - custom built variant of `3.12`). The code snippet is generated to match your - actual setup, so it should work as is, but you can customize it if you choose - so (for example to change location of Python's `.tgz` file so it could be - downloaded remotely instead of being on local machine). - -5) Make sure there is an entry for your Python's version in `requirements` - parameter for `python_init_repositories()` in your WORKSPACE file. For - example for `Python 3.13` it should have something - like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the - `requirements` parameter must always be in `"major.minor"` version format, so - even if you are building Python version `3.13.0rc1` the corresponding - `requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`, - **not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`. - -6) For unstable versions of Python, optionally (but highly recommended) - run `bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13"`, - where `3.13` is the version of Python interpreter you built on step 3. - This will make `pip` pull and build from sources (for packages which don't - have binaries published yet, for - example `numpy`, `scipy`, `matplotlib`, `zstandard`) all of the JAX's python - dependencies. It is recommended to do this step first (i.e. independently of - actual JAX build) for all unstable versions of Python to avoid conflict - between building JAX itself and building of its Python dependencies. For - example, we normally build JAX with clang but building `matplotlib` from - sources with clang fails out of the box due to differences in LTO behavior ( - Link Time Optimization, triggered by `-flto` flag) between GCC and clang, and - matplotlib assumes GCC by default. - If you build against a stable version of Python, or in general you do not - expect any of your Python dependencies to be built from sources (i.e. binary - distributions for the corresponding Python version already exist in the - repository) this step is not needed. - -7) Congrats, you've built and configured your custom Python for JAX project! You - may now execute your built/test commands as usual, just make - sure `HERMETIC_PYTHON_VERSION` environment variable is set and points to your - new version. - -8) Note, if you were building a pre-release version of Python, updating of - `requirements_lock_.txt` files with your newly built Python - is likely to fail, because package repositories will not have matching - binary packages. When there are no binary packages available `pip-compile` - proceeds with building them from sources, which is likely to fail because it - is more restrictive than doing the same thing during `pip` installation. - The recommended way to update requirements lock file for unstable versions of - Python is to update requirements for the latest stable version (e.g. `3.12`) - without hashes (therefore special `//build:requirements_dev.update` target) - and then copy the results to the unstable Python's lock file (e.g. `3.13`): -``` -bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="3.12" -cp build/requirements_lock_3_12.txt build/requirements_lock_3_13.txt -bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13" -# You may need to edit manually the resultant lock file, depending on how ready -# your dependencies are for the new version of Python. +### Customizing hermetic Python (Advanced Usage) + +We support all of the current versions of Python out of the box, so unless your +workflow has very special requirements (such as ability to use your own custom +Python interpreter) you may safely skip this section entirely. + +In short, if you rely on a non-standard Python workflow you still can achieve +the great level of flexibility in hermetic Python setup. Conceptually there will +be only one difference compared to non-hermetic case: you will need to think in +terms of files, not installations (i.e. think what files your build actually +depends on, not what files need to be installed on your system), the rest is +pretty much the same. + +So, in practice, to gain full control over your Python environment, hermetic or +not you need to be able to do the following three things: + +1) Specify which python interpreter to use (i.e. pick actual `python` or + `python3` binary and libs that come with it in the same folder). +2) Specify a list of Python dependencies (e.g. `numpy`) and their actual + versions. +3) Be able to add/remove/update dependencies in the list easily. Each + dependency itself could be custom too (self-built for example). + +You already know how to do all of the steps above in a non-hermetic Python +environment, here is how you do the same in the hermetic one (by approaching it +in terms of files, not installations): + +1) Instead of installing Python, get Python interpreter in a `tar` or `zip` + file. Depending on your case you may simply pull one of many existing ones + (such as [python-build-standalone](https://github.com/indygreg/python-build-standalone/releases)), + or build your own and pack it in an archive (following official + [build instructions](https://devguide.python.org/getting-started/setup-building/#compile-and-build) + will do just fine). E.g. on Linux it will look something like the following: + ``` + ./configure --prefix python + make -j12 + make altinstall + tar -czpf my_python.tgz python + ``` + Once you have the tarball ready, plug it in the build by pointing + `HERMETIC_PYTHON_URL` env var to the archive (either local one or from the + internet): + ``` + --repo_env=HERMETIC_PYTHON_URL="file:///local/path/to/my_python.tgz" + --repo_env=HERMETIC_PYTHON_SHA256= + + # OR + --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" + --repo_env=HERMETIC_PYTHON_SHA256= + + # We assume that top-level folder in the tarbal is called "python", if it is + # something different just pass additional HERMETIC_PYTHON_PREFIX parameter + --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" + --repo_env=HERMETIC_PYTHON_SHA256= + --repo_env=HERMETIC_PYTHON_PREFIX="my_python/install" + ``` + +2) Instead of doing `pip install` create `requirements_lock.txt` file with + full transitive closure of your dependencies. You may also depend on the + existing ones already checked in this repo (as long as they work with your + custom Python version). There are no special instructions on how you do it, + you may follow steps recommended in [Specifying Python dependencies](#specifying-python-dependencies) + from this doc, just call pip-compile directly (note, the lock file must be + hermetic, but you can always generate it from non-hermetic python if you'd + like) or even create it manually (note, hashes are optional in lock files). + + +3) If you need to update or customize your dependencies list, you may once again + follow the [Specifying Python dependencies](#specifying-python-dependencies) + instructions to update `requirements_lock.txt`, call pip-compile directly or + modify it manually. If you have a custom package you want to use just point + to its `.whl` file directly (remember, work in terms of files, not + installations) from your lock (note, `requirements.txt` and + `requirements_lock.txt` files support local wheel references). If your + `requirements_lock.txt` is already specified as a dependency to + `python_init_repositories()` in `WORKSPACE` file you don't have to do + anything else. Otherwise you can point to your custom file as follows: + ``` + --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/custom_requirements_lock.txt" + ``` + Also note if you use `HERMETIC_REQUIREMENTS_LOCK` then it fully controls list + of your dependencies and the automatic local wheels resolution logic + described in [Specifying dependencies on local wheels](#specifying-dependencies-on-local-wheels) + gets disabled to not interfere with it. + +That is it. To summarize: if you have an archive with Python interpreter in it +and a requirements_lock.txt file with full transitive closure of your +dependencies then you fully control your Python environment. + +#### Custom hermetic Python examples + +Note, for all of the examples below you may also set the environment variables +globally (i.e. `export` in your shell instead of `--repo_env` argument to your +command) so calling bazel via `build/build.py` will work just fine. + +Build with custom `Python 3.13` from the internet, using default +`requirements_lock_3_13.txt` already checked in this repo (i.e. custom +interpreter but default dependencies): +``` +bazel build + --repo_env=HERMETIC_PYTHON_VERSION=3.13 + --repo_env=HERMETIC_PYTHON_URL="https://github.com/indygreg/python-build-standalone/releases/download/20241016/cpython-3.13.0+20241016-x86_64-unknown-linux-gnu-install_only.tar.gz" + --repo_env=HERMETIC_PYTHON_SHA256="2c8cb15c6a2caadaa98af51df6fe78a8155b8471cb3dd7b9836038e0d3657fb4" +``` + +Build with custom Python 3.13 from local file system and custom lock file +(assuming the lock file was put in `jax/build` folder of this repo before +running the command): +``` +bazel test + --repo_env=HERMETIC_PYTHON_VERSION=3.13 + --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" + --repo_env=HERMETIC_PYTHON_PREFIX="prefix/to/strip/in/cython/tar/gz/archive" + --repo_env=HERMETIC_PYTHON_SHA256= + --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt" +``` + +If default python interpreter is good enough for you and you just need a custom +set of dependencies: +``` +bazel test + --repo_env=HERMETIC_PYTHON_VERSION=3.13 + --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt" +``` + +Note, you can have multiple different `requirement_lock.txt` files corresponding +to the same Python version to support different scenarios. You can control +which one is selected by specifying `HERMETIC_PYTHON_VERSION`. For example in +`WORKSPACE` file: +``` +requirements = { + "3.10": "//build:requirements_lock_3_10.txt", + "3.11": "//build:requirements_lock_3_11.txt", + "3.12": "//build:requirements_lock_3_12.txt", + "3.13": "//build:requirements_lock_3_13.txt", + "3.13-scenario1": "//build:scenario1_requirements_lock_3_13.txt", + "3.13-scenario2": "//build:scenario2_requirements_lock_3_13.txt", +}, +``` +Then you can build and test different combinations of stuff without changing +anything in your environment: +``` +# To build with scenario1 dependendencies: +bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 + +# To build with scenario2 dependendencies: +bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2 + +# To build with default dependendencies: +bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13 + +# To build with scenario1 dependendencies and custom Python 3.13 interpreter: +bazel test + --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 + --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" + --repo_env=HERMETIC_PYTHON_SHA256= ``` ## Installing `jax` @@ -464,10 +539,13 @@ or using pytest. ### Using Bazel -First, configure the JAX build by running: +First, configure the JAX build by using the `--configure_only` flag. Pass +`--wheel_list=jaxlib` for CPU tests and CUDA/ROCM for GPU for GPU tests: ``` -python build/build.py --configure_only +python build/build.py build --wheels=jaxlib --configure_only +python build/build.py build --wheels=jax-cuda-plugin --configure_only +python build/build.py build --wheels=jax-rocm-plugin --configure_only ``` You may pass additional options to `build.py` to configure the build; see the @@ -489,14 +567,14 @@ make it available in the hermetic Python. To install a specific version of ``` echo -e "\njaxlib >= 0.4.26" >> build/requirements.in -python build/build.py --requirements_update +python build/build.py requirements_update ``` Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): ``` echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Once you have `jaxlib` installed hermetically, run: @@ -611,22 +689,21 @@ minimization phase. ### Doctests JAX uses pytest in doctest mode to test the code examples within the documentation. -You can run this using +You can find the up-to-date command to run doctests in +[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml). +E.g., you can run: ``` -pytest docs +JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst ``` Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in function docstrings will run correctly. You can run this locally using, for example: ``` -pytest --doctest-modules jax/_src/numpy/lax_numpy.py +JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py ``` -Keep in mind that there are several files that are marked to be skipped when the -doctest command is run on the full package; you can see the details in -[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml) ## Type checking diff --git a/docs/export/export.md b/docs/export/export.md index 5960fcaea65a..aa686b03e2b2 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -247,7 +247,7 @@ for which the code was exported. You can specify explicitly for what platforms the code should be exported. This allows you to specify a different accelerator than you have available at export time, -and it even allows you to specify multi-platform lexport to +and it even allows you to specify multi-platform export to obtain an `Exported` object that can be compiled and executed on multiple platforms. @@ -273,7 +273,7 @@ ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used o >>> # compilation platform (which is the case for `cos` in this >>> # example): >>> exp_unsafe = export.export(jax.jit(lax.cos), -... lowering_platforms=['tpu'], +... platforms=['tpu'], ... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.) >>> exp_unsafe.call(1.) @@ -281,7 +281,7 @@ Array(0.5403023, dtype=float32, weak_type=True) # and similarly with multi-platform lowering >>> exp_multi = export.export(jax.jit(lax.cos), -... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.) +... platforms=['tpu', 'cpu', 'cuda'])(1.) >>> exp_multi.call(1.) Array(0.5403023, dtype=float32, weak_type=True) @@ -293,7 +293,7 @@ resulting module size should be only marginally larger than the size of a module with default export. As an extreme case, when serializing a module without any primitives with platform-specific lowering, you will get -the same StableHLO as for the single-plaform export. +the same StableHLO as for the single-platform export. ```python >>> import jax @@ -310,7 +310,7 @@ the same StableHLO as for the single-plaform export. 9220 >>> exp_multi = export.export(jax.jit(f), -... lowering_platforms=["cpu", "tpu", "cuda"])(1.) +... platforms=["cpu", "tpu", "cuda"])(1.) >>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP 9282 diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index b1ce80638706..9254030a4e1c 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -44,7 +44,7 @@ following example: ``` Note that such functions are still re-compiled on demand for -each concrete input shapes they are invoked on. Only the +each concrete input shape they are invoked on. Only the tracing and the lowering are saved. The {func}`jax.export.symbolic_shape` is used in the above @@ -98,7 +98,7 @@ A few examples of shape specifications: arguments. Note that the same specification would work if the first argument is a pytree of 3D arrays, all with the same leading dimension but possibly with different trailing dimensions. - The value `None` for the second arugment means that the argument + The value `None` for the second argument means that the argument is not symbolic. Equivalently, one can use `...`. * `("(batch, ...)", "(batch,)")` specifies that the two arguments @@ -159,7 +159,7 @@ new shape: It is possible to convert dimension expressions explicitly to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`. The result of these operations can be used as regular JAX arrays, -bug cannot be used anymore as dimensions in shapes. +but cannot be used anymore as dimensions in shapes, e.g., in `reshape`: ```python >>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( @@ -256,7 +256,7 @@ as follows: integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`, `a >= b`, `a - b >= 0` are inconclusive and result in an exception. -In cases where a comparison operation cannot be resolve to a boolean, +In cases where a comparison operation cannot be resolved to a boolean, we raise {class}`InconclusiveDimensionOperation`. E.g., ```python @@ -351,7 +351,7 @@ symbolic constraints: is encountered, it is rewritten to the expression on the right. E.g., `floordiv(a, b) == c` works by replacing all - occurences of `floordiv(a, b)` with `c`. + occurrences of `floordiv(a, b)` with `c`. Equality constraints must not contain addition or subtraction at the top-level on the left-hand-side. Examples of valid left-hand-sides are `a * b`, or `4 * a`, or @@ -498,11 +498,11 @@ This works well for most use cases, and it mirrors the calling convention of JIT functions. Sometimes you may want to export a function parameterized -by an integer values that determines some shapes in the program. +by an integer value that determines some shapes in the program. For example, we may want to export the function `my_top_k` defined below, parameterized by the -value of `k`, which determined the shape of the result. +value of `k`, which determines the shape of the result. The following attempt will lead to an error since the dimension variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`: @@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass These errors arise in a pre-processing step before the compilation. -### Division of symbolic dimensions is partially supported - -JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. -In particular, JAX will handle the cases when either (a) there -is no remainder, or (b) the divisor is a constant -in which case there may be a constant remainder. - -For example, the code below results in a division error when trying to -compute the inferred dimension for a `reshape` operation: - -```python ->>> b, = export.symbolic_shape("b") ->>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b,), dtype=np.int32)) -Traceback (most recent call last): -jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). -The remainder mod(b, - 2) should be 0. - -``` - -Note that the following will succeed: - -```python ->>> b, = export.symbolic_shape("b") ->>> # We specify that the first dimension is a multiple of 4 ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,2*b]),) - ->>> # We specify that some other dimension is even ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,15*b]),) - -``` - (shape_poly_debugging)= ## Debugging diff --git a/docs/faq.rst b/docs/faq.rst index af14f382b1d7..44267f6f5f7d 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -847,6 +847,6 @@ see the page on `JAX GPU memory allocation`_. .. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function -.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266 +.. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 .. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index ea8a86fa80f1..72a2a6914fc0 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -26,10 +26,7 @@ "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", "\n", - "This tutorial comes with two supplementary files:\n", - "\n", - "* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n", - "* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n", + "The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n", "\n", "## A simple example\n", "\n", @@ -101,7 +98,7 @@ "\n", "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n", - "The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n", + "The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n", "\n", "```c++\n", "#include \n", @@ -129,12 +126,11 @@ "// A wrapper function providing the interface between the XLA FFI call and our\n", "// library function `ComputeRmsNorm` above. This function handles the batch\n", "// dimensions by calling `ComputeRmsNorm` within a loop.\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y) {\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y) {\n", " auto [totalSize, lastDim] = GetDims(x);\n", " if (lastDim == 0) {\n", - " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", - " \"RmsNorm input must be an array\");\n", + " return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n", " }\n", " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", " ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n", @@ -143,14 +139,14 @@ "}\n", "\n", "// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare\n", - "// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`\n", - "// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.\n", + "// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`\n", + "// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.\n", "XLA_FFI_DEFINE_HANDLER_SYMBOL(\n", " RmsNorm, RmsNormImpl,\n", " ffi::Ffi::Bind()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", @@ -173,8 +169,7 @@ "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", "In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n", "\n", - "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n", - "The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)." + "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble." ] }, { @@ -433,7 +428,7 @@ "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", "\n", - "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n", "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", "\n", "This custom derivative rule can be wired in as follows:" @@ -508,16 +503,16 @@ "When defining our FFI wrapper for CPU, the function signature that we used was:\n", "\n", "```c++\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y)\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "To update this to interface with a CUDA kernel, this signature becomes:\n", "\n", "```c++\n", "ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n", - " ffi::Buffer x,\n", - " ffi::Result> y)\n", + " ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "And the handler definition is updated to include a `Ctx` in its binding:\n", @@ -528,8 +523,8 @@ " ffi::Ffi::Bind()\n", " .Ctx>()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index 5afc8f809d4d..96b627675004 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts: In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. -This tutorial comes with two supplementary files: - -* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and -* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code. +The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi). ## A simple example @@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call). -The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here: +The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here: ```c++ #include @@ -124,12 +121,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -138,14 +134,14 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, } // Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare -// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` -// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. +// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL` +// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`. XLA_FFI_DEFINE_HANDLER_SYMBOL( RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` @@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below. To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble. -The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt). ```{code-cell} ipython3 :tags: [hide-output] @@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. 2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. -We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end. +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end. The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. This custom derivative rule can be wired in as follows: @@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac When defining our FFI wrapper for CPU, the function signature that we used was: ```c++ -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) ``` To update this to interface with a CUDA kernel, this signature becomes: ```c++ ffi::Error RmsNormImpl(cudaStream_t stream, float eps, - ffi::Buffer x, - ffi::Result> y) + ffi::Buffer x, + ffi::ResultBuffer y) ``` And the handler definition is updated to include a `Ctx` in its binding: @@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER( ffi::Ffi::Bind() .Ctx>() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc index 4dc8a890410c..467f13d44ac2 100644 --- a/docs/ffi/rms_norm.cc +++ b/docs/ffi/rms_norm.cc @@ -56,12 +56,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); -ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormFwd, RmsNormFwdImpl, - ffi::Ffi::Bind() - .Attr("eps") - .Arg>() // x - .Ret>() // y - .Ret>() // res +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res ); void ComputeRmsNormBwd(int64_t size, float res, const float *x, @@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, } } -ffi::Error RmsNormBwdImpl(ffi::Buffer res, - ffi::Buffer x, - ffi::Buffer ct_y, - ffi::Result> ct_x) { +ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, + ffi::Buffer ct_y, + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), @@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer res, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormBwd, RmsNormBwdImpl, - ffi::Ffi::Bind() - .Arg>() // res - .Arg>() // x - .Arg>() // ct_y - .Ret>() // ct_x +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x ); diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 1fde02a14655..6667589e7b72 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -60,3 +60,41 @@ Common causes of OOM failures **Running JAX on the display GPU.** Use :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` or :code:`XLA_PYTHON_CLIENT_PREALLOCATE`. + +**Disabling rematerialization HLO pass** + Sometimes disabling the automatic rematerialization HLO pass is favorable to avoid + poor remat choices by the compiler. The pass can be enable/disable by setting + :code:`jax.config.update('enable_remat_opt_pass', True)` or + :code:`jax.config.update('enable_remat_opt_pass', False)` respectively. Enabling or + disabling the automatic remat pass produces different trade-offs between compute and + memory. Note however, that the algorithm is basic and you can often get better + trade-off between compute and memory by disabling the automatic remat pass and doing + it manually with `the jax.remat API `_ + + +Experimental features +--------------------- + +Features here are experimental and must be tried with caution. + +``TF_GPU_ALLOCATOR=cuda_malloc_async`` + This replace XLA's own BFC memory allocator with `cudaMallocAsync + `_. + This will remove the big fixed pre-allocation and use a memory pool that grows. + The expected benefit is no need to set `XLA_PYTHON_CLIENT_MEM_FRACTION`. + + The risk are: + + - that memory fragmentation is different, so if you are close to the + limit, the exact OOM case due to fragmentation will be different. + - The allocation time won't be all paid at the start, but be incurred + when the memory pool need to be increased. So you could + experience less speed stability at the start and for benchmarks + it will be even more important to ignore the first few iterations. + + The risks can be mitigated by pre-allocating a signigicant chunk and + still get the benefit of having a growing memory pool. This can be + done with `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N`. If N is `-1` + it will preallocate the same as what was allocatedy by + default. Otherwise, it is the size in bytes that you want to + preallocate. diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 1f5cc0727605..5a760db98684 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -112,7 +112,7 @@ don't seem useful for multi-host communication yet. ## Multi-Process -We recommand using one process per GPU and not one per node. In some +We recommend using one process per GPU and not one per node. In some cases, this can speed up jitted computation. The {func}`jax.distributed.initialize` API will automatically understand that configuration when run under SLURM. However, this only a rule of diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 33efaed6274b..3ef927e056f2 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -443,7 +443,7 @@ print_fwd_bwd(f, 3.) When differentiated functions are staged out to XLA for compilation — for example by applying {func}`jax.jit` to a function which contains a {func}`jax.grad` call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **{func}`jax.checkpoint` often isn't needed for differentiated functions under a {func}`jax.jit`**. XLA will optimize things for you. -One exception is when using staged-out control flow, like {func}`jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass `scan` and the corresponding backward-pass `scan`), typically aren't aren't as thorough. As a result, it's often a good idea to use {func}`jax.checkpoint` on the body function passed to {func}`jax.lax.scan`. +One exception is when using staged-out control flow, like {func}`jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass `scan` and the corresponding backward-pass `scan`), typically aren't as thorough. As a result, it's often a good idea to use {func}`jax.checkpoint` on the body function passed to {func}`jax.lax.scan`. For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a {func}`jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this: diff --git a/docs/hero.html b/docs/hero.html new file mode 100644 index 000000000000..a2ee3b8e206f --- /dev/null +++ b/docs/hero.html @@ -0,0 +1,8 @@ +
+
+ +

High performance array computing

+

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

+
+ +
\ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 2dd856ab88ef..ba8ebcbdd128 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,10 +1,22 @@ JAX: High performance array computing ===================================== -JAX is a Python library for accelerator-oriented array computation and program transformation, -designed for high-performance numerical computing and large-scale machine learning. +.. raw:: html + + + + +.. raw:: html + :file: hero.html .. grid:: 3 + :class-container: product-offerings :margin: 0 :padding: 0 :gutter: 0 @@ -31,6 +43,13 @@ designed for high-performance numerical computing and large-scale machine learni The same code executes on multiple backends, including CPU, GPU, & TPU .. grid:: 3 + :class-container: color-cards + + .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Installation + :columns: 12 6 6 4 + :link: installation + :link-type: ref + :class-card: installation .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started :columns: 12 6 6 4 @@ -44,12 +63,6 @@ designed for high-performance numerical computing and large-scale machine learni :link-type: ref :class-card: user-guides - .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes - :columns: 12 6 6 4 - :link: contributor-guide - :link-type: ref - :class-card: developer-docs - If you're looking to train neural networks, use Flax_ and start with its tutorials. For an end-to-end transformer library built on JAX, see MaxText_. @@ -59,13 +72,12 @@ JAX itself is narrowly-scoped and focuses on efficient array operations & progra transformations. Built around JAX is an evolving ecosystem of machine learning and numerical computing tools; the following is just a small sample of what is out there: -.. grid:: 4 +.. grid:: 2 :class-container: ecosystem-grid .. grid-item:: :material-outlined:`hub;2em` **Neural networks** - Flax_ - - NNX_ - Equinox_ - Keras_ @@ -79,8 +91,8 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-outlined:`storage;2em` **Data loading** - Grain_ - - `Tensorflow datasets`_ - - `Hugging Face datasets`_ + - `TensorFlow Datasets`_ + - `Hugging Face Datasets`_ .. grid-item:: :material-regular:`construction;2em` **Miscellaneous tools** @@ -95,7 +107,7 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** - - `Tensorflow probabilty`_ + - `TensorFlow Probabilty`_ - Distrax_ .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** @@ -143,6 +155,7 @@ maintains an up-to-date list. extensions notes jax + about .. toctree:: @@ -164,17 +177,16 @@ maintains an up-to-date list. .. _Equinox: https://docs.kidger.site/equinox/ .. _Flax: https://flax.readthedocs.io/ .. _Grain: https://github.com/google/grain -.. _Hugging Face datasets: https://huggingface.co/docs/datasets/ +.. _Hugging Face Datasets: https://huggingface.co/docs/datasets/ .. _JAX MD: https://jax-md.readthedocs.io/ .. _Keras: https://keras.io/ .. _Levanter: https://github.com/stanford-crfm/levanter .. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ -.. _NNX: https://flax.readthedocs.io/en/latest/nnx/ .. _Numpyro: https://num.pyro.ai/en/latest/index.html .. _Optax: https://optax.readthedocs.io/ .. _Optimistix: https://github.com/patrick-kidger/optimistix .. _Orbax: https://orbax.readthedocs.io/ .. _PyMC: https://www.pymc.io/ -.. _Tensorflow datasets: https://www.tensorflow.org/datasets -.. _Tensorflow probabilty: https://www.tensorflow.org/probability +.. _TensorFlow Datasets: https://www.tensorflow.org/datasets +.. _TensorFlow Probabilty: https://www.tensorflow.org/probability diff --git a/docs/installation.md b/docs/installation.md index 5b8893628d85..6686eac41186 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -35,6 +35,7 @@ The table below shows all supported platforms and installation options. Check if | Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | | AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no | | Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | +| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | (install-cpu)= @@ -230,6 +231,17 @@ JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or * Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_). +(install-intel-gpu)= +## Intel GPU + +Intel provides an experimental OneAPI plugin: intel-extension-for-openxla for Intel GPU hardware. For more details and installation instructions, refer to one of the following two methods: +1. Pip installation: [JAX acceleration on Intel GPU](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). +2. Using [Intel's XLA Docker container](https://hub.docker.com/r/intel/intel-optimized-xla). + +Please report any issues related to: +* JAX: [JAX issue tracker](https://github.com/jax-ml/jax/issues). +* Intel's OpenXLA plugin: [Intel-extension-for-openxla issue tracker](https://github.com/intel/intel-extension-for-openxla/issues). + ## Conda (community-supported) ### Conda installation @@ -241,18 +253,14 @@ simply run: conda install jax -c conda-forge ``` -To install it on a machine with an NVIDIA GPU, run: +If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package of `jaxlib`. + +To ensure that the jax version you are installing is indeed CUDA-enabled, run: ```bash -conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +conda install "jaxlib=*=*cuda*" jax -c conda-forge ``` -Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which -JAX requires. You must therefore either install the `cuda-nvcc` package from -the `nvidia` channel, or install CUDA on your machine separately so that `ptxas` -is in your path. The channel order above is important (`conda-forge` before -`nvidia`). - If you would like to override which release of CUDA is used by JAX, or to install the CUDA build on a machine without GPUs, follow the instructions in the [Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 82c9f08145eb..2d3452609c75 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -27,8 +27,10 @@ Functions barrier_arrive barrier_wait + commit_smem copy_gmem_to_smem copy_smem_to_gmem + emit_pipeline layout_cast set_max_registers wait_smem_to_gmem diff --git a/docs/jax.export.rst b/docs/jax.export.rst index d458b6c64e8e..c8feb1d169bd 100644 --- a/docs/jax.export.rst +++ b/docs/jax.export.rst @@ -14,8 +14,11 @@ Classes .. autosummary:: :toctree: _autosummary - Exported - DisabledSafetyCheck +.. autoclass:: Exported + :members: + +.. autoclass:: DisabledSafetyCheck + :members: Functions --------- @@ -28,6 +31,8 @@ Functions minimum_supported_calling_convention_version maximum_supported_calling_convention_version default_export_platform + register_pytree_node_serialization + register_namedtuple_serialization Functions related to shape polymorphism --------------------------------------- diff --git a/docs/jax.extend.core.rst b/docs/jax.extend.core.rst new file mode 100644 index 000000000000..5f3ff0558af6 --- /dev/null +++ b/docs/jax.extend.core.rst @@ -0,0 +1,18 @@ +``jax.extend.core`` module +========================== + +.. automodule:: jax.extend.core + +.. autosummary:: + :toctree: _autosummary + + ClosedJaxpr + Jaxpr + JaxprEqn + Literal + Primitive + Token + Var + array_types + jaxpr_as_fun + primitives diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 9cbee08e8e50..0d68013c9261 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -11,6 +11,7 @@ Modules .. toctree:: :maxdepth: 1 + jax.extend.core jax.extend.ffi jax.extend.linear_util jax.extend.mlir diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 9eb518464b4e..30553a360155 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -138,6 +138,7 @@ namespace; they are listed below. csingle cumprod cumsum + cumulative_prod cumulative_sum deg2rad degrees @@ -336,6 +337,7 @@ namespace; they are listed below. promote_types ptp put + put_along_axis quantile r_ rad2deg diff --git a/docs/jax.rst b/docs/jax.rst index a5e0dcad5b50..042804792f8a 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -102,6 +102,9 @@ Automatic differentiation closure_convert checkpoint +Customization +------------- + ``custom_jvp`` ~~~~~~~~~~~~~~ @@ -121,6 +124,16 @@ Automatic differentiation custom_vjp custom_vjp.defvjp +``custom_batching`` +~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + custom_batching.custom_vmap + custom_batching.custom_vmap.def_vmap + custom_batching.sequential_vmap + jax.Array (:code:`jax.Array`) ----------------------------- diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst deleted file mode 100644 index 1ece596d88ef..000000000000 --- a/docs/jax_internal_api.rst +++ /dev/null @@ -1,14 +0,0 @@ -Internal API reference -====================== - -core ----- - -.. currentmodule:: jax.core -.. automodule:: jax.core - -.. autosummary:: - :toctree: _autosummary - - Jaxpr - ClosedJaxpr diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 8e66a675a522..63742bc852c6 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -3,6 +3,9 @@ *January 2023* +**This was the design doc proposing `shard_map`. You may instead want +[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** + ## Motivation JAX supports two schools of thought for multi-device programming: @@ -374,114 +377,8 @@ One philosophy is: it is almost always simpler to write a program in `jit==pjit` — but if a given part of the program is less optimized by the compiler than it could be, drop into `shmap`! -### A realistic transformer example - -In fact, we can implement a simple version of the ["collective -matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm -recently introduced in XLA to overlap communication and computation using `shmap` -and 30 lines of Python. The basic idea of the algorithm can be grasped with a -simple example. - -Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the -0-th dimension while `B` and `C` are replicated. - -```python -M, K, N = 4096, 2048, 1024 -A = jnp.arange(np.prod((M, K))).reshape((M, K)) -B = jnp.arange(np.prod((K, N))).reshape((K, N)) - -mesh = Mesh(np.array(jax.devices()), axis_names=('i')) -A_x = jax.device_put(A, NamedSharding(mesh, P('i', None))) - -@jax.jit -def f(lhs, rhs): - return lhs @ rhs - -C = f(A_x, B) -``` - -A profile shows the blocking all-gather across 8 devices before the matmul can -start. This is suboptimal because `A` is sharded on a non-contracting dimension, -and each shard of `A` can be matmul'ed with `B` independently and this chunked -computation can be overlapped with fetching of the next shard of `A` from -another device. - -image - -This overlap can be implemented using `shmap` and explicit collectives. - -```python -def collective_matmul_allgather_lhs_non_contracting(lhs, rhs): - # lhs is the looped operand; rhs is the local operand - axis_size = jax.lax.psum(1, axis_name='i') - axis_index = jax.lax.axis_index(axis_name='i') - chunk_size = lhs.shape[0] - - def f(i, carrys): - accum, lhs = carrys - # matmul for a chunk - update = lhs @ rhs - # circular shift to the left - lhs = jax.lax.ppermute( - lhs, - axis_name='i', - perm=[(j, (j - 1) % axis_size) for j in range(axis_size)] - ) - # device 0 computes chunks 0, 1, ... - # device 1 computes chunks 1, 2, ... - update_index = (((axis_index + i) % axis_size) * chunk_size, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - return accum, lhs - - accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype) - # fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual() - # accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs)) - for i in range(0, axis_size - 1): - accum, lhs = f(i, (accum, lhs)) - - # compute the last chunk, without the ppermute - update = lhs @ rhs - i = axis_size - 1 - update_index = (((axis_index + i) % axis_size) * chunk_size, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - - return accum -``` - -``` -jit_sharded_f = jax.jit(shard_map( - collective_matmul_allgather_lhs_non_contracting, mesh, - in_specs=(P('i', None), P()), out_specs=P())) -C = jit_sharded_f(A_x, B) -``` -A profile shows that the all-gather is gone, and replaced with overlapped matmul -with async collective permute. This profile matches very closely with the -collective matmul paper result. - -image - -This collective matmul technique can be used to speed up feedforward blocks in -transformer layers. This typically consists of two matrix multiplications -followed by a `ReduceScatter` (to resolve partial sums from a parallelized -matrix multiplication) and preceded by an `AllGather` (to collect the sharded -dimensions along some axes and allow partial sum computation). Together, the -`ReduceScatter` from one layer and the `AllGather` for the next amount to an -`AllReduce`. - -In a typical profile, the two matmuls will be followed by an `AllReduce`, and -they will not be overlapped. Collective matmul can be used to achieve the -overlap, but is difficult to trigger, has a minimum slice size and does not yet -cover all topologies, tensor shapes and variants of collective matmul (i.e -latency and throughput optimized variants). [In a recent -paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many -circumstances from manually implementing collective matmul variants in `shmap` -style. - -But it isn’t always more complex! We expect this to be a much more natural way -to think about pipelined computation, and plan to do some demos of that soon! - -### Another realistic example +### A realistic example Here's how `shmap` might look in a transformer layer pass with a 2D weight gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5): diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 59c7bbd8fb90..5e5be308068a 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -170,7 +170,7 @@ jax.jit(g)(10, 20) # Raises an error The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values. Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as `shape` or `dtype`, and not via their values. -For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). +For more detail on the interaction between Python control flow and JAX, see {ref}`control-flow`. One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical. In that case, you can consider JIT-compiling only part of the function. @@ -192,6 +192,8 @@ def g_inner_jitted(x, n): g_inner_jitted(10, 20) ``` +(jit-marking-arguments-as-static)= + ## Marking arguments as static If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index daab2c9fdde4..91f0c953462e 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -189,3 +189,43 @@ tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the le in a tree. You can learn more in the {ref}`working-with-pytrees` tutorial. + +(key-concepts-prngs)= +## Pseudorandom numbers + +Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: + +```{code-cell} +from jax import random + +key = random.key(43) +print(key) +``` + +The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions. +Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. + +```{code-cell} +print(random.normal(key)) +print(random.normal(key)) +``` + +**The rule of thumb is: never reuse keys (unless you want identical outputs).** + +In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: + +```{code-cell} +for i in range(3): + new_key, subkey = random.split(key) + del key # The old key is consumed by split() -- we must never use it again. + + val = random.normal(subkey) + del subkey # The subkey is consumed by normal(). + + print(f"draw {i}: {val}") + key = new_key # new_key is safe to use in the next iteration. +``` + +Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. + +For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 71bd4527644a..8823fac13042 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -34,7 +34,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "from jax import grad, jit\n", + "from jax import jit\n", "from jax import lax\n", "from jax import random\n", "import jax\n", @@ -202,7 +202,7 @@ "id": "cDpQ5u63Ba_H" }, "source": [ - "It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results." + "It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX's functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results." ] }, { @@ -865,920 +865,21 @@ "id": "MUycRNh6e50W" }, "source": [ - "## 🔪 Random numbers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O8vvaVt3MRG2" - }, - "source": [ - "> _If all scientific papers whose results are in doubt because of bad\n", - "> `rand()`s were to disappear from library shelves, there would be a\n", - "> gap on each shelf about as big as your fist._ - Numerical Recipes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qikt9pPW9L5K" - }, - "source": [ - "### RNGs and state\n", - "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "rr9FeP41fynt", - "outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.2726690048900553\n", - "0.6304191979771206\n", - "0.6933648856441533\n" - ] - } - ], - "source": [ - "print(np.random.random())\n", - "print(np.random.random())\n", - "print(np.random.random())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ORMVVGZJgSVi" - }, - "source": [ - "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "7Pyp2ajzfPO2" - }, - "outputs": [], - "source": [ - "np.random.seed(0)\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n", - "# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n", - "# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aJIxHVXCiM6m" - }, - "source": [ - "This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "GAHaDCYafpAF" - }, - "outputs": [], - "source": [ - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n", - "\n", - "# Let's exhaust the entropy in this PRNG statevector\n", - "for i in range(311):\n", - " _ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n", - "\n", - "# Next call iterates the RNG state for a new batch of fake \"entropy\".\n", - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n", - "# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N_mWnleNogps" - }, - "source": [ - "The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n", - "\n", - "The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Uvq7nV-j4vKK" - }, - "source": [ - "### JAX PRNG" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "COjzGBpO4tzL" - }, - "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", - "\n", - "The random state is described by a special array element that we call a __key__:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "yPHE7KTWgAWs", - "outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([0, 0], dtype=uint32)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = random.key(0)\n", - "key" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XjYyWYNfq0hW" - }, - "source": [ - "JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n", + "## 🔪 Random numbers\n", "\n", - "Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "7zUdQMynoE5e", - "outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.20584226]\n", - "[0 0]\n", - "[-0.20584226]\n", - "[0 0]\n" - ] - } - ], - "source": [ - "print(random.normal(key, shape=(1,)))\n", - "print(key)\n", - "# No no no!\n", - "print(random.normal(key, shape=(1,)))\n", - "print(key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hQN9van8rJgd" - }, - "source": [ - "Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "ASj0_rSzqgGh", - "outputId": "2f13f249-85d1-47bb-d503-823eca6961aa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [0 0]\n", - " \\---SPLIT --> new key [4146024105 967050713]\n", - " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tqtFVE4MthO3" - }, - "source": [ - "We propagate the __key__ and make new __subkeys__ whenever we need a new random number:" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "jbC34XLor2Ek", - "outputId": "4059a2e2-0205-40bc-ad55-17709d538871" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [4146024105 967050713]\n", - " \\---SPLIT --> new key [2384771982 3928867769]\n", - " \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0KLYUluz3lN3" - }, - "source": [ - "We can generate more than one __subkey__ at a time:" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "lEi08PJ4tfkX", - "outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.37533438]\n", - "[0.98645043]\n", - "[0.14553197]\n" - ] - } - ], - "source": [ - "key, *subkeys = random.split(key, 4)\n", - "for subkey in subkeys:\n", - " print(random.normal(subkey, shape=(1,)))" + "JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial." ] }, { "cell_type": "markdown", + "id": "1dc0e6b2", "metadata": { "id": "rg4CpMZ8c3ri" }, "source": [ - "## 🔪 Control flow" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "izLTvT24dAq0" - }, - "source": [ - "### ✔ Python control_flow + autodiff ✔\n", - "\n", - "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "id": "aAx0T3F8lLtu", - "outputId": "383b7bfa-1634-4d23-8497-49cb9452ca52" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.0\n", - "-4.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "print(grad(f)(2.)) # ok!\n", - "print(grad(f)(4.)) # ok!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hIfPT7WMmZ2H" - }, - "source": [ - "### Python control flow + JIT\n", - "\n", - "Using control flow with `jit` is more complicated, and by default it has more constraints.\n", - "\n", - "This works:" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "OZ_BJX0CplNC", - "outputId": "60c902a2-eba1-49d7-c8c8-2f68616d660c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "24\n" - ] - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " for i in range(3):\n", - " x = 2 * x\n", - " return x\n", - "\n", - "print(f(3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "22RzeJ4QqAuX" - }, - "source": [ - "So does this:" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "id": "pinVnmRWp6w6", - "outputId": "25e06cf2-474f-4782-af7c-4f5514b64422" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "6.0\n" - ] - } - ], - "source": [ - "@jit\n", - "def g(x):\n", - " y = 0.\n", - " for i in range(x.shape[0]):\n", - " y = y + x[i]\n", - " return y\n", - "\n", - "print(g(jnp.array([1., 2., 3.])))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TStltU2dqf8A" - }, - "source": [ - "But this doesn't, at least by default:" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "id": "9z38AIKclRNM", - "outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ConcretizationTypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n" - ] - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "# This will fail!\n", - "f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pIbr4TVPqtDN" - }, - "source": [ - "__What gives!?__\n", - "\n", - "When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n", - "\n", - "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", - "\n", - "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", - "\n", - "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", - "\n", - "But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n", - "\n", - "The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "id": "-Tzp0H7Bt1Sn", - "outputId": "f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "f = jit(f, static_argnums=(0,))\n", - "\n", - "print(f(2.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MHm1hIQAvBVs" - }, - "source": [ - "Here's another example, this time involving a loop:" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "id": "iwY86_JKvD6b", - "outputId": "48f9b51f-bd32-466f-eac1-cd23444ce937" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(5., dtype=float32)" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f(x, n):\n", - " y = 0.\n", - " for i in range(n):\n", - " y = y + x[i]\n", - " return y\n", - "\n", - "f = jit(f, static_argnums=(1,))\n", - "\n", - "f(jnp.array([2., 3., 4.]), 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nSPTOX8DvOeO" - }, - "source": [ - "In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wWdg8LTYwCW3" - }, - "source": [ - "️⚠️ **functions with argument-__value__ dependent shapes**\n", - "\n", - "These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "id": "Tqe9uLmUI_Gv", - "outputId": "989be121-dfce-4bb3-c78e-a10829c5f883" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[4. 4. 4. 4. 4.]\n" - ] - } - ], - "source": [ - "def example_fun(length, val):\n", - " return jnp.ones((length,)) * val\n", - "# un-jit'd works fine\n", - "print(example_fun(5, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "id": "fOlR54XRgHpd", - "outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Tracedwith,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n" - ] - } - ], - "source": [ - "bad_example_jit = jit(example_fun)\n", - "# this will fail:\n", - "bad_example_jit(10, 4)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": { - "id": "kH0lOD4GgFyI", - "outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n", - "[4. 4. 4. 4. 4.]\n" - ] - } - ], - "source": [ - "# static_argnums tells JAX to recompile on changes at these argument positions:\n", - "good_example_jit = jit(example_fun, static_argnums=(0,))\n", - "# first compile\n", - "print(good_example_jit(10, 4))\n", - "# recompiles\n", - "print(good_example_jit(5, 4))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MStx_r2oKxpp" - }, - "source": [ - "`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n", - "\n", - "Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "id": "m2ABpRd8K094", - "outputId": "4f7ebe17-ade4-4e18-bd8c-4b24087c33c3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tracedwith\n", - "Tracedwith\n" - ] - }, - { - "data": { - "text/plain": [ - "Array(4, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " print(x)\n", - " y = 2 * x\n", - " print(y)\n", - " return y\n", - "f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uCDcWG4MnVn-" - }, - "source": [ - "### Structured control flow primitives\n", - "\n", - "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n", - "\n", - " - `lax.cond` _differentiable_\n", - " - `lax.while_loop` __fwd-mode-differentiable__\n", - " - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.\n", - " - `lax.scan` _differentiable_" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sd9xrLMXeK3A" - }, - "source": [ - "#### `cond`\n", - "python equivalent:\n", - "\n", - "```python\n", - "def cond(pred, true_fun, false_fun, operand):\n", - " if pred:\n", - " return true_fun(operand)\n", - " else:\n", - " return false_fun(operand)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": { - "id": "SGxz9JOWeiyH", - "outputId": "942a8d0e-5ff6-4702-c499-b3941f529ca3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([-1.], dtype=float32)" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import lax\n", - "\n", - "operand = jnp.array([0.])\n", - "lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n", - "# --> array([1.], dtype=float32)\n", - "lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n", - "# --> array([-1.], dtype=float32)" - ] - }, - { - "cell_type": "markdown", - "id": "e6622244", - "metadata": { - "id": "lIYdn1woOS1n" - }, - "source": [ - "`jax.lax` provides two other functions that allow branching on dynamic predicates:\n", - "\n", - "- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is\n", - " like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays\n", - " rather than as functions.\n", - "- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is\n", - " like `lax.cond`, but allows switching between any number of callable choices.\n", - "\n", - "In addition, `jax.numpy` provides several numpy-style interfaces to these functions:\n", - "\n", - "- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with\n", - " three arguments is the numpy-style wrapper of `lax.select`.\n", - "- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)\n", - " is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.\n", - "- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has\n", - " an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather\n", - " than as functions. It is implemented in terms of multiple calls to `lax.select`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xkOFAw24eOMg" - }, - "source": [ - "#### `while_loop`\n", - "\n", - "python equivalent:\n", - "```\n", - "def while_loop(cond_fun, body_fun, init_val):\n", - " val = init_val\n", - " while cond_fun(val):\n", - " val = body_fun(val)\n", - " return val\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": { - "id": "jM-D39a-c436", - "outputId": "552fe42f-4d32-4e25-c8c2-b951160a3f4e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(10, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "init_val = 0\n", - "cond_fun = lambda x: x < 10\n", - "body_fun = lambda x: x+1\n", - "lax.while_loop(cond_fun, body_fun, init_val)\n", - "# --> array(10, dtype=int32)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "apo3n3HAeQY_" - }, - "source": [ - "#### `fori_loop`\n", - "python equivalent:\n", - "```\n", - "def fori_loop(start, stop, body_fun, init_val):\n", - " val = init_val\n", - " for i in range(start, stop):\n", - " val = body_fun(i, val)\n", - " return val\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "id": "dt3tUpOmeR8u", - "outputId": "7819ca7c-1433-4d85-b542-f6159b0e8380" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(45, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "init_val = 0\n", - "start = 0\n", - "stop = 10\n", - "body_fun = lambda i,x: x+i\n", - "lax.fori_loop(start, stop, body_fun, init_val)\n", - "# --> array(45, dtype=int32)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SipXS5qiqk8e" - }, - "source": [ - "#### Summary\n", - "\n", - "$$\n", - "\\begin{array} {r|rr}\n", - "\\hline \\\n", - "\\textrm{construct}\n", - "& \\textrm{jit}\n", - "& \\textrm{grad} \\\\\n", - "\\hline \\\n", - "\\textrm{if} & ❌ & ✔ \\\\\n", - "\\textrm{for} & ✔* & ✔\\\\\n", - "\\textrm{while} & ✔* & ✔\\\\\n", - "\\textrm{lax.cond} & ✔ & ✔\\\\\n", - "\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n", - "\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n", - "\\textrm{lax.scan} & ✔ & ✔\\\\\n", - "\\hline\n", - "\\end{array}\n", - "$$\n", - "\n", - "
\n", - "\n", - "$\\ast$ = argument-value-independent loop condition - unrolls the loop\n", + "## 🔪 Control flow\n", "\n", - "
" + "Moved to {ref}`control-flow`." ] }, { @@ -2209,6 +1310,9 @@ " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", + "## 🔪 Sharp bits covered in tutorials\n", + "- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n", + "- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.\n", "\n", "## Fin.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 741fa3af063c..1529dcef5e37 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -31,7 +31,7 @@ JAX works great for many numerical and scientific programs, but __only if they a :id: GoK_PCxPeYcy import numpy as np -from jax import grad, jit +from jax import jit from jax import lax from jax import random import jax @@ -121,7 +121,7 @@ print(jit(pure_uses_internal_state)(5.)) +++ {"id": "cDpQ5u63Ba_H"} -It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results. +It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX's functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results. ```{code-cell} ipython3 :id: w99WXa6bBa_H @@ -384,480 +384,13 @@ jnp.sum(jnp.array(x)) ## 🔪 Random numbers -+++ {"id": "O8vvaVt3MRG2"} - -> _If all scientific papers whose results are in doubt because of bad -> `rand()`s were to disappear from library shelves, there would be a -> gap on each shelf about as big as your fist._ - Numerical Recipes - -+++ {"id": "Qikt9pPW9L5K"} - -### RNGs and state -You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness: - -```{code-cell} ipython3 -:id: rr9FeP41fynt -:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 - -print(np.random.random()) -print(np.random.random()) -print(np.random.random()) -``` - -+++ {"id": "ORMVVGZJgSVi"} - -Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up. - -```{code-cell} ipython3 -:id: 7Pyp2ajzfPO2 - -np.random.seed(0) -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044, -# 2481403966, 4042607538, 337614300, ... 614 more numbers..., -# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0) -``` - -+++ {"id": "aJIxHVXCiM6m"} - -This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector: - -```{code-cell} ipython3 -:id: GAHaDCYafpAF - -_ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0) - -# Let's exhaust the entropy in this PRNG statevector -for i in range(311): - _ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0) - -# Next call iterates the RNG state for a new batch of fake "entropy". -_ = np.random.uniform() -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([1499117434, 2949980591, 2242547484, -# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0) -``` - -+++ {"id": "N_mWnleNogps"} - -The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user. - -The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. - -+++ {"id": "Uvq7nV-j4vKK"} - -### JAX PRNG - -+++ {"id": "COjzGBpO4tzL"} - -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. - -The random state is described by a special array element that we call a __key__: - -```{code-cell} ipython3 -:id: yPHE7KTWgAWs -:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 - -key = random.key(0) -key -``` - -+++ {"id": "XjYyWYNfq0hW"} - -JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! - -Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__: - -```{code-cell} ipython3 -:id: 7zUdQMynoE5e -:outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805 - -print(random.normal(key, shape=(1,))) -print(key) -# No no no! -print(random.normal(key, shape=(1,))) -print(key) -``` - -+++ {"id": "hQN9van8rJgd"} - -Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number: - -```{code-cell} ipython3 -:id: ASj0_rSzqgGh -:outputId: 2f13f249-85d1-47bb-d503-823eca6961aa - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "tqtFVE4MthO3"} - -We propagate the __key__ and make new __subkeys__ whenever we need a new random number: - -```{code-cell} ipython3 -:id: jbC34XLor2Ek -:outputId: 4059a2e2-0205-40bc-ad55-17709d538871 - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "0KLYUluz3lN3"} - -We can generate more than one __subkey__ at a time: - -```{code-cell} ipython3 -:id: lEi08PJ4tfkX -:outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01 - -key, *subkeys = random.split(key, 4) -for subkey in subkeys: - print(random.normal(subkey, shape=(1,))) -``` +JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial. +++ {"id": "rg4CpMZ8c3ri"} ## 🔪 Control flow -+++ {"id": "izLTvT24dAq0"} - -### ✔ Python control_flow + autodiff ✔ - -If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). - -```{code-cell} ipython3 -:id: aAx0T3F8lLtu -:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52 - -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -print(grad(f)(2.)) # ok! -print(grad(f)(4.)) # ok! -``` - -+++ {"id": "hIfPT7WMmZ2H"} - -### Python control flow + JIT - -Using control flow with `jit` is more complicated, and by default it has more constraints. - -This works: - -```{code-cell} ipython3 -:id: OZ_BJX0CplNC -:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c - -@jit -def f(x): - for i in range(3): - x = 2 * x - return x - -print(f(3)) -``` - -+++ {"id": "22RzeJ4QqAuX"} - -So does this: - -```{code-cell} ipython3 -:id: pinVnmRWp6w6 -:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422 - -@jit -def g(x): - y = 0. - for i in range(x.shape[0]): - y = y + x[i] - return y - -print(g(jnp.array([1., 2., 3.]))) -``` - -+++ {"id": "TStltU2dqf8A"} - -But this doesn't, at least by default: - -```{code-cell} ipython3 -:id: 9z38AIKclRNM -:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac -:tags: [raises-exception] - -@jit -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -# This will fail! -f(2) -``` - -+++ {"id": "pIbr4TVPqtDN"} - -__What gives!?__ - -When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation. - -For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. - -To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. - -By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. - -But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace. - -The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again: - -```{code-cell} ipython3 -:id: -Tzp0H7Bt1Sn -:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a - -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -f = jit(f, static_argnums=(0,)) - -print(f(2.)) -``` - -+++ {"id": "MHm1hIQAvBVs"} - -Here's another example, this time involving a loop: - -```{code-cell} ipython3 -:id: iwY86_JKvD6b -:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937 - -def f(x, n): - y = 0. - for i in range(n): - y = y + x[i] - return y - -f = jit(f, static_argnums=(1,)) - -f(jnp.array([2., 3., 4.]), 2) -``` - -+++ {"id": "nSPTOX8DvOeO"} - -In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation - -+++ {"id": "wWdg8LTYwCW3"} - -️⚠️ **functions with argument-__value__ dependent shapes** - -These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`. - -```{code-cell} ipython3 -:id: Tqe9uLmUI_Gv -:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883 - -def example_fun(length, val): - return jnp.ones((length,)) * val -# un-jit'd works fine -print(example_fun(5, 4)) -``` - -```{code-cell} ipython3 -:id: fOlR54XRgHpd -:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71 -:tags: [raises-exception] - -bad_example_jit = jit(example_fun) -# this will fail: -bad_example_jit(10, 4) -``` - -```{code-cell} ipython3 -:id: kH0lOD4GgFyI -:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade - -# static_argnums tells JAX to recompile on changes at these argument positions: -good_example_jit = jit(example_fun, static_argnums=(0,)) -# first compile -print(good_example_jit(10, 4)) -# recompiles -print(good_example_jit(5, 4)) -``` - -+++ {"id": "MStx_r2oKxpp"} - -`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! - -Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: - -```{code-cell} ipython3 -:id: m2ABpRd8K094 -:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3 - -@jit -def f(x): - print(x) - y = 2 * x - print(y) - return y -f(2) -``` - -+++ {"id": "uCDcWG4MnVn-"} - -### Structured control flow primitives - -There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives: - - - `lax.cond` _differentiable_ - - `lax.while_loop` __fwd-mode-differentiable__ - - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static. - - `lax.scan` _differentiable_ - -+++ {"id": "Sd9xrLMXeK3A"} - -#### `cond` -python equivalent: - -```python -def cond(pred, true_fun, false_fun, operand): - if pred: - return true_fun(operand) - else: - return false_fun(operand) -``` - -```{code-cell} ipython3 -:id: SGxz9JOWeiyH -:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3 - -from jax import lax - -operand = jnp.array([0.]) -lax.cond(True, lambda x: x+1, lambda x: x-1, operand) -# --> array([1.], dtype=float32) -lax.cond(False, lambda x: x+1, lambda x: x-1, operand) -# --> array([-1.], dtype=float32) -``` - -+++ {"id": "lIYdn1woOS1n"} - -`jax.lax` provides two other functions that allow branching on dynamic predicates: - -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is - like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays - rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is - like `lax.cond`, but allows switching between any number of callable choices. - -In addition, `jax.numpy` provides several numpy-style interfaces to these functions: - -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with - three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) - is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has - an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather - than as functions. It is implemented in terms of multiple calls to `lax.select`. - -+++ {"id": "xkOFAw24eOMg"} - -#### `while_loop` - -python equivalent: -``` -def while_loop(cond_fun, body_fun, init_val): - val = init_val - while cond_fun(val): - val = body_fun(val) - return val -``` - -```{code-cell} ipython3 -:id: jM-D39a-c436 -:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e - -init_val = 0 -cond_fun = lambda x: x < 10 -body_fun = lambda x: x+1 -lax.while_loop(cond_fun, body_fun, init_val) -# --> array(10, dtype=int32) -``` - -+++ {"id": "apo3n3HAeQY_"} - -#### `fori_loop` -python equivalent: -``` -def fori_loop(start, stop, body_fun, init_val): - val = init_val - for i in range(start, stop): - val = body_fun(i, val) - return val -``` - -```{code-cell} ipython3 -:id: dt3tUpOmeR8u -:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380 - -init_val = 0 -start = 0 -stop = 10 -body_fun = lambda i,x: x+i -lax.fori_loop(start, stop, body_fun, init_val) -# --> array(45, dtype=int32) -``` - -+++ {"id": "SipXS5qiqk8e"} - -#### Summary - -$$ -\begin{array} {r|rr} -\hline \ -\textrm{construct} -& \textrm{jit} -& \textrm{grad} \\ -\hline \ -\textrm{if} & ❌ & ✔ \\ -\textrm{for} & ✔* & ✔\\ -\textrm{while} & ✔* & ✔\\ -\textrm{lax.cond} & ✔ & ✔\\ -\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ -\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ -\textrm{lax.scan} & ✔ & ✔\\ -\hline -\end{array} -$$ - -
- -$\ast$ = argument-value-independent loop condition - unrolls the loop - -
+Moved to {ref}`control-flow`. +++ {"id": "OxLsZUyRt_kF"} @@ -1145,6 +678,9 @@ Many such cases are discussed in detail in the sections above; here we list seve ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. +## 🔪 Sharp bits covered in tutorials +- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators. +- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions. ## Fin. diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 82381838a5aa..feb906546341 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -1129,7 +1129,7 @@ "source": [ "When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.\n", "\n", - "One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n", + "One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n", "\n", "For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:" ] diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 0a6c84b2d88f..8ba87dcfee18 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -490,7 +490,7 @@ print_fwd_bwd(f, 3.) When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you. -One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`. +One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`. For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this: diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 37c27ce2728a..d73b0d4c0f3e 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -864,7 +864,7 @@ "Indeed, this implementation is often used on both TPU and GPU!\n", "\n", "The reason `psum_scatter` can require about half the communication as a full\n", - "`psum` is illustrated the `ppermute` section.\n", + "`psum` is illustrated in the `ppermute` section.\n", "\n", "Another intuition is that we can use `psum_scatter` to implement a distributed\n", "matrix multiplication with inputs and outputs sharded over the same axis. In\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 47b11079e27d..c52cf0e6d22b 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -627,7 +627,7 @@ def psum(x, axis_name): Indeed, this implementation is often used on both TPU and GPU! The reason `psum_scatter` can require about half the communication as a full -`psum` is illustrated the `ppermute` section. +`psum` is illustrated in the `ppermute` section. Another intuition is that we can use `psum_scatter` to implement a distributed matrix multiplication with inputs and outputs sharded over the same axis. In diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index d7ed91011a95..94dbeb3aa70d 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,7 +11,30 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> -## Released with jax 0.4.34 +## Released with jax 0.4.37 + +* New functionality + + * Added support for `DotAlgorithmPreset` precision arguments for `dot` + lowering on Triton backend. + +## Released with jax 0.4.36 (December 6, 2024) + +## Released with jax 0.4.35 (October 22, 2024) + +* Removals + + * Removed previously deprecated aliases + {class}`jax.experimental.pallas.tpu.CostEstimate` and + {func}`jax.experimental.tpu.run_scoped`. Both are now available in + {mod}`jax.experimental.pallas`. + +* New functionality + + * Added a cost estimate tool {func}`pl.estimate_cost` for automatically + constructing a kernel cost estimate from a JAX reference function. + +## Released with jax 0.4.34 (October 4, 2024) * Changes diff --git a/docs/pallas/async_note.md b/docs/pallas/design/async_note.md similarity index 100% rename from docs/pallas/async_note.md rename to docs/pallas/design/async_note.md diff --git a/docs/pallas/design.md b/docs/pallas/design/design.md similarity index 99% rename from docs/pallas/design.md rename to docs/pallas/design/design.md index f6fc8f5926cb..17c7a6dbdc0f 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design/design.md @@ -94,7 +94,7 @@ Pallas kernels via JAX transformations.
-![Pallas lowering path](../_static/pallas/pallas_flow.png) +![Pallas lowering path](../../_static/pallas/pallas_flow.png) Visualization of Pallas lowering paths
@@ -413,10 +413,10 @@ verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU. -### Examples +### GPU Examples -Note all the following examples are for GPU only. They will require some small -changes to work on TPUs. +Note all the following examples are for GPU only. They will require tweaks to +the block sizes to work on TPUs. #### `add` diff --git a/docs/pallas/design/index.rst b/docs/pallas/design/index.rst new file mode 100644 index 000000000000..d11a13d39fe8 --- /dev/null +++ b/docs/pallas/design/index.rst @@ -0,0 +1,9 @@ +Pallas Design Notes +=================== + +.. toctree:: + :caption: Design + :maxdepth: 2 + + design + async_note diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index cde200528785..c1b2c2b95229 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -44,39 +44,7 @@ For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and You can also use {func}`jax.experimental.pallas.num_programs` to get the grid size for a given axis. -Here's an example kernel that uses a `grid` and `program_id`. - -```python ->>> import jax ->>> from jax.experimental import pallas as pl - ->>> def iota_kernel(o_ref): -... i = pl.program_id(0) -... o_ref[i] = i - -``` - -We now execute it using `pallas_call` with an additional `grid` argument. - -```python ->>> def iota(size: int): -... return pl.pallas_call(iota_kernel, -... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), -... grid=(size,), interpret=True)() ->>> iota(8) -Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) - -``` - -On GPUs, each program is executed in parallel on separate thread blocks. -Thus, we need to think about race conditions on writes to HBM. -A reasonable approach is to write our kernels in such a way that different -programs write to disjoint places in HBM to avoid these parallel writes. - -On TPUs, programs are executed in a combination of parallel and sequential -(depending on the architecture) so there are slightly different considerations. - -See {ref}`pallas_tpu_noteworthy_properties`. +See {ref}`grids_by_example` for a simple kernel that uses this API. (pallas_blockspec)= @@ -131,6 +99,8 @@ shape `x_shape` are computed as in the function `slice_for_invocation` below: ```python +>>> import jax +>>> from jax.experimental import pallas as pl >>> def slices_for_invocation(x_shape: tuple[int, ...], ... x_spec: pl.BlockSpec, ... grid: tuple[int, ...], diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 5969349c962a..b2e2fca6c82e 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -22,7 +22,6 @@ See also the :class:`jax.experimental.pallas` module API documentation. :maxdepth: 2 quickstart - design grid_blockspec @@ -34,9 +33,9 @@ See also the :class:`jax.experimental.pallas` module API documentation. .. toctree:: :caption: Design Notes - :maxdepth: 1 + :maxdepth: 2 - async_note + design/index .. toctree:: :caption: Other diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 50464ce8ffd4..11dd2108e405 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -72,8 +72,9 @@ "\n", "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", - "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", - "but we are given an `o_ref`, which corresponds to the desired output.\n", + "Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n", + "Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n", + "to the desired output.\n", "\n", "**Reading from `Ref`s**\n", "\n", @@ -150,7 +151,8 @@ "**What's actually happening here?**\n", "\n", "Thus far we've described how to think about Pallas kernels but what we've actually\n", - "accomplished is we're writing a function that's executed very close to the compute units.\n", + "accomplished is we're writing a function that's executed very close to the compute units\n", + "since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n", "\n", "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n", "we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n", @@ -195,6 +197,8 @@ "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", "\n", + "(grids_by_example)=\n", + "\n", "### Grids by example\n", "\n", "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", @@ -240,7 +244,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We now execute it using `pallas_call` with an additional `grid` argument." + "We now execute it using `pallas_call` with an additional `grid` argument.\n", + "On GPUs, we can call the kernel directly like so:" ] }, { @@ -260,6 +265,7 @@ } ], "source": [ + "# GPU version\n", "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", @@ -272,16 +278,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "On GPUs, each program is executed in parallel on separate threads.\n", - "Thus, we need to think about race conditions on writes to HBM.\n", - "A reasonable approach is to write our kernels in such a way that different\n", - "programs write to disjoint places in HBM to avoid these parallel writes.\n", - "On the other hand, parallelizing the computation is how we can execute\n", - "operations like matrix multiplications really quickly.\n", - "\n", - "On TPUs, programs are executed in a combination of parallel and sequential\n", - "(depending on the architecture) so there are slightly different considerations.\n", - "\n", + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n", "To call the above kernel on TPU, run:" ] }, @@ -292,6 +291,7 @@ "metadata": {}, "outputs": [], "source": [ + "# TPU version\n", "from jax.experimental.pallas import tpu as pltpu\n", "\n", "def iota(size: int):\n", @@ -307,11 +307,22 @@ "id": "68f97b4e", "metadata": {}, "source": [ - "TPUs distinguish between vector and scalar memory spaces and in this case the\n", - "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", - "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", + "### Grid semantics\n", + "\n", + "On GPUs, each program is executed in parallel on separate threads.\n", + "Thus, we need to think about race conditions on writes to HBM.\n", + "A reasonable approach is to write our kernels in such a way that different\n", + "programs write to disjoint locations in HBM to avoid these parallel writes.\n", + "On the other hand, parallelizing the computation is how we can execute\n", + "operations like matrix multiplications really quickly.\n", + "\n", + "In contrast, TPUs operate like a very wide SIMD machine.\n", + "Some TPU models contain multiple cores, but in many cases a TPU can be\n", + "treated as a single-threaded processor. The grid on a TPU can be\n", + "specified in a combination of parallel and sequential dimensions, where sequential\n", + "dimensions are guaranteed to run serially.\n", "\n", - "You can read more details at {ref}`pallas_grid`." + "You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`." ] }, { @@ -412,7 +423,7 @@ "\n", "For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this\n", "carves `x` up into \"row\" blocks.\n", - "To see this see how both program instances\n", + "To see this, see how both program instances\n", "`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n", "For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.\n", "Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index b9acd6497fb5..fff1dcb730f3 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -53,8 +53,9 @@ def add_vectors_kernel(x_ref, y_ref, o_ref): Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. -Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs -but we are given an `o_ref`, which corresponds to the desired output. +Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory. +Note that we also don't have any outputs but we are given an `o_ref`, which corresponds +to the desired output. **Reading from `Ref`s** @@ -101,7 +102,8 @@ thereof). **What's actually happening here?** Thus far we've described how to think about Pallas kernels but what we've actually -accomplished is we're writing a function that's executed very close to the compute units. +accomplished is we're writing a function that's executed very close to the compute units +since values are loaded into the innermost (fastest) portion of the memory hierarchy. On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) @@ -134,6 +136,8 @@ Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on "blocks" of those arrays that can fit in SRAM. +(grids_by_example)= + ### Grids by example To automatically "carve" up the inputs and outputs, you provide a `grid` and @@ -169,8 +173,10 @@ def iota_kernel(o_ref): ``` We now execute it using `pallas_call` with an additional `grid` argument. +On GPUs, we can call the kernel directly like so: ```{code-cell} ipython3 +# GPU version def iota(size: int): return pl.pallas_call(iota_kernel, out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), @@ -178,19 +184,13 @@ def iota(size: int): iota(8) ``` -On GPUs, each program is executed in parallel on separate threads. -Thus, we need to think about race conditions on writes to HBM. -A reasonable approach is to write our kernels in such a way that different -programs write to disjoint places in HBM to avoid these parallel writes. -On the other hand, parallelizing the computation is how we can execute -operations like matrix multiplications really quickly. - -On TPUs, programs are executed in a combination of parallel and sequential -(depending on the architecture) so there are slightly different considerations. - +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read {ref}`tpu_and_its_memory_spaces`. To call the above kernel on TPU, run: ```{code-cell} ipython3 +# TPU version from jax.experimental.pallas import tpu as pltpu def iota(size: int): @@ -201,11 +201,22 @@ def iota(size: int): iota(8) ``` -TPUs distinguish between vector and scalar memory spaces and in this case the -output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is -a scalar. For more details read {ref}`pallas_tpu_pipelining`. +### Grid semantics + +On GPUs, each program is executed in parallel on separate threads. +Thus, we need to think about race conditions on writes to HBM. +A reasonable approach is to write our kernels in such a way that different +programs write to disjoint locations in HBM to avoid these parallel writes. +On the other hand, parallelizing the computation is how we can execute +operations like matrix multiplications really quickly. + +In contrast, TPUs operate like a very wide SIMD machine. +Some TPU models contain multiple cores, but in many cases a TPU can be +treated as a single-threaded processor. The grid on a TPU can be +specified in a combination of parallel and sequential dimensions, where sequential +dimensions are guaranteed to run serially. -You can read more details at {ref}`pallas_grid`. +You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`. +++ @@ -294,7 +305,7 @@ To express this, we'd first use a `(2, 2)` grid (one block for each program). For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this carves `x` up into "row" blocks. -To see this see how both program instances +To see this, see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`. Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`. diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index b7ce10d564f6..0575806e6037 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results. spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a low-level compiler error message complaining about an out-of-memory error. -Dimension ordering is meaningful -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Array Layouts +^^^^^^^^^^^^^ +Dimension ordering of arrays is meaningful in Pallas. In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code. -Recall that the TPUs perform bulk of the computation on 2D vector registers. -Pallas TPU will only ever consider mapping the last two dimensions of -intermediate arrays to those vector register dimensions (sublanes and lanes -respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least -``n`` vector registers to represent. If ``n`` becomes too large, this can lead -to spills, and potential VMEM OOM errors due to an overly large memory footprint. -But it also might not --- the low-level compiler is free to rearrange the -instructions to lower the register pressure, and is in fact very good at it. -Still, it is a good rule of thumb to keep the last two dimensions large -(especially the last dimension), while keeping the leading dimensions small. +TPUs perform bulk of the computation on 2D vector registers, which are typically of +size 8x128 for 32-bit values (as of TPU v6). +When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``), +the last two dimensions of the array will be tiled into the registers. +Pallas will only ever consider mapping the last two dimensions of +intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes +respectively). + +Here is a graphical example of how a 12x320 array can be tiled using 6 8x128 +tiles: + +.. image:: ../../_static/pallas/vector_layout_example.svg + +Tiled layouts have several import ramifications for kernel writers: + +* The last two axes of an array are treated differently than other + axes. For example, reductions, reshapes, and transposes are generally + more expensive when involving the last two axes. Some reshapes + involving the last two dimensions are not supported and will result in a compiler + error, but are "free" and performed at compile time for other dimensions. +* While sometimes unavoidable, it is generally wasteful to have singleton + dimensions in the last two axes, since they will occupy 1 element out of + the entire tile dimension. Consuming too many registers can + also potentially cause register spills into VMEM which degrades kernel + performance. +* Related to the above point, all vector computation is padded up to the tile + size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and + adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two + 8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128. Multicore TPU configurations ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will receive not only the grid indices, but also the SMEM references to the leading operands. -.. note:: - We are working on implementing examples for this feature. Stay tuned! +See :ref:`pallas_scalar_prefetch_guide` for examples on using this +feature. Supported data types ^^^^^^^^^^^^^^^^^^^^ -At the moment Pallas TPU only supports the following data types: +At the moment Pallas TPU supports the following data types: * ``jnp.float32`` * ``jnp.bfloat16`` * ``jnp.int*`` (all precisions, except for ``jnp.int4``) * ``jnp.uint*`` (all precisions) +* ``jnp.bool_`` Computation placement ^^^^^^^^^^^^^^^^^^^^^ @@ -306,14 +327,13 @@ Array constructors ^^^^^^^^^^^^^^^^^^ All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``, -``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with -Pallas as of today. +``jnp.full``). Reductions ^^^^^^^^^^ -Sum, maximum and minimum reductions are supported, but only on a single array -axis at a time. +``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well +as ``any`` and ``all`` for boolean values. Integer reductions are not supported. Reductions over the last array dimension are generally the slowest. Reductions over the second last dimension are faster, but still slower than @@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second to last dimension, or (2) it adds a dimension that was just removed by a reduction. +Random Number Generation +^^^^^^^^^^^^^^^^^^^^^^^^ + +Pallas supports the most commonly used functions from the ``jax.random`` module, +such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key, +which is the default setting in JAX. Keys can be directly passed into a kernel, +or generated inside of a kernel. + Control flow ^^^^^^^^^^^^ diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 9774e08dcda8..10de587105f2 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -48,12 +48,20 @@ }, { "cell_type": "markdown", + "id": "0e212a5e", "metadata": { "id": "TWKESTKAlyjT" }, "source": [ - "## TPU and its memory spaces\n", + "(tpu_and_its_memory_spaces)=\n", "\n", + "## TPU and its memory spaces" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", "registers (which temporarily store scalar and array values) and compute units\n", "(that do computation with values in registers).\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 21865430178d..df570cf0806c 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -38,8 +38,12 @@ import numpy as np +++ {"id": "TWKESTKAlyjT"} +(tpu_and_its_memory_spaces)= + ## TPU and its memory spaces ++++ + A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index a80ba4ebedbb..5b37e7b0574b 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -6,6 +6,8 @@ "id": "ZHuzXqQ-9JUQ" }, "source": [ + "(pallas_scalar_prefetch_guide)=\n", + "\n", "# Scalar Prefetch and Block-Sparse Computation\n", "\n", "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory." diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 2ac25edb5064..36a6e07e9192 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -14,6 +14,8 @@ kernelspec: +++ {"id": "ZHuzXqQ-9JUQ"} +(pallas_scalar_prefetch_guide)= + # Scalar Prefetch and Block-Sparse Computation In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory. diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 47a7587b620f..37afa2f594e3 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,11 +1,18 @@ # Persistent compilation cache - + JAX has an optional disk cache for compiled programs. If enabled, JAX will store copies of compiled programs on disk, which can save recompilation time when running the same or similar tasks repeatedly. +Note: if the compilation cache is not on a local filesystem, +[etils](https://pypi.org/project/etils/) needs to be installed. + +```python +pip install etils +``` + ## Usage ### Quick start @@ -17,6 +24,7 @@ import jax.numpy as jnp jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") @jax.jit def f(x): @@ -70,16 +78,33 @@ cc.set_cache_dir("/tmp/jax_cache") * `jax_persistent_cache_min_entry_size_bytes`: The minimum size (in bytes) of an entry that will be cached in the persistent compilation cache: - * `-1`: disable the size restriction and prevent overrides. + * `-1`: disable the size restriction and prevent overrides. * Leave at default (`0`) to allow for overrides. The override will typically ensure that the minimum size is optimal for the file system - being used for the cache. + being used for the cache. * `> 0`: the actual minimum size desired; no overrides. Note that both criteria need to be satisfied for a function to be cached. +### Additional caching + +XLA supports additional caching mechanism which can be enabled alongside JAX's +persistent compilation cache to further improve recompilation time. + +* `jax_persistent_cache_enable_xla_caches`: Possible values: + + * `all`: enable all XLA caching features + + * `none`: don't enable any extra XLA caching features + + * `xla_gpu_kernel_cache_file`: only enable the kernel cache + + * `xla_gpu_per_fusion_autotune_cache_dir`: (default value) only enable the + autotuning cache + + ### Google Cloud When running on Google Cloud, the compilation cache can be placed on a Google @@ -155,7 +180,14 @@ import os os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache" ``` -on the top of the script. +on the top of the script. Alternatively, you can change the global jax logging level with + +```python +import os +os.environ["JAX_LOGGING_LEVEL"] = "DEBUG" +# or locally with +jax.config.update("jax_logging_level", "DEBUG") +``` ### Examining cache misses diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 2ad1eadb0968..00f77e3473bb 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -17,6 +17,10 @@ kernelspec: +> _If all scientific papers whose results are in doubt because of bad +> `rand()`s were to disappear from library shelves, there would be a +> gap on each shelf about as big as your fist._ - Numerical Recipes + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. @@ -35,6 +39,19 @@ import numpy as np np.random.seed(0) ``` +Repeated calls to NumPy's stateful pseudorandom number generators (PRNGs) mutate the global state and give a stream of pseudorandom numbers: + +```{code-cell} +:id: rr9FeP41fynt +:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 + +print(np.random.random()) +print(np.random.random()) +print(np.random.random()) +``` + +Underneath the hood, NumPy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this "entropy" has been used up. + You can inspect the content of the state using the following command. ```{code-cell} @@ -109,7 +126,7 @@ Further, when executing in multi-device environments, execution efficiency would ### Explicit random state -To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: +To avoid these issues, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: ```{code-cell} from jax import random @@ -137,6 +154,7 @@ Re-using the same key, even with different {mod}`~jax.random` APIs, can result i **The rule of thumb is: never reuse keys (unless you want identical outputs).** +JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: ```{code-cell} diff --git a/docs/requirements.txt b/docs/requirements.txt index 41d8aa6d9ee7..bfbb4e271d42 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,8 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error +pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1 -sphinx-book-theme>=1.0.1 # Older versions fail to pin pydata-sphinx-theme +sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15 sphinx-copybutton>=0.5.0 sphinx-remove-toctrees sphinx-design diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 2ff82e0431e2..fe84fc0d7f0a 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -12,6 +12,7 @@ kernelspec: name: python3 --- +(stateful-computations)= # Stateful computations diff --git a/docs/tutorials.rst b/docs/tutorials.rst index a31517155e1a..c9c2fdb1dcc7 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -16,6 +16,7 @@ Tutorials working-with-pytrees sharded-computation stateful-computations + control-flow .. toctree:: :maxdepth: 1 diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md index e41179996bc4..537a4df3e5a6 100644 --- a/docs/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -272,6 +272,49 @@ jax.tree.leaves([ Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way. +Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator: + +```{code-cell} +from dataclasses import dataclass +import functools + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=['a', 'b', 'c'], + meta_fields=['name']) +@dataclass +class MyDataclassContainer(object): + name: str + a: Any + b: Any + c: Any + +# MyDataclassContainer is now a pytree node. +jax.tree.leaves([ + MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])), + MyDataclassContainer('banana', np.array([3, 4]), -1., 0.) +]) +``` + +Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args): + +```{code-cell} +@jax.jit +def f(x: MyDataclassContainer | MyOtherContainer): + return x.a + x.b + +# Works fine! `mdc.name` is static. +mdc = MyDataclassContainer('mdc', 1, 2, 3) +y = f(mdc) +``` + +Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error: + +```{code-cell} +:tags: [raises-exception] + +moc = MyOtherContainer('moc', 1, 2, 3) +y = f(moc) +``` (pytree-and-jax-transformations)= ## Pytrees and JAX transformations diff --git a/docs/xla_flags.md b/docs/xla_flags.md index b332940ccb9d..fd351a7966b2 100644 --- a/docs/xla_flags.md +++ b/docs/xla_flags.md @@ -59,7 +59,7 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py | Flag | Type | Notes | | ---- | ---- | ----- | | `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. | -| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure. | +| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes don't match what can be saved in place in the stacked variables. Can increase memory pressure. | | `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. | | `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to `auto`, it will be enabled based on the target. | | `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. | diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 62142fd49034..843c2cda0e3b 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -1,6 +1,8 @@ cmake_minimum_required(VERSION 3.15...3.30) project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) +option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF) + find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) execute_process( COMMAND "${Python_EXECUTABLE}" @@ -10,10 +12,23 @@ message(STATUS "XLA include directory: ${XLA_DIR}") find_package(nanobind CONFIG REQUIRED) -nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") -target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) -install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +set( + JAX_FFI_EXAMPLE_PROJECTS + "rms_norm" + "cpu_examples" +) + +foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS}) + nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc") + target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR}) + install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +endforeach() -nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc") -target_include_directories(_attrs PUBLIC ${XLA_DIR}) -install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +if(JAX_FFI_EXAMPLE_ENABLE_CUDA) + enable_language(CUDA) + add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu") + set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON + CUDA_STANDARD 17) + target_include_directories(_cuda_examples PUBLIC ${XLA_DIR}) + install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +endif() diff --git a/examples/ffi/README.md b/examples/ffi/README.md index cc7018782a25..bd45408e50d8 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -3,7 +3,27 @@ This directory includes an example project demonstrating the use of JAX's foreign function interface (FFI). The JAX docs provide more information about this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), -but the example in this directory explicitly demonstrates: +but the example in this directory complements that document by demonstrating +(and testing!) the full packaging workflow, and some more advanced use cases. +Within the example project, there are several example calls: -1. One way to package and distribute FFI targets, and -2. Some more advanced use cases. +1. `rms_norm`: This is the example from the tutorial on the JAX docs, and it + demonstrates the most basic use of the FFI. It also includes customization of + behavior under automatic differentiation using `jax.custom_vjp`. + +2. `cpu_examples`: This submodule includes several smaller examples: + + * `counter`: This example demonstrates a common pattern for how an FFI call + can use global cache to maintain state between calls. This pattern is + useful when an FFI call requires an expensive initialization step which + shouldn't be run on every execution, or if there is other shared state + that could be reused between calls. In this simple example we just count + the number of times the call was executed. + * `attrs`: An example demonstrating the different ways that attributes can be + passed to the FFI. For example, we can pass arrays, variadic attributes, + and user-defined types. Full support of user-defined types isn't yet + supported by XLA, so that example will be added in the future. + +3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI + with CUDA. The specifics of the kernels are not very important, but the + general structure, and packaging of the extension are useful for testing. diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc similarity index 52% rename from examples/ffi/src/jax_ffi_example/attrs.cc rename to examples/ffi/src/jax_ffi_example/cpu_examples.cc index 2a6e8d847cf4..3832c86b29b2 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.cc +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" @@ -21,8 +24,19 @@ limitations under the License. namespace nb = nanobind; namespace ffi = xla::ffi; +// ---------- +// Attributes +// ---------- +// +// An example demonstrating the different ways that attributes can be passed to +// the FFI. +// +// For example, we can pass arrays, variadic attributes, and user-defined types. +// Full support of user-defined types isn't yet supported by XLA, so that +// example will be added in the future. + ffi::Error ArrayAttrImpl(ffi::Span array, - ffi::Result> res) { + ffi::ResultBufferR0 res) { int64_t total = 0; for (int32_t x : array) { total += x; @@ -37,8 +51,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, .Ret>()); ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, - ffi::Result> secret, - ffi::Result> count) { + ffi::ResultBufferR0 secret, + ffi::ResultBufferR0 count) { auto maybe_secret = attrs.get("secret"); if (maybe_secret.has_error()) { return maybe_secret.error(); @@ -54,13 +68,52 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, .Ret>() .Ret>()); -NB_MODULE(_attrs, m) { +// ------- +// Counter +// ------- +// +// An example demonstrating how an FFI call can maintain "state" between calls +// +// In this case, the ``Counter`` call simply accumulates the number of times it +// was executed, but this pattern can also be used for more advanced use cases. +// For example, this pattern is used in jaxlib for: +// +// 1. The GPU solver linear algebra kernels which require an expensive "handler" +// initialization, and +// 2. The ``triton_call`` function which caches the compiled triton modules +// after their first use. + +ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { + static std::mutex mutex; + static auto &cache = *new std::unordered_map(); + { + const std::lock_guard lock(mutex); + auto it = cache.find(index); + if (it != cache.end()) { + out->typed_data()[0] = ++it->second; + } else { + cache.insert({index, 0}); + out->typed_data()[0] = 0; + } + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Counter, CounterImpl, + ffi::Ffi::Bind().Attr("index").Ret>()); + +// Boilerplate for exposing handlers to Python +NB_MODULE(_cpu_examples, m) { m.def("registrations", []() { nb::dict registrations; registrations["array_attr"] = nb::capsule(reinterpret_cast(ArrayAttr)); registrations["dictionary_attr"] = nb::capsule(reinterpret_cast(DictionaryAttr)); + + registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + return registrations; }); } diff --git a/examples/ffi/src/jax_ffi_example/attrs.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py similarity index 73% rename from examples/ffi/src/jax_ffi_example/attrs.py rename to examples/ffi/src/jax_ffi_example/cpu_examples.py index 2f215e8e25b1..7771237e41d1 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -12,22 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An example demonstrating the different ways that attributes can be passed to -the FFI. - -For example, we can pass arrays, variadic attributes, and user-defined types. -Full support of user-defined types isn't yet supported by XLA, so that example -will be added in the future. -""" - import numpy as np import jax import jax.extend as jex -from jax_ffi_example import _attrs +from jax_ffi_example import _cpu_examples -for name, target in _attrs.registrations().items(): +for name, target in _cpu_examples.registrations().items(): jex.ffi.register_ffi_target(name, target) @@ -43,3 +35,8 @@ def dictionary_attr(**kwargs): "dictionary_attr", (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), )(**kwargs) + + +def counter(index): + return jex.ffi.ffi_call( + "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/docs/cuda_custom_call/foo.cu.cc b/examples/ffi/src/jax_ffi_example/cuda_examples.cu similarity index 79% rename from docs/cuda_custom_call/foo.cu.cc rename to examples/ffi/src/jax_ffi_example/cuda_examples.cu index 858b5f8a888a..240adb6d6a8c 100644 --- a/docs/cuda_custom_call/foo.cu.cc +++ b/examples/ffi/src/jax_ffi_example/cuda_examples.cu @@ -44,11 +44,9 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *c, // Buffer type provides buffer dimensions, so the "n" argument here is not // strictly necessary, but it allows us to demonstrate the use of attributes // (.Attr in the FFI handler definition above). -ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, - ffi::Buffer b, - ffi::Result> c, - ffi::Result> b_plus_1, - size_t n) { +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, + ffi::Buffer b, ffi::ResultBuffer c, + ffi::ResultBuffer b_plus_1, size_t n) { const int block_dim = 128; const int grid_dim = 1; // Note how we access regular Buffer data vs Result Buffer data: @@ -73,12 +71,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooFwd, FooFwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // a - .Arg>() // b - .Ret>() // c - .Ret>() // b_plus_1 + .Arg>() // a + .Arg>() // b + .Ret>() // c + .Ret>() // b_plus_1 .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled //----------------------------------------------------------------------------// // Backward pass // @@ -106,11 +104,11 @@ __global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c } ffi::Error FooBwdHost(cudaStream_t stream, - ffi::Buffer c_grad, - ffi::Buffer a, - ffi::Result> b_plus_1, - ffi::Result> a_grad, - ffi::Result> b_grad, + ffi::Buffer c_grad, + ffi::Buffer a, + ffi::ResultBuffer b_plus_1, + ffi::ResultBuffer a_grad, + ffi::ResultBuffer b_grad, size_t n) { const int block_dim = 128; const int grid_dim = 1; @@ -131,10 +129,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooBwd, FooBwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // c_grad - .Arg>() // a - .Arg>() // b_plus_1 - .Ret>() // a_grad - .Ret>() // b_grad + .Arg>() // c_grad + .Arg>() // a + .Arg>() // b_plus_1 + .Ret>() // a_grad + .Ret>() // b_grad .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled diff --git a/examples/ffi/src/jax_ffi_example/cuda_examples.py b/examples/ffi/src/jax_ffi_example/cuda_examples.py new file mode 100644 index 000000000000..b60b12af577e --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/cuda_examples.py @@ -0,0 +1,68 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An end-to-end example demonstrating the use of the JAX FFI with CUDA. + +The specifics of the kernels are not very important, but the general structure, +and packaging of the extension are useful for testing. +""" + +import os +import ctypes + +import numpy as np + +import jax +import jax.numpy as jnp +import jax.extend as jex + +# Load the shared library with the FFI target definitions +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so") +library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) + +jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd), + platform="CUDA") +jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd), + platform="CUDA") + + +def foo_fwd(a, b): + assert a.dtype == jnp.float32 + assert a.shape == b.shape + assert a.dtype == b.dtype + n = np.prod(a.shape).astype(np.uint64) + out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) + c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n) + return c, (a, b_plus_1) + + +def foo_bwd(res, c_grad): + a, b_plus_1 = res + assert c_grad.dtype == jnp.float32 + assert c_grad.shape == a.shape + assert a.shape == b_plus_1.shape + assert c_grad.dtype == a.dtype + assert a.dtype == b_plus_1.dtype + n = np.prod(a.shape).astype(np.uint64) + out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) + return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1, + n=n) + + +@jax.custom_vjp +def foo(a, b): + c, _ = foo_fwd(a, b) + return c + + +foo.defvjp(foo_fwd, foo_bwd) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 2fb8d96c8461..455a0e557620 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -59,11 +59,10 @@ std::pair GetDims(const ffi::Buffer &buffer) { // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -82,12 +81,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ); ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -118,11 +116,10 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, ffi::Buffer ct_y, - ffi::Result> ct_x) { + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/cpu_examples_test.py similarity index 50% rename from examples/ffi/tests/attrs_test.py rename to examples/ffi/tests/cpu_examples_test.py index 0288b31cf9fa..cb2653d2e928 100644 --- a/examples/ffi/tests/attrs_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -18,18 +18,23 @@ import jax.numpy as jnp from jax._src import test_util as jtu -from jax_ffi_example import attrs +from jax_ffi_example import cpu_examples jax.config.parse_flags_with_absl() class AttrsTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + def test_array_attr(self): - self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) - self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) + self.assertEqual(cpu_examples.array_attr(5), jnp.arange(5).sum()) + self.assertEqual(cpu_examples.array_attr(3), jnp.arange(3).sum()) def test_array_attr_jit_cache(self): - jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,)) + jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,)) with jtu.count_jit_and_pmap_lowerings() as count: jit_array_attr(5) self.assertEqual(count[0], 1) # compiles once the first time @@ -39,22 +44,51 @@ def test_array_attr_jit_cache(self): def test_array_attr_no_jit(self): with jax.disable_jit(): - attrs.array_attr(5) # doesn't crash + cpu_examples.array_attr(5) # doesn't crash def test_dictionary_attr(self): - secret, count = attrs.dictionary_attr(secret=5) + secret, count = cpu_examples.dictionary_attr(secret=5) self.assertEqual(secret, 5) self.assertEqual(count, 1) - secret, count = attrs.dictionary_attr(secret=3, a_string="hello") + secret, count = cpu_examples.dictionary_attr(secret=3, a_string="hello") self.assertEqual(secret, 3) self.assertEqual(count, 2) with self.assertRaisesRegex(Exception, "Unexpected attribute"): - attrs.dictionary_attr() + cpu_examples.dictionary_attr() with self.assertRaisesRegex(Exception, "Wrong attribute type"): - attrs.dictionary_attr(secret="invalid") + cpu_examples.dictionary_attr(secret="invalid") + + +class CounterTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + def test_basic(self): + self.assertEqual(cpu_examples.counter(0), 0) + self.assertEqual(cpu_examples.counter(0), 1) + self.assertEqual(cpu_examples.counter(0), 2) + self.assertEqual(cpu_examples.counter(1), 0) + self.assertEqual(cpu_examples.counter(0), 3) + + def test_jit(self): + @jax.jit + def counter_fun(x): + return x, cpu_examples.counter(2) + + self.assertEqual(counter_fun(0)[1], 0) + self.assertEqual(counter_fun(0)[1], 1) + + # Persists across different cache hits + self.assertEqual(counter_fun(1)[1], 2) + + # Persists after the cache is cleared + counter_fun.clear_cache() + self.assertEqual(counter_fun(0)[1], 3) if __name__ == "__main__": diff --git a/examples/ffi/tests/cuda_examples_test.py b/examples/ffi/tests/cuda_examples_test.py new file mode 100644 index 000000000000..f4a736599ce4 --- /dev/null +++ b/examples/ffi/tests/cuda_examples_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +jax.config.parse_flags_with_absl() + + +class CudaE2eTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Unsupported platform") + + # Import here to avoid trying to load the library when it's not built. + from jax_ffi_example import cuda_examples + self.foo = cuda_examples.foo + + def test_fwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + observed = jax.jit(self.foo)(a, b) + expected = (2. * (3. + 1.)) + self.assertArraysEqual(observed, expected) + + def test_bwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + + def loss(a, b): + return jnp.sum(self.foo(a, b)) + + da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) + da_expected = b + 1 + db_expected = a + self.assertArraysEqual(da_observed, da_expected) + self.assertArraysEqual(db_observed, db_expected) + + def test_fwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + observed = jax.jit(self.foo)(a, b) + expected = a * (b + 1) + self.assertAllClose(observed, expected) + + def test_bwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + jtu.check_grads(f=jax.jit(self.foo), args=(a, b), order=1, + modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py index aad5562629ed..bccd696c601e 100644 --- a/examples/ffi/tests/rms_norm_test.py +++ b/examples/ffi/tests/rms_norm_test.py @@ -29,6 +29,11 @@ def rms_norm_ref(x, eps=1e-5): class RmsNormTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + def test_basic(self): x = jnp.linspace(-0.5, 0.5, 15) self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index 6e4647b5e491..b3cb995aae21 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -26,8 +26,13 @@ cc_binary( "@tsl//tsl/platform:platform_port", "@xla//xla:literal", "@xla//xla:literal_util", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt/cpu:cpu_client", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/service:hlo_module_config", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 2a8f8d4debba..ceac2cd2d7c9 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -36,15 +36,21 @@ limitations under the License. // } // ) +#include #include #include #include #include "third_party/absl/status/statusor.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" @@ -66,8 +72,10 @@ int main(int argc, char** argv) { // Run it using JAX C++ Runtime (PJRT). // Get a CPU client. + xla::CpuClientOptions options; + options.asynchronous = true; std::unique_ptr client = - xla::GetTfrtCpuClient(/*asynchronous=*/true).value(); + xla::GetXlaPjrtCpuClient(options).value(); // Compile XlaComputation to PjRtExecutable. xla::XlaComputation xla_computation(test_module_proto); diff --git a/jax/BUILD b/jax/BUILD index 12c239a2d63e..31020eb1d385 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -193,6 +193,7 @@ py_library_providing_imports_info( "_src/custom_batching.py", "_src/custom_derivatives.py", "_src/custom_partitioning.py", + "_src/custom_partitioning_sharding_rule.py", "_src/custom_transpose.py", "_src/debugging.py", "_src/dispatch.py", @@ -227,6 +228,7 @@ py_library_providing_imports_info( "_src/state/**/*.py", "_src/third_party/**/*.py", "experimental/key_reuse/**/*.py", + "experimental/roofline/**/*.py", "image/**/*.py", "interpreters/**/*.py", "lax/**/*.py", @@ -426,10 +428,12 @@ pytype_strict_library( name = "compiler", srcs = ["_src/compiler.py"], deps = [ + ":cache_key", ":compilation_cache_internal", ":config", ":mlir", ":monitoring", + ":path", ":profiler", ":traceback_util", ":xla_bridge", @@ -450,6 +454,7 @@ pytype_strict_library( ":deprecations", ":dtypes", ":effects", + ":mesh", ":pretty_printer", ":source_info_util", ":traceback_util", @@ -627,7 +632,7 @@ pytype_strict_library( pytype_strict_library( name = "pallas_gpu_ops", - srcs = glob(["experimental/pallas/ops/gpu/**/*.py"]), + srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"], visibility = [ ":pallas_gpu_users", ], @@ -638,6 +643,22 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "pallas_experimental_gpu_ops", + testonly = True, + srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"], + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + ":jax", + ":mosaic_gpu", + ":pallas", + ":pallas_mosaic_gpu", + ":test_util", # This is only to make them runnable as jax_multiplatform_test... + ] + py_deps("numpy"), +) + pytype_strict_library( name = "pallas_tpu_ops", srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]), @@ -688,6 +709,7 @@ pytype_strict_library( deps = [ "//jax/_src/pallas/mosaic_gpu:core", "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic_gpu:pipeline", "//jax/_src/pallas/mosaic_gpu:primitives", ], ) @@ -705,6 +727,7 @@ py_library( ":jax", ":mlir", "//jax/_src/lib", + "//jax/extend:ffi", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:func_dialect", @@ -931,6 +954,7 @@ pytype_strict_library( ":mlir", ":sharding_impls", "//jax/_src/lib", + "//jax/_src/pallas", ] + if_building_jaxlib([ "//jaxlib/mlir:ir", "//jaxlib/mlir:mhlo_dialect", @@ -1027,7 +1051,7 @@ pytype_library( "experimental/array_api/*.py", ], ), - visibility = [":internal"] + jax_visibility("array_api"), + visibility = [":internal"], deps = [ ":jax", ], @@ -1157,3 +1181,25 @@ pytype_library( visibility = ["//visibility:public"], deps = [":jax"], ) + +pytype_library( + name = "experimental_colocated_python", + srcs = [ + "experimental/colocated_python/__init__.py", + "experimental/colocated_python/api.py", + "experimental/colocated_python/func.py", + "experimental/colocated_python/func_backend.py", + "experimental/colocated_python/serialization.py", + ], + visibility = ["//visibility:public"], + deps = [ + ":api_util", + ":jax", + ":traceback_util", + ":tree_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + "//jax/extend:ifrt_programs", + ] + py_deps("numpy") + py_deps("cloudpickle"), +) diff --git a/jax/__init__.py b/jax/__init__.py index 4f5c256b0c9d..8ca7721da445 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -83,7 +83,6 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies -from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient @@ -218,23 +217,15 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), - # Added Mar 18, 2024 + # Finalized Nov 12 2024; remove after Feb 12 2025 "clear_backends": ( - "jax.clear_backends is deprecated.", - _deprecated_clear_backends - ), - # Remove after jax 0.4.35 release. - "xla_computation": ( - "jax.xla_computation is deleted. Please use the AOT APIs; see " - "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " - "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " - "CHANGELOG.md for 0.4.30 for more examples.", None + "jax.clear_backends was removed in JAX v0.4.36", + None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.api import clear_backends as clear_backends from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 9a49a09c7483..95216fb6fcb2 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -24,9 +24,7 @@ from jax._src import traceback_util traceback_util.register_exclusion(__file__) -UnshapedArray = core.UnshapedArray ShapedArray = core.ShapedArray -ConcreteArray = core.ConcreteArray AbstractToken = core.AbstractToken abstract_token = core.abstract_token canonicalize_shape = core.canonicalize_shape @@ -47,8 +45,11 @@ array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic def canonical_concrete_aval(val, weak_type=None): - return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val, - weak_type=weak_type) + weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type + dtype = dtypes.canonicalize_dtype(np.result_type(val)) + dtypes.check_valid_dtype(dtype) + sharding = core._get_abstract_sharding(val) + return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding) def masked_array_error(*args, **kwargs): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5160104e2141..93376c7bd170 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -410,17 +410,15 @@ def _trace_to_jaxpr(fun, in_tree, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) except core.ConcretizationTypeError as e: msg, = e.args - if 'for checkpoint' not in msg: - raise - new_msg = msg + "\n\n" + ( - "Consider using the `static_argnums` parameter for `jax.remat` or " - "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " - "involving `static_argnums`:\n" - "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" - "\n") - new_e = core.ConcretizationTypeError.__new__(core.ConcretizationTypeError) - new_e.args = (new_msg,) - raise new_e from None + if 'for checkpoint' in msg: + msg += "\n\n" + ( + "Consider using the `static_argnums` parameter for `jax.remat` or " + "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " + "involving `static_argnums`:\n" + "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" + "\n") + e.args = msg, + raise return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree() @@ -654,7 +652,7 @@ def remat_transpose(out_cts, *in_primals, jaxpr, **params): for x in in_primals] assert next(in_cts_nz_, None) is next(in_zeros_, None) is None return in_cts -ad.reducing_transposes[remat_p] = remat_transpose +ad.primitive_transposes[remat_p] = remat_transpose # TODO(mattjj): move this to ad.py def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool], @@ -703,24 +701,23 @@ def transposed(*args_flat): transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error -def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, - jaxpr, **params): +def remat_vmap(axis_data, args, dims, *, jaxpr, **params): assert not jaxpr.constvars jaxpr_batched_, out_batched = batching.batch_jaxpr_axes( - pe.close_jaxpr(jaxpr), axis_size, dims, - [batching.zero_if_mapped] * len(jaxpr.outvars), - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + pe.close_jaxpr(jaxpr), axis_data, dims, + [batching.zero_if_mapped] * len(jaxpr.outvars)) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts if consts: jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched) out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims -batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None) -batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap +batching.fancy_primitive_batchers[remat_p] = remat_vmap # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if (not any(used_inputs) and not any(used_outputs) and diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index bd1427f59e01..02f3b0405e38 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -43,7 +43,8 @@ def add_impl(x, y): @add_jaxvals_p.def_abstract_eval def add_abstract(x, y): - return core.lattice_join(x, y) + assert core.typematch(x, y) + return x def zeros_like_aval(aval: core.AbstractValue) -> Array: return aval_zeros_likers[type(aval)](aval) diff --git a/jax/_src/api.py b/jax/_src/api.py index d2ac5465eded..308e7c230dc2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -34,7 +34,7 @@ import weakref import numpy as np -from contextlib import contextmanager, ExitStack +from contextlib import contextmanager from jax._src import linear_util as lu from jax._src import stages @@ -56,7 +56,7 @@ from jax._src import traceback_util from jax._src import pjit from jax._src import xla_bridge as xb -from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray +from jax._src.core import eval_jaxpr, ShapedArray from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, flatten_axes, donation_vector, @@ -123,8 +123,8 @@ def _update_debug_special_global(_): jax_jit.global_state().post_hook = None def _update_debug_special_thread_local(_): - if (getattr(config._thread_local_state, "jax_debug_nans", False) or - getattr(config._thread_local_state, "jax_debug_infs", False)): + if (config.debug_nans.get_local() == True or + config.debug_infs.get_local() == True): jax_jit.thread_local_state().post_hook = _nan_check_posthook else: jax_jit.thread_local_state().post_hook = None @@ -151,6 +151,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, + compiler_options: dict[str, Any] | None = None, ) -> pjit.JitWrapped: """Sets up ``fun`` for just-in-time compilation with XLA. @@ -280,7 +281,7 @@ def jit( return pjit.make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env=False) + keep_unused, inline, compiler_options, use_resource_env=False) @contextmanager @@ -989,10 +990,10 @@ def vmap_f(*args, **kwargs): axis_size_ = (axis_size if axis_size is not None else _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) try: + axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name) out_flat = batching.batch( - flat_fun, axis_name, axis_size_, in_axes_flat, - lambda: flatten_axes("vmap out_axes", out_tree(), out_axes), - spmd_axis_name=spmd_axis_name + flat_fun, axis_data, in_axes_flat, + lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) ).call_wrapped(*args_flat) except batching.SpecMatchError as e: out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) @@ -1038,27 +1039,35 @@ def _get_axis_size(name: str, shape: tuple[core.AxisSize, ...], axis: int def _get_argument_type(x): try: return shaped_abstractify(x).str_short() - except TypeError: #Catch all for user specified objects that can't be interpreted as a data type + except TypeError: # Catch all for user specified objects that can't be interpreted as a data type return "unknown" msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"] args, kwargs = tree_unflatten(tree, vals) try: ba = inspect.signature(fn).bind(*args, **kwargs) + signature_parameters: list[str] = list(ba.signature.parameters.keys()) except (TypeError, ValueError): - ba = None - if ba is None: - args_paths = [f'args{keystr(p)} ' - f'of type {_get_argument_type(x)}' - for p, x in generate_key_paths(args)] - kwargs_paths = [f'kwargs{keystr(p)} ' - f'of type {_get_argument_type(x)}' - for p, x in generate_key_paths(kwargs)] - key_paths = [*args_paths, *kwargs_paths] - else: - key_paths = [f'argument {name}{keystr(p)} ' - f'of type {_get_argument_type(x)}' - for name, arg in ba.arguments.items() - for p, x in generate_key_paths(arg)] + signature_parameters = None + + def arg_name(key_path): + if signature_parameters is None: + return f"args{keystr(key_path)}" + # args is a tuple, so key_path[0].idx is the index into args. + i = key_path[0].idx + res = f"argument {signature_parameters[i]}" + if len(key_path) > 1: + res += keystr(key_path[1:]) + return res + + args_paths = [ + f"{arg_name(p)} of type {_get_argument_type(x)}" + for (p, x) in generate_key_paths(args) + ] + kwargs_paths = [ + f"kwargs{keystr(p)} of type {_get_argument_type(x)}" + for p, x in generate_key_paths(kwargs) + ] + key_paths = [*args_paths, *kwargs_paths] all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None for x, d in zip(vals, dims)] size_counts = collections.Counter(s for s in all_sizes if s is not None) @@ -1538,16 +1547,13 @@ def cache_miss(*args, **kwargs): is_explicit_global_axis_size=p.is_explicit_global_axis_size, ) - map_bind_continuation, top_trace, fun_, tracers, params = ( - core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun, - *p.flat_args, **params)) execute: Callable | None = None - if isinstance(top_trace, core.EvalTrace): - execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params) - out = map_bind_continuation(execute(*tracers)) - else: - out = map_bind_continuation( - pxla.xla_pmap_p.process(top_trace, fun_, tracers, params)) + with core.take_current_trace() as trace: + if isinstance(trace, core.EvalTrace): + execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) + out = execute(*p.flat_args) + else: + out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) out_tree, out_flat = p.out_tree, out out_pytree_def = out_tree() @@ -1593,7 +1599,7 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - lambda x, s: pxla.shard_args([s], [None], [x])[0], + lambda x, s: pxla.shard_args([s], [None], [None], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) @@ -1794,7 +1800,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) - (Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True)) + (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 @@ -2152,9 +2158,7 @@ def make_jaxpr( @wraps(fun) @api_boundary def make_jaxpr_f(*args, **kwargs): - with ExitStack() as stack: - for axis_name, size in axis_env or []: - stack.enter_context(core.extend_axis_env(axis_name, size, None)) + with core.extend_axis_env_nd(axis_env or []): traced = jit(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes).trace(*args, **kwargs) # `jit` converts tracers in consts to args but that breaks the semantics of @@ -2180,14 +2184,15 @@ def make_jaxpr_f(*args, **kwargs): def _infer_src_sharding(src, x) -> Sharding | None: if src is not None: - # TODO(slebedev): This looks like an error and needs investigation. return src # pytype: disable=bad-return-type if isinstance(x, array.ArrayImpl): return x.sharding - elif isinstance(x, core.Tracer): - aval = core.get_aval(x) - if isinstance(aval, ConcreteArray) and isinstance(aval.val, array.ArrayImpl): - return aval.val.sharding + if config.sharding_in_types.value and hasattr(x, 'sharding'): + return x.sharding + if isinstance(x, core.Tracer): + val = x.to_concrete_value() + if val is not None and isinstance(val, array.ArrayImpl): + return val.sharding return None @@ -2437,6 +2442,13 @@ def _device_put_replicated(x): def _device_get(x): if isinstance(x, core.Tracer): return x + + # Extended dtypes dispatch via their device_get rule. + if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended): + bufs, tree = tree_util.dispatch_registry.flatten(x) + return tree.unflatten(device_get(bufs)) + + # Other types dispatch via their __array__ method. try: toarray = x.__array__ except AttributeError: @@ -2765,11 +2777,11 @@ def clear_backends(): dispatch.xla_primitive_callable.cache_clear() util.clear_all_caches() pjit._infer_params_cached.cache_clear() - pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() + xc._xla.jax_jit.thread_local_state().extra_jit_context = None @atexit.register def clean_up(): diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 329abd6b7570..eb5e7e8bf8de 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: else: return tuple(map(_ensure_str, x)) -@lu.transformation_with_aux -def flatten_fun(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree @@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args): ans = fun(*args) return tree_unflatten(out_tree, ans) -@lu.transformation_with_aux -def flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} - yield tree_flatten(ans) + ans = f(*py_args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree @@ -118,17 +122,18 @@ def flattened_fun_in_tree( else: return in_tree, lambda: out_tree_store.val, has_kwargs -@lu.transformation_with_aux -def flatten_fun_nokwargs2(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs2(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - pair = yield py_args, {} + pair = f(*py_args) if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {pair!r}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) - yield (ans_flat, aux_flat), (ans_tree, aux_tree) + store.store((ans_tree, aux_tree)) + return ans_flat, aux_flat class _HashableWithStrictTypeEquality: """Box object used when comparing static arguments as a jit key. @@ -277,18 +282,16 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args -@lu.transformation -def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): +@lu.transformation2 +def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): sentinel = object() - args = [sentinel] * (len(fixed_args) + len(dyn_args)) - for i, arg in zip(dyn_argnums, dyn_args): + args = [sentinel] * (len(_fixed_args) + len(dyn_args)) + for i, arg in zip(_dyn_argnums, dyn_args): args[i] = arg - fixed_args_ = iter(fixed_args) + fixed_args_ = iter(_fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - ans = yield args, kwargs - yield ans - + return _fun(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs -@lu.transformation -def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): - kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - ans = yield args, kwargs - yield ans +@lu.transformation2 +def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): + kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs) + return _fun(*args, **kwargs) @lru_cache(maxsize=4096) @@ -435,10 +437,10 @@ def flat_out_axes( f, out_axes = _flat_out_axes(f, tuple(leaves), treedef) return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) -@lu.transformation_with_aux -def _flat_out_axes(leaves, treedef, *args, **kwargs): - ans = yield args, kwargs - spec = tree_unflatten(treedef, leaves) +@lu.transformation_with_aux2 +def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs): + ans = _fun(*args, **kwargs) + spec = tree_unflatten(_treedef, _leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) except ValueError: @@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - yield ans, spec_flat + _store.store(spec_flat) + return ans def check_callable(fun): # In Python 3.10+, the only thing stopping us from supporting staticmethods @@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() for path, l in generate_key_paths(x) if l is not static) -@lu.transformation_with_aux -def result_paths(*args, **kwargs): +@lu.transformation_with_aux2 +def result_paths(_fun, _store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = yield args, kwargs - yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] + ans = _fun(*args, **kwargs) + _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, result_paths: tuple[str, ...] | None = None, diff --git a/jax/_src/array.py b/jax/_src/array.py index 2f29f137675b..7c5385f97e40 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1034,7 +1034,7 @@ def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( self.sharding.mesh.abstract_mesh, - self.sharding._normalized_spec(self.ndim))) + self.sharding.spec._normalized_spec(self.ndim))) else: return self.aval api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array @@ -1110,7 +1110,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Look up all buffers that contain the correct slice of the logical array. candidates_list = candidates[hashed_index(idx)] if not candidates_list: - return pxla.shard_args([sharding], [None], [x._value], + return pxla.shard_args([sharding], [None], [None], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. @@ -1119,7 +1119,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): bufs.append(buf) break else: - bufs.append(buf) + bufs.append(candidates_list[-1]) return pxla.batched_device_put(x.aval, sharding, bufs, devices) @@ -1130,11 +1130,13 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _array_shard_arg(xs, shardings, layouts): +def _array_shard_arg(xs, shardings, layouts, copy_semantics): results = [] batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] + batch_cs = [] - for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): + for i, (x, sharding, layout, cs) in enumerate( + safe_zip(xs, shardings, layouts, copy_semantics)): x._check_if_deleted() indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) same_layout = (True if layout is None else @@ -1156,6 +1158,7 @@ def _array_shard_arg(xs, shardings, layouts): batch_devs.append(list(devices)) batch_shardings.append(sharding) batch_indices.append(i) + batch_cs.append(cs) # Resharding starts here: elif not same_layout: results.append(api.device_put(x, Layout(layout, sharding))) @@ -1166,7 +1169,7 @@ def _array_shard_arg(xs, shardings, layouts): shard_sharded_device_array_slow_path(x, devices, indices, sharding)) copy_outs = xc.batched_copy_array_to_devices_with_sharding( - batch_xs, batch_devs, batch_shardings) + batch_xs, batch_devs, batch_shardings, batch_cs) for i, copy_out in safe_zip(batch_indices, copy_outs): assert results[i] is None results[i] = copy_out @@ -1184,7 +1187,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed): global_aval, out_sharding, committed=committed, _skip_checks=True ) pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler -pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): @@ -1197,13 +1199,13 @@ def _array_local_result_handler(aval, sharding, indices): aval, sharding, committed=True, _skip_checks=True ) pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler -pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler # Token handlers -def _token_shard_arg(xs, shardings, layouts): - return _array_shard_arg([x._buf for x in xs], shardings, layouts) +def _token_shard_arg(xs, shardings, layouts, copy_semantics): + return _array_shard_arg([x._buf for x in xs], shardings, layouts, + copy_semantics) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index c3145f32e8bf..a89d4a2949be 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -53,6 +53,7 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace # associated basearray.pyi file. __slots__ = ['__weakref__'] + __hash__ = None @property @abc.abstractmethod diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 16da61d75b3f..3bc592d88246 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -23,7 +23,7 @@ Shape = random.Shape class SampleFn(Protocol): - def __call__(self, key: random.KeyArrayLike, *args, shape: Shape, + def __call__(self, key: ArrayLike, *args, shape: Shape, **kwargs) -> Array: ... @@ -43,7 +43,7 @@ def _compute_scalar_index(iteration_index: Sequence[int], def blocked_fold_in( - global_key: random.KeyArrayLike, + global_key: ArrayLike, total_size: Shape, block_size: Shape, tile_size: Shape, diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 9bce9d0e4308..e4b6e7a2669c 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import enum import hashlib import io import logging @@ -61,11 +62,23 @@ def custom_hook() -> str: return "" -def get(module: ir.Module, - devices: np.ndarray, - compile_options: xla_client.CompileOptions, - backend: xla_client.Client, - compression_algorithm: str = "zstandard") -> str: +class IgnoreCallbacks(enum.IntEnum): + # Do not remove any callback pointers from precompiled IR. + NO = enum.auto() + # Remove all callback pointers from precompiled IR. + ALL = enum.auto() + # Remove only custom_partitioning callback pointer from precompiled IR. + CUSTOM_PARTITIONING = enum.auto() + + +def get( + module: ir.Module, + devices: np.ndarray, + compile_options: xla_client.CompileOptions, + backend: xla_client.Client, + compression_algorithm: str = "zstandard", + ignore_callbacks: IgnoreCallbacks = IgnoreCallbacks.NO, +) -> str: """Creates a hashed string to use as a key to the compilation cache. Creates a cache key that is a hex-encoded string of a unique hash based on @@ -78,28 +91,47 @@ def get(module: ir.Module, backend: description of the platform (e.g., TPU version) compression_algorithm: a string representing the compression algorithm used for the executable before persisting in the cache + ignore_callbacks: whether to remove the all callback pointer from the + computation. Typical return value example: 'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' """ entries = [ - ("computation", - lambda hash_obj: _hash_computation(hash_obj, module)), - ("jax_lib version", - lambda hash_obj: hash_obj.update( - bytes(jaxlib_version_str.encode("utf-8")))), - ("XLA flags", - lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())), - ("compile_options", - lambda hash_obj: _hash_serialized_compile_options( - hash_obj, compile_options, - # In case of GPU multi-process tasks we need to strip device - # assignment to use cache key as invariant between processes. - strip_device_assignment=(backend.platform == "gpu"))), - ("accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)), - ("compression", - lambda hash_obj: _hash_string(hash_obj, compression_algorithm)), + ( + "computation", + lambda hash_obj: _hash_computation( + hash_obj, module, ignore_callbacks + ), + ), + ( + "jax_lib version", + lambda hash_obj: hash_obj.update( + bytes(jaxlib_version_str.encode("utf-8")) + ), + ), + ( + "XLA flags", + lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), + ), + ( + "compile_options", + lambda hash_obj: _hash_serialized_compile_options( + hash_obj, + compile_options, + # In case of GPU multi-process tasks we need to strip device + # assignment to use cache key as invariant between processes. + strip_device_assignment=(backend.platform == "gpu"), + ), + ), + ( + "accelerator_config", + lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + ), + ( + "compression", + lambda hash_obj: _hash_string(hash_obj, compression_algorithm), + ), ("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())), ] @@ -130,45 +162,56 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): ) -def _remove_custom_partitioning_ptr(m: ir.Module): - """ - Removes custom_partitioning callback pointer from precompiled IR. +def _remove_callbacks(m: ir.Module, ignore_callbacks: IgnoreCallbacks): + """Removes callback pointers from precompiled IR. + Python function pointers are not deterministic across executions. """ def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult: - if (op.name == "stablehlo.custom_call" and - op.attributes["call_target_name"].value == "CustomSPMDPartitioning"): + if op.name == "stablehlo.custom_call" and ( + ( + ignore_callbacks == IgnoreCallbacks.ALL + and op.attributes["call_target_name"].value.endswith("callback") + ) + or op.attributes["call_target_name"].value == "CustomSPMDPartitioning" + ): op.attributes["backend_config"] = ir.StringAttr.get("REMOVED") return ir.WalkResult.ADVANCE + if ignore_callbacks == IgnoreCallbacks.NO: + return m + m.operation.walk(_update_bc_attribute) return m -def _serialize_ir(m: ir.Module) -> bytes: +def _serialize_ir(m: ir.Module, ignore_callbacks: IgnoreCallbacks) -> bytes: output = io.BytesIO() - if config.remove_custom_partitioning_ptr_from_cache_key.value: - m = _remove_custom_partitioning_ptr(type_cast(ir.Module, - m.operation.clone())) + if ignore_callbacks != IgnoreCallbacks.NO: + m = _remove_callbacks( + type_cast(ir.Module, m.operation.clone()), ignore_callbacks + ) m.operation.write_bytecode(file=output) return output.getvalue() -def _canonicalize_ir(m_original: ir.Module) -> bytes: +def _canonicalize_ir( + m_original: ir.Module, ignore_callbacks: IgnoreCallbacks +) -> bytes: with m_original.context: m = type_cast(ir.Module, m_original.operation.clone()) passes = pm.PassManager.parse( "builtin.module(strip-debuginfo)" ) passes.run(m.operation) - return _serialize_ir(m) + return _serialize_ir(m, ignore_callbacks) -def _hash_computation(hash_obj, module): +def _hash_computation(hash_obj, module, ignore_callbacks: IgnoreCallbacks): if config.compilation_cache_include_metadata_in_key.value: - canonical_ir = _serialize_ir(module) + canonical_ir = _serialize_ir(module, ignore_callbacks) else: - canonical_ir = _canonicalize_ir(module) + canonical_ir = _canonicalize_ir(module, ignore_callbacks) hash_obj.update(canonical_ir) @@ -194,6 +237,38 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): _hash_devices(hash_obj, accelerators) _hash_platform(hash_obj, backend) +# LINT.IfChange(xla_flags) +xla_flags_to_exclude_from_cache_key = [ + "--xla_dump_compress_protos", + "--xla_dump_module_metadata", + "--xla_dump_max_hlo_modules", + "--xla_dump_include_timestamp", + "--xla_dump_hlo_pass_re", + "--xla_dump_hlo_module_re", + "--xla_dump_hlo_snapshots", + "--xla_dump_fusion_visualization", + "--xla_dump_hlo_as_url", + "--xla_dump_hlo_as_proto", + "--xla_dump_hlo_as_text", + "--xla_dump_hlo_as_long_text", + "--xla_dump_hlo_as_html", + "--xla_dump_hlo_as_dot", + "--xla_dump_to", + "--xla_force_host_platform_device_count", + "--xla_dump_disable_metadata", + "--xla_dump_hlo_pipeline_re", + "--xla_tpu_sdc_checker_streamz_metric", + "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", + "--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks", + "--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present", + "--xla_gpu_cuda_data_dir", + "--xla_gpu_experimental_autotune_cache_mode", +] + +env_override_flags_to_exclude_from_cache_key = { + x.strip("-") for x in xla_flags_to_exclude_from_cache_key +} +# LINT.ThenChange(:debug_options) def _hash_serialized_compile_options(hash_obj, compile_options_obj, strip_device_assignment=False): @@ -225,6 +300,8 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_dump_hlo_as_long_text = False debug_options.xla_dump_disable_metadata = False debug_options.xla_dump_hlo_pipeline_re = "" + debug_options.xla_gpu_experimental_autotune_cache_mode = 0 + # Optional way to specify the cuda install path to be used by the compiler. # This could possibly affect the cuda version compiled with, but this should # already be included in the platform information (and might not be reflected @@ -235,6 +312,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_gpu_cuda_data_dir = "" # LINT.ThenChange(:xla_flags) + compile_options_copy.env_option_overrides = [ + flag_value + for flag_value in compile_options_copy.env_option_overrides + if flag_value[0] not in env_override_flags_to_exclude_from_cache_key + ] if strip_device_assignment and compile_options_copy.device_assignment: replica_count = compile_options_copy.device_assignment.replica_count() computation_count = compile_options_copy.device_assignment.computation_count() @@ -252,32 +334,6 @@ def _hash_platform(hash_obj, backend): def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): - # LINT.IfChange(xla_flags) - xla_flags_to_exclude_from_cache_key = [ - "--xla_dump_compress_protos", - "--xla_dump_module_metadata", - "--xla_dump_max_hlo_modules", - "--xla_dump_include_timestamp", - "--xla_dump_hlo_pass_re", - "--xla_dump_hlo_module_re", - "--xla_dump_hlo_snapshots", - "--xla_dump_fusion_visualization", - "--xla_dump_hlo_as_url", - "--xla_dump_hlo_as_proto", - "--xla_dump_hlo_as_text", - "--xla_dump_hlo_as_long_text", - "--xla_dump_hlo_as_html", - "--xla_dump_hlo_as_dot", - "--xla_dump_to", - "--xla_force_host_platform_device_count", - "--xla_dump_disable_metadata", - "--xla_dump_hlo_pipeline_re", - "--xla_tpu_sdc_checker_streamz_metric", - "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", - "--xla_gpu_cuda_data_dir", - ] - # LINT.ThenChange(:debug_options) - xla_flags = [] xla_flags_env_var = os.getenv("XLA_FLAGS") diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 9630418ae76c..013b766b8550 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -160,9 +160,22 @@ def callback_batching_rule( batched_result_avals = tuple( core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) for aval in result_avals) + + # For FFI calls we must update the layouts. We handle the output layouts + # here, but the input layout updates depend on the vmap_method parameter. + if vmap_method != "sequential" and kwargs.get("output_layouts") is not None: + kwargs["output_layouts"] = tuple( + None if layout is None else tuple(n + 1 for n in layout) + (0,) + for layout in kwargs["output_layouts"]) + if vmap_method == "legacy_vectorized": # This method is kept to support the behavior that was previously exposed # when using `vectorized=True`. + if kwargs.get("input_layouts") is not None: + kwargs["input_layouts"] = tuple( + layout if d is batching.not_mapped else + (None if layout is None else tuple(n + 1 for n in layout) + (0,)) + for layout, d in zip(kwargs["input_layouts"], dims)) outvals = prim.bind( *new_args, vectorized=vectorized, @@ -175,6 +188,10 @@ def callback_batching_rule( bcast_args = [ lax.broadcast(x, (size,)) if d is batching.not_mapped else x for x, d in zip(new_args, dims)] + if kwargs.get("input_layouts") is not None: + kwargs["input_layouts"] = tuple( + None if layout is None else tuple(n + 1 for n in layout) + (0,) + for layout in kwargs["input_layouts"]) outvals = prim.bind( *bcast_args, vectorized=vectorized, @@ -326,7 +343,7 @@ def pure_callback( * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` is deprecated and it will eventually raise ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over - the batched arugments, calling ``callback`` once for each batch element. + the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1`` added as the leading dimension unbatched inputs. * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the @@ -616,7 +633,6 @@ def io_callback( flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype), flat_shape_dtypes) - flat_args = map(core.raise_as_much_as_possible, flat_args) out_flat = io_callback_p.bind( *flat_args, callback=_FlatCallback(callback, in_tree), diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 944bf303b8f6..22fde8bd1cb5 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type): ## Checkify transformation for plumbing functional error values. -@lu.transformation_with_aux -def _flatten_and_get_error_metadata_thunk(*invals): - error, out = yield invals, {} +@lu.transformation_with_aux2 +def _flatten_and_get_error_metadata_thunk(f, store, *invals): + error, out = f(*invals) out_vals, out_tree = jtu.tree_flatten((error, out)) - yield out_vals, (out_tree, set(error._pred.keys())) + store.store((out_tree, set(error._pred.keys()))) + return out_vals def default_checkify_rule(primitive: core.Primitive, error: Error, enabled_errors, *invals: core.Value, @@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, consts = tuple(c.x for c in hashable_consts) return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) -@lu.transformation_with_aux -def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) +@lu.transformation_with_aux2 +def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def _reduce_any_error(error: Error): @@ -898,7 +901,8 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, inline, keep_unused): + resource_env, donated_invars, name, inline, keep_unused, + compiler_options_kvs): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] @@ -929,6 +933,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, name=name, inline=inline, keep_unused=keep_unused, + compiler_options_kvs=compiler_options_kvs, ) return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index c7665da961af..a2f137686dae 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,8 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''): + os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 02aea2cd64d5..c8aa765c181c 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -180,6 +180,9 @@ def is_env_present(cls) -> bool: if not running_in_cloud_tpu_vm: logger.debug("Did not detect cloud TPU VM") return False + if os.environ.get("TPU_SKIP_MDS_QUERY") is not None: + logger.debug("TPU_SKIP_MDS_QUERY is set to True, so it's probably not a GCE TPU cluster.") + return False metadata_response, metadata_code = get_metadata('agent-worker-number') if metadata_code == metadata_response_code_success: logger.debug("Gce Tpu Cluster detected for Jax Distributed System") diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 2fb13fde72cf..69ef77a6421d 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -49,12 +49,6 @@ def auto_detect_unset_distributed_params(cls, initialization_timeout: int | None, ) -> tuple[str | None, int | None, int | None, Sequence[int] | None]: - - if all(p is not None for p in (coordinator_address, num_processes, - process_id, local_device_ids)): - return (coordinator_address, num_processes, process_id, - local_device_ids) - # First, we check the spec detection method because it will ignore submitted values # If if succeeds. if cluster_detection_method is not None: diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index c75d1783f356..d8724e42975e 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -84,6 +84,8 @@ def is_cache_used(backend: xla_client.Client) -> bool: _cache_used = True return _cache_used + return False + def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" @@ -265,12 +267,21 @@ def put_executable_and_time( cache.put(cache_key, executable_and_time) -def get_cache_key(module: ir.Module, - devices: np.ndarray, - compile_options, - backend) -> str: - return cache_key.get(module, devices, compile_options, backend, - "zstandard" if zstandard is not None else "zlib") +def get_cache_key( + module: ir.Module, + devices: np.ndarray, + compile_options, + backend, + ignore_callbacks: cache_key.IgnoreCallbacks = cache_key.IgnoreCallbacks.NO, +) -> str: + return cache_key.get( + module, + devices, + compile_options, + backend, + "zstandard" if zstandard is not None else "zlib", + ignore_callbacks, + ) def is_initialized() -> bool: diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 8a2d6047e9b8..6fbd9ab4e3a5 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -24,14 +24,17 @@ from typing import Any, Callable import warnings +from jax._src import cache_key as cache_key_type from jax._src import compilation_cache from jax._src import config as config from jax._src import distributed from jax._src import lib from jax._src import monitoring +from jax._src import path as pathlib from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir +from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir import numpy as np @@ -188,7 +191,17 @@ def get_compile_options( assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment + build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value + build_options.memory_fitting_effort = config.memory_fitting_effort.value + if env_options_overrides is not None: + # Some overrides are passed directly on build_options. + overrides_on_build_options = [ + 'exec_time_optimization_effort', 'memory_fitting_effort'] + env_options_overrides = dict(env_options_overrides) + for name in overrides_on_build_options: + if name in env_options_overrides: + setattr(build_options, name, env_options_overrides.pop(name)) compile_options.env_option_overrides = list(env_options_overrides.items()) debug_options = compile_options.executable_build_options.debug_options @@ -234,6 +247,31 @@ def get_compile_options( debug_options.xla_detailed_logging = detailed_logging + # If persistent cache is enabled, also enable additional XLA caching features. + if compilation_cache.is_persistent_cache_enabled() and jaxlib_version > (0, 4, 35): + # compilation_cache_dir can't be None here, but the type checker is a bit + # strict. + path = pathlib.Path(config.compilation_cache_dir.value or "") + enabled_flags = config.persistent_cache_enable_xla_caches.value or "" + + if enabled_flags == "all" or "xla_gpu_kernel_cache_file" in enabled_flags: + kernel_cache_path = path / "xla_gpu_kernel_cache_file" + debug_options.xla_gpu_kernel_cache_file = str(kernel_cache_path) + # This option is required to use the kernel cache. + debug_options.xla_gpu_enable_llvm_module_compilation_parallelism = True + logger.debug("Enabling XLA kernel cache at '%s'", kernel_cache_path) + + if enabled_flags == "all" or "xla_gpu_per_fusion_autotune_cache_dir" in enabled_flags: + autotune_cache_path = path / "xla_gpu_per_fusion_autotune_cache_dir" + debug_options.xla_gpu_per_fusion_autotune_cache_dir = str(autotune_cache_path) + logger.debug("Enabling XLA autotuning cache at '%s'", autotune_cache_path) + + # Set caching mode so that only process 0 can write to the cache. + if distributed.global_state.process_id == 0: + debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.UPDATE + else: + debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.READ + return compile_options @@ -310,71 +348,69 @@ def compile_or_get_cached( use_compilation_cache = compilation_cache.is_cache_used(backend) + is_multi_process = ( + len({device.process_index for device in devices.flatten()}) > 1 + ) + min_device_process_id = min( + devices.flatten(), key=lambda device: device.id + ).process_index + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 + ) + if not use_compilation_cache: + if ( + is_multi_process + and is_auto_pgle_used + and distributed.global_state.client is not None + ): + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, + ) + ) + return backend_compile(backend, computation, compile_options, host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') try: + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + cache_key = compilation_cache.get_cache_key( - computation, devices, compile_options, backend) + computation, + devices, + compile_options, + backend, + ignore_callbacks=ignore_callbacks, + ) except xc._xla.XlaRuntimeError as ex: logger.error("compile_or_get_cached: unable to generate cache key, " "skipping the cache: %s", ex) return backend_compile(backend, computation, compile_options, host_callbacks) - is_multi_process = ( - len({device.process_index for device in devices.flatten()}) > 1) - min_device_process_id = ( - min(devices.flatten(), key=lambda device: device.id).process_index) - - # When PGLE is enabled there might be 3 types of situations: - # 1. PGLE profiled module (the one which was recompiled with FDO profile) is - # in the persistent cache. In this case the module should be returned from - # cache and PGLE should be disabled for this module. Is module is stored in - # the persistent cache under the "pgle_profiled_module_key" which calculated - # with replacing FDO profile with flag which identify that module were PGLE - # profiled. - # 2. PGLE profiled module is not in the persistent cache and the module is - # getting built with an FDO profile. In this case we need to share FDO profile - # with other processes and store the result under the - # "pgle_profiled_module_key" so later in case 1 we will be able to find the - # module. - # 3. PGLE profiled module is not in the persistent cache and the module is - # getting compiled to be PGLEd (FDO profile is empty). In this case we need to - # simply return the non-PGLE profiled module from the persistent cache. - if (config.enable_pgle.value - and config.pgle_profiling_runs.value > 0): - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, devices, compile_options, backend) - compile_options.executable_build_options.fdo_profile = fdo_profile - - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - cache_key = pgle_profiled_module_key - if pgle_profiler is not None: - pgle_profiler.disable() - elif fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - cache_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = _share_fdo_profiles( - computation, devices, compile_options, backend, - distributed.global_state.client, - min_device_process_id - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile: %s", - module_name, - compile_options.executable_build_options.fdo_profile, - ) + if is_auto_pgle_used: + cache_key = _resolve_pgle_module_cache_key( + computation, + devices, + compile_options, + backend, + pgle_profiler, + is_multi_process, + cache_key, + module_name, + min_device_process_id, + ) cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( @@ -440,6 +476,75 @@ def compile_or_get_cached( cache_key, ) + +# When PGLE is enabled there might be 3 types of situations: +# 1. PGLE profiled module (the one which was recompiled with FDO profile) is +# in the persistent cache. In this case the module should be returned from +# cache and PGLE should be disabled for this module. Is module is stored in +# the persistent cache under the "pgle_profiled_module_key" which calculated +# with replacing FDO profile with flag which identify that module were PGLE +# profiled. +# 2. PGLE profiled module is not in the persistent cache and the module is +# getting built with an FDO profile. In this case we need to share FDO profile +# with other processes and store the result under the +# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# module. +# 3. PGLE profiled module is not in the persistent cache and the module is +# getting compiled to be PGLEd (FDO profile is empty). In this case we need to +# simply return the non-PGLE profiled module from the persistent cache. +def _resolve_pgle_module_cache_key( + computation: ir.Module, + devices: np.ndarray, + compile_options: xc.CompileOptions, + backend: xc.Client, + pgle_profiler: profiler.PGLEProfiler | None, + is_multi_process: bool, + cache_key: str, + module_name: str, + min_device_process_id: int, +) -> str: + fdo_profile = compile_options.executable_build_options.fdo_profile + compile_options.executable_build_options.fdo_profile = b"pgle profiled" + + pgle_profiled_module_key = compilation_cache.get_cache_key( + computation, + devices, + compile_options, + backend, + cache_key_type.IgnoreCallbacks.ALL, + ) + compile_options.executable_build_options.fdo_profile = fdo_profile + + result_key = cache_key + if _is_executable_in_cache(backend, pgle_profiled_module_key): + # Load PGLE profiled module from the persistent cache. + result_key = pgle_profiled_module_key + if pgle_profiler is not None: + pgle_profiler.disable() + elif fdo_profile is not None and len(fdo_profile) > 0: + # Store module under PGLE profiled module cache key. + result_key = pgle_profiled_module_key + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, + ) + ) + else: + compile_options.executable_build_options.fdo_profile = fdo_profile + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(compile_options.executable_build_options.fdo_profile), + ) + return result_key + + # The process that has the lowest device ID should share FDO profile before # compilation with other processes. def _share_fdo_profiles( @@ -457,28 +562,39 @@ def _share_fdo_profiles( return fdo_profile compile_options.executable_build_options.fdo_profile = b"" - profile_key = ( - compilation_cache.get_cache_key( - computation, devices, compile_options, backend - ) - + "_fdo_sync" - ) + try: + profile_key = ( + compilation_cache.get_cache_key( + computation, + devices, + compile_options, + backend, + cache_key_type.IgnoreCallbacks.ALL, + ) + + "_fdo_sync" + ) + except xc._xla.XlaRuntimeError as ex: + logger.error( + "compile_or_get_cached: unable to generate cache key, " + "skipping the fdo profile sharing: %s", + ex, + ) + return fdo_profile + if profile_key in _share_fdo_profiles.modules_profiles: return _share_fdo_profiles.modules_profiles[profile_key] share_timeout = config.share_binary_between_hosts_timeout_ms.value if distributed.global_state.process_id == min_process_id: logger.debug( - "Sharing FDO profile: %s. For module %s. Process %d.", - fdo_profile, + "Module %s. Sharing FDO profile. Process %d.", module_name, min_process_id, ) global_client.key_value_set_bytes(profile_key, fdo_profile) else: logger.debug( - "Waiting for FDO profile: %s. For module %s. Should be set by process %d.", - fdo_profile, + "Module %s. Waiting for FDO profile which should be set by process %d.", module_name, min_process_id, ) diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 4495d38f9da8..7bd9b9b08b7b 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -29,8 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) - config.update_thread_local_jit_state( - compute_on_context_manager=tuple(compute_on_context.stack)) + config.compute_on_context_manager.set_local( + tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -39,16 +39,16 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() - config.update_thread_local_jit_state( - compute_on_context_manager=tuple(compute_on_context.stack)) + config.compute_on_context_manager.set_local(tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None def _check_valid(c_type: str): - if c_type not in {'device_host', 'device'}: - raise ValueError('Invalid compute type received. Current supported values ' - f'are `device_host` and `device`. Got {c_type}') + if c_type not in {'device_host', 'device', 'tpu_sparsecore'}: + raise ValueError( + 'Invalid compute type received. Current supported values ' + f'are `device_host`, `device` and `tpu_sparsecore`. Got {c_type}') @contextmanager def compute_on(compute_type: str): diff --git a/jax/_src/config.py b/jax/_src/config.py index 8bebd7d904a6..f1b170050a6b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -14,22 +14,22 @@ from __future__ import annotations -from collections.abc import Callable, Hashable, Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import contextlib import functools import itertools import logging import os import sys -import threading -from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast +from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, cast -from jax._src import lib from jax._src.lib import guard_lib from jax._src.lib import jax_jit from jax._src.lib import xla_client from jax._src import logging_config +config_ext = xla_client._xla.config + logger = logging.getLogger(__name__) _T = TypeVar('_T') @@ -199,29 +199,20 @@ def trace_context(): Values included in this set should also most likely be included in the C++ JIT state, which is handled separately. """ - tls = jax_jit.thread_local_state() - axis_env_state = () - mesh_context_manager = () - xla_metadata_context_manager = () - compute_on_context_manager = () - - context: Any = tls.extra_jit_context - if context and context.axis_env_state is not None: - axis_env_state = context.axis_env_state - if context and context.mesh_context_manager: - mesh_context_manager = context.mesh_context_manager - if context and context.xla_metadata_context_manager: - xla_metadata_context_manager = context.xla_metadata_context_manager - if context and context.compute_on_context_manager: - compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, - compute_on_context_manager, enable_x64.value, + return (axis_env_state.value, mesh_context_manager.value, + xla_metadata_context_manager.value, + abstract_mesh_context_manager.value, + device_context.value, + compute_on_context_manager.value, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, numpy_dtype_promotion.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, threefry_partitionable.value, threefry_gpu_kernel_lowering.value, sharding_in_types.value, + use_direct_linearize.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -243,16 +234,10 @@ def trace_context(): class NoDefault: pass no_default = NoDefault() - -class _Unset: pass -unset = _Unset() - -_thread_local_state = threading.local() - -class State(Generic[_T]): +class State(config_ext.Config[_T]): __slots__ = ( - '_name', '_value', '_update_thread_local_hook', '_update_global_hook', + '_name', '_update_thread_local_hook', '_update_global_hook', '_validator', '_default_context_manager_value', '__doc__', '__name__', ) @@ -266,7 +251,9 @@ def __init__( validator: Callable[[Any], None] | None = None, extra_description: str = '', default_context_manager_value: Any = no_default, + include_in_jit_key: bool = False, ): + super().__init__(default, include_in_jit_key) self._name = name self.__name__ = name[4:] if name.startswith('jax_') else name self.__doc__ = (f"Context manager for `{name}` config option" @@ -275,7 +262,10 @@ def __init__( self._update_thread_local_hook = update_thread_local_hook self._validator = validator self._default_context_manager_value = default_context_manager_value - self._set(default) + if self._validator: + self._validator(default) + if self._update_global_hook: + self._update_global_hook(default) def __bool__(self) -> NoReturn: raise TypeError( @@ -286,15 +276,10 @@ def __bool__(self) -> NoReturn: def _set(self, value: _T) -> None: if self._validator: self._validator(value) - self._value = value + self.set_global(value) if self._update_global_hook: self._update_global_hook(value) - @property - def value(self) -> _T: - val = _thread_local_state.__dict__.get(self._name, unset) - return cast(_T, val) if val is not unset else self._value - @contextlib.contextmanager def __call__(self, new_val: Any = no_default): if new_val is no_default: @@ -308,21 +293,18 @@ def __call__(self, new_val: Any = no_default): "the config option.") if self._validator: self._validator(new_val) - prev_val = getattr(_thread_local_state, self._name, unset) - setattr(_thread_local_state, self._name, new_val) + prev_val = self.swap_local(new_val) if self._update_thread_local_hook: self._update_thread_local_hook(new_val) try: yield finally: - if prev_val is unset: - delattr(_thread_local_state, self._name) - if self._update_thread_local_hook: + self.set_local(prev_val) + if self._update_thread_local_hook: + if prev_val is config_ext.unset: self._update_thread_local_hook(None) - else: - setattr(_thread_local_state, self._name, prev_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(cast(_T, prev_val)) + else: + self._update_thread_local_hook(cast(Optional[Any], prev_val)) def _add_hooks(self, update_global_hook, update_thread_local_hook): """Private method that adds hooks to an existing context-manager. @@ -330,7 +312,7 @@ def _add_hooks(self, update_global_hook, update_thread_local_hook): Used to avoid cyclic import dependencies.""" self._update_thread_local_hook = update_thread_local_hook self._update_global_hook = update_global_hook - update_global_hook(self._value) + update_global_hook(self.get_global()) UPGRADE_BOOL_HELP = ( @@ -351,6 +333,7 @@ def bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', + include_in_jit_key: bool = False, ) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. @@ -415,7 +398,8 @@ def bool_state( s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, - extra_description=extra_description, default_context_manager_value=True) + extra_description=extra_description, default_context_manager_value=True, + include_in_jit_key=include_in_jit_key) config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -429,6 +413,7 @@ def enum_state( *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -468,6 +453,7 @@ def validator(new_val): update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator, + include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -486,6 +472,7 @@ def optional_enum_state( *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. @@ -521,7 +508,7 @@ def validate(new_val): s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, - validate + validate, include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -539,6 +526,7 @@ def int_state( *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -573,7 +561,8 @@ def validate(new_val): f'got {new_val} of type {type(new_val)}') s = State[int](name, default, help, update_global_hook, - update_thread_local_hook, validate) + update_thread_local_hook, validate, + include_in_jit_key=include_in_jit_key) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -824,90 +813,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: already_configured_with_absl = False -# The C++ JIT maintains its own copy of several configuration items as -# a global/thread-local state. These methods allow updates to part of the -# state when a configuration value changes. -class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool = False - random_seed_offset: int = 0 - threefry_partitionable: bool = False - threefry_gpu_kernel_lowering: bool = False - sharding_in_types: bool = False - softmax_custom_jvp: bool = False - xla_profile_version: int = 0 - pgle_profiling_runs: int = 0 - enable_pgle: bool = False - use_shardy_partitioner: bool = False - - -def _update_global_jit_state(**kw): - gs = jax_jit.global_state() - context = gs.extra_jit_context or _GlobalExtraJitContext() - gs.extra_jit_context = context._replace(**kw) - - -class _ThreadLocalExtraJitContext(NamedTuple): - """A namedtuple containing states to add to the cache key. - - Just in time compilation (for jit, pmap, etc) behavior is configurable through - global and thread-local options, used in the cache key. - - The initialization, which uses both config.py and core.py is done using - `_update_thread_local_jit_state` in core.py to prevent circular imports. - """ - dynamic_trace_state: Any | None = None - axis_env_state: Hashable = () - mesh_context_manager: Hashable = () - compute_on_context_manager: Hashable = () - xla_metadata_context_manager: Hashable = () - - # Values set by _StateContextManager context managers. - # CAUTION: these must be initialized to `None`! The state context manager - # restores these to None on exit. If the object default is not `None`, the - # context manager is not a no-op, which leads to problems with stale state - # (e.g. spurious cache misses in tests). - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool | None = None - random_seed_offset: int | None = None - threefry_partitionable: bool | None = None - threefry_gpu_kernel_lowering: bool | None = None - sharding_in_types: bool | None = None - softmax_custom_jvp: bool | None = None - xla_profile_version: int | None = None - pgle_profiling_runs: int | None = None - enable_pgle: bool | None = None - use_shardy_partitioner: bool | None = None - - -class _ThreadLocalStateCache(threading.local): - """"A thread local cache for _ThreadLocalExtraJitContext - - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus - incurring dispatch overhead for comparing this python object during jit calls. - We want to deduplicate the objects that have the same hash/equality to also - have the same object ID, since the equality check is much faster if the object - IDs match. - """ - def __init__(self): - self.canonicalize = functools.lru_cache(128)(lambda x: x) - - -_thread_local_state_cache = _ThreadLocalStateCache() - - -def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() - # After xla_client._version >= 70, the thread_local object will necessarily - # be initialized when accessed. The following line can be removed when the - # minimum jaxlib version is past version 70 - context = tls.extra_jit_context or _ThreadLocalExtraJitContext() - tmp = context._replace(**kw) - tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) +trace_state = config_ext.Config(None, include_in_jit_key=True) +axis_env_state = config_ext.Config((), include_in_jit_key=True) +mesh_context_manager = config_ext.Config((), include_in_jit_key=True) +abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) +device_context = config_ext.Config((), include_in_jit_key=True) +compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) +xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) # TODO(b/214340779): remove flag when XLA:CPU is improved. @@ -1061,10 +973,10 @@ def update_thread_local_jit_state(**kw): help='If True, pmap and shard_map API will be merged.') def _update_jax_memories_global(val): - lib.jax_jit.global_state().enable_memories = val + jax_jit.global_state().enable_memories = val def _update_jax_memories_thread_local(val): - lib.jax_jit.thread_local_state().enable_memories = val + jax_jit.thread_local_state().enable_memories = val enable_memories = bool_state( 'jax_enable_memories', @@ -1099,10 +1011,7 @@ def _update_jax_memories_thread_local(val): name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), - update_global_hook=lambda val: _update_global_jit_state( - random_seed_offset=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - random_seed_offset=val) + include_in_jit_key=True, ) legacy_prng_key = enum_state( @@ -1137,10 +1046,7 @@ def _update_jax_memories_thread_local(val): 'may result in extraneous communication and/or redundant distributed ' 'computation. With this flag, the communication overheads disappear ' 'in some cases.'), - update_global_hook=lambda val: _update_global_jit_state( - threefry_partitionable=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - threefry_partitionable=val)) + include_in_jit_key=True) threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', @@ -1148,21 +1054,26 @@ def _update_jax_memories_thread_local(val): help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' 'This makes compile times faster at a potential runtime memory ' 'cost.'), - update_global_hook=lambda val: _update_global_jit_state( - threefry_gpu_kernel_lowering=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - threefry_gpu_kernel_lowering=val)) + include_in_jit_key=True) sharding_in_types = bool_state( name='jax_sharding_in_types', default=False, help=('When True, enables forward only sharding propagation in JAX and ' 'avals have sharding on them.'), - update_global_hook=lambda val: _update_global_jit_state( - sharding_in_types=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - sharding_in_types=val)) + include_in_jit_key=True) + +use_direct_linearize = bool_state( + name='jax_use_direct_linearize', + default=False, + help=('Use direct linearization instead JVP followed by partial eval'), + include_in_jit_key=True) +data_dependent_tracing_fallback = bool_state( + name='jax_data_dependent_tracing_fallback', + default=False, + help=('When True, falls back to trace dispatch based on data dependence ' + 'instead of throwing an escaped tracer error.')) softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', @@ -1171,10 +1082,7 @@ def _update_jax_memories_thread_local(val): help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' 'behavior. See https://github.com/jax-ml/jax/pull/15677'), - update_global_hook=lambda val: _update_global_jit_state( - softmax_custom_jvp=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - softmax_custom_jvp=val)) + include_in_jit_key=True) enable_custom_vjp_by_custom_transpose = bool_state( @@ -1212,6 +1120,15 @@ def _update_jax_memories_thread_local(val): ' filesystem being used for the cache. ' '* > 0: the actual minimum size desired; no overrides.')) +# TODO: Change default to all +persistent_cache_enable_xla_caches = optional_string_state( + name='jax_persistent_cache_enable_xla_caches', + default='xla_gpu_per_fusion_autotune_cache_dir', + help=('When the persistent cache is enabled, additional XLA caching will ' + 'also be enabled automatically. This option can be used to configure' + 'which XLA caching methods will be enabled.'), +) + compilation_cache_include_metadata_in_key = bool_state( name='jax_compilation_cache_include_metadata_in_key', default=False, @@ -1290,9 +1207,7 @@ def _update_jax_memories_thread_local(val): 'number times with collected data provided to the profile guided latency ' 'estimator.' ), - update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - enable_pgle=val), + include_in_jit_key=True, ) pgle_profiling_runs = int_state( @@ -1302,12 +1217,7 @@ def _update_jax_memories_thread_local(val): 'Amount of times module should be profiled before recompilation when ' 'PGLE is used.' ), - update_global_hook=lambda val: _update_global_jit_state( - pgle_profiling_runs=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - pgle_profiling_runs=val - ), + include_in_jit_key=True, ) pgle_aggregation_percentile = int_state( @@ -1373,10 +1283,7 @@ def _update_jax_memories_thread_local(val): 'between arrays. Options are "standard" or "strict"; in strict-mode, ' 'binary operations between arrays of differing strongly-specified ' 'dtypes will result in an error.'), - update_global_hook=lambda val: \ - _update_global_jit_state(numpy_dtype_promotion=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(numpy_dtype_promotion=val)) + include_in_jit_key=True) disallow_mesh_context_manager = bool_state( name='jax_disallow_mesh_context_manager', @@ -1388,10 +1295,10 @@ def _update_jax_memories_thread_local(val): ) def _update_x64_global(val): - lib.jax_jit.global_state().enable_x64 = val + jax_jit.global_state().enable_x64 = val def _update_x64_thread_local(val): - lib.jax_jit.thread_local_state().enable_x64 = val + jax_jit.thread_local_state().enable_x64 = val enable_x64 = bool_state( name='jax_enable_x64', @@ -1406,15 +1313,17 @@ def _update_x64_thread_local(val): setattr(Config, "x64_enabled", property(lambda _: enable_x64.value)) def _update_default_device_global(val): - lib.jax_jit.global_state().default_device = val + jax_jit.global_state().default_device = val def _update_default_device_thread_local(val): - lib.jax_jit.thread_local_state().default_device = val + jax_jit.thread_local_state().default_device = val def _validate_default_device(val): - if val is not None and not isinstance(val, xla_client.Device): + if (val is not None and + not isinstance(val, xla_client.Device) and + val not in ['cpu', 'gpu', 'tpu']): # TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when # all JAX backends use a single C++ device interface. if 'Device' in str(type(val)): @@ -1422,12 +1331,11 @@ def _validate_default_device(val): 'Allowing non-`xla_client.Device` default device: %s, type: %s', repr(val), type(val)) return - raise ValueError('jax.default_device must be passed a Device object (e.g. ' - f"`jax.devices('cpu')[0]`), got: {val!r}") + raise ValueError('jax.default_device must be passed either a Device object (e.g. ' + f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'" + f", got: {val!r}") -# TODO(skye): default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). default_device = string_or_object_state( name='jax_default_device', default=None, @@ -1443,10 +1351,10 @@ def _validate_default_device(val): validator=_validate_default_device) def _update_disable_jit_global(val): - lib.jax_jit.global_state().disable_jit = val + jax_jit.global_state().disable_jit = val def _update_disable_jit_thread_local(val): - lib.jax_jit.thread_local_state().disable_jit = val + jax_jit.thread_local_state().disable_jit = val disable_jit = bool_state( name='jax_disable_jit', @@ -1462,14 +1370,20 @@ def _update_disable_jit_thread_local(val): default='allow', help=('Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").'), - update_global_hook=lambda val: \ - _update_global_jit_state(numpy_rank_promotion=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(numpy_rank_promotion=val)) + include_in_jit_key=True) default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', - enum_values=['default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32'], + enum_values=[ + # Legacy precision API values + 'default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32', + # Dot algorithm presets + 'ANY_F8_ANY_F8_F32', 'ANY_F8_ANY_F8_F32_FAST_ACCUM', 'ANY_F8_ANY_F8_ANY', + 'ANY_F8_ANY_F8_ANY_FAST_ACCUM', 'F16_F16_F16', 'F16_F16_F32', + 'BF16_BF16_BF16', 'BF16_BF16_F32', 'BF16_BF16_F32_X3', + 'BF16_BF16_F32_X6', 'TF32_TF32_F32', 'TF32_TF32_F32_X3', 'F32_F32_F32', + 'F64_F64_F64', + ], default=None, help=('Control the default matmul and conv precision for 32bit inputs.\n\n' @@ -1486,11 +1400,14 @@ def _update_disable_jit_thread_local(val): 'convolution on 32bit inputs. The levels roughly describe the ' "precision at which scalar products are computed. The 'bfloat16' " "option is the fastest and least precise; 'float32' is similar to " - "full float32 precision; 'tensorfloat32' is intermediate.\n\n"), - update_global_hook=lambda val: \ - _update_global_jit_state(default_matmul_precision=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(default_matmul_precision=val)) + "full float32 precision; 'tensorfloat32' is intermediate.\n\n" + + 'This parameter can also be used to specify an accumulation ' + '"algorithm" for functions that perform matrix multiplications, like ' + ':func:`jax.lax.dot`. To specify an algorithm, set this option to ' + 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), + include_in_jit_key=True) + traceback_filtering = enum_state( name = 'jax_traceback_filtering', @@ -1525,10 +1442,14 @@ def _update_disable_jit_thread_local(val): default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' 'dynamic shapes.'), - update_global_hook=lambda val: \ - _update_global_jit_state(dynamic_shapes=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(dynamic_shapes=val)) + include_in_jit_key=True) + +# This is for stackless backward compat with e.g. equinox +eager_constant_folding = bool_state( + name='eager_constant_folding', + default=False, + help=('Attempt constant folding during staging.'), + include_in_jit_key=True) # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. @@ -1587,10 +1508,7 @@ def _update_disable_jit_thread_local(val): 'Optional profile version for XLA compilation. This is meaningful ' 'only when XLA is configured to support the remote compilation ' 'profile feature.'), - update_global_hook=lambda val: _update_global_jit_state( - xla_profile_version=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - xla_profile_version=val), + include_in_jit_key=True, ) @contextlib.contextmanager @@ -1719,59 +1637,43 @@ def transfer_guard(new_val: str) -> Iterator[None]: yield -if lib.xla_extension_version < 293: - - def array_garbage_collection_guard(_val): - raise NotImplementedError( - 'jaxlib version is too low for garbage collection guard' - ) - -else: - def _update_garbage_collection_guard(state, key, val): - """Applies the transfer guard level within guard_lib.""" - if val is None: - setattr(state, key, None) - elif val == 'allow': - setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW) - elif val == 'log': - setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG) - elif val == 'fatal': - setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL) - else: - assert False, f'Invalid garbage collection guard level {val}' - - array_garbage_collection_guard = optional_enum_state( - name='jax_array_garbage_collection_guard', - enum_values=['allow', 'log', 'fatal'], - # The default is applied by guard_lib. - default=None, - help=( - 'Select garbage collection guard level for "jax.Array" objects.\nThis' - ' option can be used to control what happens when a "jax.Array"' - ' object is garbage collected. It is desirable for "jax.Array"' - ' objects to be freed by Python reference couting rather than garbage' - ' collection in order to avoid device memory being held by the arrays' - ' until garbage collection occurs.\n\nValid values are:\n * "allow":' - ' do not log garbage collection of "jax.Array" objects.\n * "log":' - ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' - ' fatal error if a "jax.Array" is garbage collected.\nDefault is' - ' "allow".' - ), - update_global_hook=lambda val: _update_garbage_collection_guard( - guard_lib.global_state(), 'garbage_collect_array', val - ), - update_thread_local_hook=lambda val: _update_garbage_collection_guard( - guard_lib.thread_local_state(), 'garbage_collect_array', val - ), - ) +def _update_garbage_collection_guard(state, key, val): + """Applies the transfer guard level within guard_lib.""" + if val is None: + setattr(state, key, None) + elif val == 'allow': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW) + elif val == 'log': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG) + elif val == 'fatal': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL) + else: + assert False, f'Invalid garbage collection guard level {val}' -def _update_debug_log_modules(module_names_str: str | None): - logging_config.disable_all_debug_logging() - if not module_names_str: - return - module_names = module_names_str.split(',') - for module_name in module_names: - logging_config.enable_debug_logging(module_name) +array_garbage_collection_guard = optional_enum_state( + name='jax_array_garbage_collection_guard', + enum_values=['allow', 'log', 'fatal'], + # The default is applied by guard_lib. + default=None, + help=( + 'Select garbage collection guard level for "jax.Array" objects.\nThis' + ' option can be used to control what happens when a "jax.Array"' + ' object is garbage collected. It is desirable for "jax.Array"' + ' objects to be freed by Python reference couting rather than garbage' + ' collection in order to avoid device memory being held by the arrays' + ' until garbage collection occurs.\n\nValid values are:\n * "allow":' + ' do not log garbage collection of "jax.Array" objects.\n * "log":' + ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' + ' fatal error if a "jax.Array" is garbage collected.\nDefault is' + ' "allow".' + ), + update_global_hook=lambda val: _update_garbage_collection_guard( + guard_lib.global_state(), 'garbage_collect_array', val + ), + update_thread_local_hook=lambda val: _update_garbage_collection_guard( + guard_lib.thread_local_state(), 'garbage_collect_array', val + ), +) # Don't define a context manager since this isn't threadsafe. string_state( @@ -1780,7 +1682,20 @@ def _update_debug_log_modules(module_names_str: str | None): help=('Comma-separated list of module names (e.g. "jax" or ' '"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging ' 'for.'), - update_global_hook=_update_debug_log_modules) + update_global_hook=logging_config.update_debug_log_modules) + +# Don't define a context manager since this isn't threadsafe. +optional_enum_state( + name='jax_logging_level', + enum_values=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default=logging.getLevelName(logging.getLogger("jax").level), + help=('Set the corresponding logging level on all jax loggers. Only string' + ' values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR",' + ' "CRITICAL"] are accepted. If None, the logging level will not be' + ' set. Includes C++ logging.'), + update_global_hook=lambda logging_level: \ + logging_config.update_logging_level_global(logging_level=logging_level) +) pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', @@ -1797,10 +1712,28 @@ def _update_debug_log_modules(module_names_str: str | None): 'framework for MLIR. Currently Shardy is experimental in JAX. See ' 'www.github.com/openxla/shardy' ), - update_global_hook=lambda val: _update_global_jit_state( - use_shardy_partitioner=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - use_shardy_partitioner=val + include_in_jit_key=True, +) + +gpu_use_magma = enum_state( + name='jax_use_magma', + enum_values=['off', 'on', 'auto'], + default='auto', + help=( + 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. ' + 'See the documentation for lax.linalg.eig for more details about how ' + 'to use this feature.' ), ) + +exec_time_optimization_effort = float_state( + name='jax_exec_time_optimization_effort', + default=0.0, + help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].' +) + +memory_fitting_effort = float_state( + name='jax_memory_fitting_effort', + default=0.0, + help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].' +) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7674ba76da38..0c2949de07af 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -14,13 +14,12 @@ from __future__ import annotations from collections import Counter, defaultdict, deque, namedtuple -from collections.abc import (Callable, Collection, Generator, Hashable, - Iterable, Iterator, Set, Sequence, MutableSet, - MutableMapping) +from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator, + Sequence, MutableSet, MutableMapping) from contextlib import contextmanager, ExitStack from dataclasses import dataclass import functools -from functools import partial, partialmethod, total_ordering +from functools import partial, total_ordering import gc import inspect import itertools as it @@ -29,17 +28,17 @@ import threading import types from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar, - cast, overload, Union) + overload, Union) import warnings from weakref import ref import numpy as np -from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects from jax._src import compute_on +from jax._src import mesh as mesh_lib from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -47,7 +46,7 @@ from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, - tuple_delete, as_hashable_function, + tuple_delete, HashableFunction, HashableWrapper, weakref_lru_cache, partition_list, StrictABCMeta) import jax._src.pretty_printer as pp @@ -257,7 +256,14 @@ def _repr_pretty_(self, p, cycle): @curry def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): - return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) + # TODO(dougalm): remove this hack when we add contexts to jaxpr. + # debug_nans is sometimes disabled locally at the traceable level by ops that + # work with nans internally, like jnp.var. The right thing to do is to add + # contexts to our jaxpr representation so that we can capture these local + # context modifications. In the meantime, disabling the checks when we + # round-trip prevents those ops producing spurious errors. + with config.debug_nans(False): + return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) class JaxprEqnContext: @@ -362,7 +368,7 @@ class Var: def __init__(self, suffix: str, aval: AbstractValue): self.count = next(_var_counter) self.suffix = suffix - self.aval = raise_to_shaped(aval) + self.aval = aval # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not # care about variable ordering, but the downstream package kfac_jax does. @@ -425,6 +431,8 @@ class Primitive: call_primitive: bool = False # set for map primitives processed in final style. map_primitive: bool = False + # set for ref primitives + ref_primitive: bool = False def __init__(self, name: str): self.name = name @@ -433,14 +441,27 @@ def __repr__(self): return f'{self.name}' def bind(self, *args, **params): - assert (not config.enable_checks.value or - all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args - return self.bind_with_trace(find_top_trace(args), args, params) + for arg in args: + if (isinstance(arg, Tracer) + and not arg._trace.is_valid() + and not config.data_dependent_tracing_fallback.value): + raise escaped_tracer_error(arg) + # TODO: figure out how to handle function arguments + # assert (not config.enable_checks.value or + # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args + + # This is equivalent to "with take_current_trace()", but the bind() code + # is called frequently and it's slightly faster to avoid using a context + # manager object. + prev_trace = trace_ctx.trace + trace_ctx.set_trace(eval_trace) + try: + return self.bind_with_trace(prev_trace, args, params) + finally: + trace_ctx.set_trace(prev_trace) def bind_with_trace(self, trace, args, params): - with pop_level(trace.level): - out = trace.process_primitive(self, map(trace.full_raise, args), params) - return map(full_lower, out) if self.multiple_results else full_lower(out) + return trace.process_primitive(self, args, params) def def_impl(self, impl): self.impl = impl @@ -454,9 +475,9 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval): self.abstract_eval = effectful_abstract_eval return effectful_abstract_eval - def def_custom_bind(self, bind): - self.bind = bind - return bind + def def_bind_with_trace(self, bind_with_trace): + self.bind_with_trace = bind_with_trace + return bind_with_trace def impl(self, *args, **params): raise NotImplementedError("Evaluation rule for '{}' not implemented" @@ -519,65 +540,18 @@ def write(v: Var, val: Any) -> None: TracerType = TypeVar('TracerType', bound='Tracer') class Trace(Generic[TracerType]): - __slots__ = ['main', 'level', 'sublevel'] - - main: MainTrace - level: int - sublevel: Sublevel - - def __init__(self, main: MainTrace, sublevel: Sublevel) -> None: - self.main = main - self.level = main.level - self.sublevel = sublevel - - def full_raise(self, val) -> TracerType: - if not isinstance(val, Tracer): - # This check is only applied to non-Tracers, because the hasattr() is - # expensive (Tracer.__getattr__) in the common case that val is a Tracer. - if hasattr(val, "dimension_as_value"): # Used for shape_poly._DimExpr - val = val.dimension_as_value() - if not isinstance(val, Tracer): - return self.pure(val) - else: - return self.pure(val) - val._assert_live() - level = self.level - sublevel = self.sublevel - if val._trace.main is self.main: - if val._trace.sublevel == sublevel: - return cast(TracerType, val) - elif val._trace.sublevel < sublevel: - return self.sublift(val) - else: - raise escaped_tracer_error( - val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}") - elif val._trace.level < level: - if val._trace.sublevel > sublevel: - raise escaped_tracer_error( - val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}") - return self.lift(val) - elif val._trace.level > level: - raise escaped_tracer_error( - val, f"Can't lift level {val} to {self}") - else: # val._trace.level == self.level: - raise escaped_tracer_error( - val, f"Different traces at same level: {val}, {self}") - - def pure(self, val) -> TracerType: - raise NotImplementedError("must override") - def lift(self, tracer) -> TracerType: + def process_primitive(self, primitive, tracers, params): raise NotImplementedError("must override") - def sublift(self, tracer) -> TracerType: - raise NotImplementedError("must override") + def invalidate(self): + self._invalidated = True - def process_primitive(self, primitive, tracers, params): - raise NotImplementedError("must override") + def is_valid(self): + return not hasattr(self, "_invalidated") def __repr__(self): - return '{}(level={}/{})'.format( - self.__class__.__name__, self.level, self.sublevel) + return '{}'.format(self.__class__.__name__) def process_call(self, call_primitive, f, tracers, params): msg = (f"{type(self)} must override process_call to handle call-like " @@ -606,24 +580,14 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, "to handle custom_vjp primitives") raise NotImplementedError(msg) + # TODO(dougalm): deprecate/delete + def full_raise(self, x): + return x -def raise_as_much_as_possible(tracer) -> Tracer: - # Find effective bottom of trace stack (highest dynamic Trace on the stack). - trace_stack = thread_local_state.trace_state.trace_stack.stack - idx = next(i for i, m in enumerate(trace_stack) if m is - thread_local_state.trace_state.trace_stack.dynamic) - - # Only pay attention to effective part of trace stack. - trace_stack = trace_stack[idx:] - - # Lift tracer into everything in the effective stack higher than its level - for trace in trace_stack: - trace = trace.with_cur_sublevel() - if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level): - tracer = trace.full_raise(tracer) - - return tracer - + # TODO(dougalm): deprecate/delete + @property + def main(self): + return getattr(self, "tag", None) def escaped_tracer_error(tracer, detail=None): num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value @@ -684,31 +648,20 @@ def _aval_property(name): class Tracer(typing.Array, metaclass=StrictABCMeta): __array_priority__ = 1000 __slots__ = ['_trace', '_line_info'] + __hash__ = None # type: ignore dtype = _aval_property('dtype') ndim = _aval_property('ndim') size = _aval_property('size') shape = _aval_property('shape') - def __hash__(self): - # TODO(jakevdp) finalize this deprecation and set __hash__ = None - # Warning added 2024-06-13 - if deprecations.is_accelerated('tracer-hash'): - raise TypeError(f"unhashable type: {type(self)}") - # Use FutureWarning rather than DeprecationWarning because hash is likely - # not called directly by the user, so we want to warn at all stacklevels. - warnings.warn( - f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an" - " error in a future JAX release.", category=FutureWarning) - return super().__hash__() - def __init__(self, trace: Trace): self._trace = trace def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" + return f"traced array with shape {self.aval.str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) @@ -729,6 +682,10 @@ def tobytes(self, order="C"): f"The tobytes() method was called on {self._error_repr()}." f"{self._origin_msg()}") + # TODO(dougalm): deprecate/delete + def full_lower(self): + raise NotImplementedError("must override: ", type(self)) + def __iter__(self): return iter(self.aval._iter(self)) @@ -738,6 +695,10 @@ def __reversed__(self): def __len__(self): return self.aval._len(self) + def to_concrete_value(self): + # Should return the concrete value if there is one, or else None. + return None + @property def sharding(self): # This attribute is part of the jax.Array API, but only defined on concrete arrays. @@ -777,17 +738,16 @@ def at(self): def aval(self): raise NotImplementedError("must override") - def _assert_live(self) -> None: - pass # Override for liveness checking - def get_referent(self) -> Any: return self # Override for object equivalence checking def __bool__(self): + if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_bool_conversion(self) return self.aval._bool(self) def __int__(self): + if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_scalar_conversion(self) return self.aval._int(self) @@ -800,16 +760,19 @@ def __complex__(self): return self.aval._complex(self) def __hex__(self): + if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._hex(self) def __oct__(self): + if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._oct(self) def __index__(self): + if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) - raise self.aval._index(self) + return self.aval._index(self) # raises a useful error on attempts to pickle a Tracer. def __reduce__(self): @@ -940,19 +903,27 @@ def unsafe_buffer_pointer(self): aval_property = namedtuple("aval_property", ["fget"]) aval_method = namedtuple("aval_method", ["fun"]) +def check_eval_args(args): + for arg in args: + if isinstance(arg, Tracer): + raise escaped_tracer_error(arg) class EvalTrace(Trace): - # See comments in https://github.com/jax-ml/jax/pull/3370 - def pure(self, x): return x - lift = sublift = pure - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, args, params): if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error - return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params) + return call_impl_with_key_reuse_checks(primitive, primitive.impl, *args, **params) else: - return primitive.impl(*tracers, **params) + # TODO(dougalm): delete. this shouldn't be necessary + args = map(full_lower, args) + if config.data_dependent_tracing_fallback.value: + for arg in args: + if isinstance(arg, Tracer): + return primitive.bind_with_trace(arg._trace, args, params) + check_eval_args(args) + return primitive.impl(*args, **params) def process_call(self, primitive, f, tracers, params): if config.debug_key_reuse.value: @@ -965,128 +936,147 @@ def process_call(self, primitive, f, tracers, params): def process_custom_transpose(self, primitive, call, tracers, **_): del primitive, _ - with new_sublevel(): - return call.call_wrapped(*tracers) + return call.call_wrapped(*tracers) def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_): del primitive, jvp, _ # Unused. - with new_sublevel(): - return fun.call_wrapped(*tracers) + return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch del primitive, fwd, bwd, _ # Unused. - with new_sublevel(): - return fun.call_wrapped(*tracers) - - -class MainTrace: - level: int - trace_type: type[Trace] - payload: dict[str, Any] - - def __init__(self, level, trace_type, **payload) -> None: - self.level = level - self.trace_type = trace_type - self.payload = payload - - def __repr__(self) -> str: - return f"MainTrace({self.level},{self.trace_type.__name__})" - - def __hash__(self) -> int: - return hash((self.level, self.trace_type)) - - def __eq__(self, other: object) -> bool: - return (isinstance(other, MainTrace) and - self.level == other.level and - self.trace_type == other.trace_type and - self.payload == other.payload) - - def with_cur_sublevel(self): - return self.trace_type(self, cur_sublevel(), **self.payload) + return fun.call_wrapped(*tracers) + + +class TraceTag: + # TODO: this works for surprisingly subtle reasons. Function transformations + # like `jvp_subtrace` are parameterized by a tag that identifies the set of + # pre-existing tracers we want to unpack during the transformation. A function + # defined in an outer scope can't have any closed-over traces, so the tag is + # irrelevant. A function defined in the current scope may have closed-over + # traces, but the tag will never change so we'll never get a spurious cache + # hit. The plan is to do away with `lu.cache` altogether, and use a simpler + # caching scheme that only caches top-level functions. Then we can remove this + # hack. + def __hash__(self): + return hash(TraceTag) + def __eq__(self, other): + return isinstance(other, TraceTag) -class TraceStack: - # See comments in https://github.com/jax-ml/jax/pull/3370 - stack: list[MainTrace] - dynamic: MainTrace +ParamDict = dict[str, Any] +AxisName = Hashable - def __init__(self): - eval_trace = MainTrace(0, EvalTrace) - self.stack = [eval_trace] - self.dynamic = eval_trace +no_axis_name = object() - def next_level(self) -> int: - return len(self.stack) +@dataclass(frozen=True) +class AxisEnv: + axis_sizes : dict[AxisName, int] + spmd_axis_names : set[AxisName] - def push(self, main_trace: MainTrace) -> None: - self.stack.append(main_trace) + def axis_size(self, axis_name): + if axis_name not in self.axis_sizes: + raise NameError(f"unbound axis name: {axis_name}") + else: + return self.axis_sizes[axis_name] - def pop(self) -> None: - self.stack.pop() + def axis_exists(self, axis_name): + return axis_name in self.axis_sizes - def __repr__(self) -> str: - stack_str = map(' {}\n'.format, self.stack[::-1]) - return f'Trace stack\n{stack_str}\n{self.dynamic}' + def axis_names(self): + return tuple(k for k in self.axis_sizes) - def copy(self): - new = self.__new__(TraceStack) - new.stack = self.stack[:] - new.dynamic = self.dynamic - return new + def pop_pure(self, axis_name): + new_sizes = self.axis_sizes.copy() + new_sizes.pop(axis_name) + return AxisEnv(new_sizes, self.spmd_axis_names) + def extend_pure(self, name_size_pairs): + new_sizes = self.axis_sizes.copy() + new_sizes.update((name, size) for name, size in name_size_pairs + if name is not no_axis_name) + return AxisEnv(new_sizes, self.spmd_axis_names) -@total_ordering -class Sublevel: + def add_spmd_axis_names(self, axis_names): + new_spmd_axis_names = self.spmd_axis_names | set(axis_names) + return AxisEnv(self.axis_sizes, new_spmd_axis_names) - def __init__(self, level: int): - self.level = level + def as_hashable_key(self): + return tuple((name, size) for (name, size) in self.axis_sizes.items() + if name is not no_axis_name) - def __repr__(self): - return str(self.level) +eval_trace = EvalTrace() +top_axis_env = AxisEnv({}, set()) - def __eq__(self, other): - return type(other) is Sublevel and self.level == other.level +class TracingContext(threading.local): + trace: Trace | None + axis_env : AxisEnv - def __lt__(self, other): - return type(other) is Sublevel and self.level < other.level + def __init__(self): + self.reset() + def reset(self): + self.trace = eval_trace + self.axis_env = top_axis_env -AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace']) -AxisName = Hashable + def is_top_level(self) -> bool: + return (self.trace is eval_trace and + self.axis_env is top_axis_env) -no_axis_name = object() + def set_trace(self, trace): + self.trace = trace + ts = ref(trace) if trace is not None else None + config.trace_state.set_local(ts) -class TraceState: - trace_stack: TraceStack - substack: list[Sublevel] - axis_env: list[AxisEnvFrame] + def set_axis_env(self, axis_env): + self.axis_env = axis_env + config.axis_env_state.set_local(axis_env.as_hashable_key()) - def __init__(self) -> None: - self.trace_stack = TraceStack() - self.substack = [Sublevel(0)] - self.axis_env = [] + def update_thread_local_jit_state(self): + ts = ref(self.trace) if self.trace is not None else None + config.trace_state.set_local(ts) + config.axis_env_state.set_local(self.axis_env.as_hashable_key()) - def copy(self): - new = self.__new__(TraceState) - new.trace_stack = self.trace_stack.copy() - new.substack = self.substack[:] - new.axis_env = self.axis_env[:] - return new +trace_ctx = TracingContext() -def _update_thread_local_jit_state(dynamic): - state = (dynamic.level, dynamic.trace_type) - config.update_thread_local_jit_state(dynamic_trace_state=state) +@contextmanager +def take_current_trace(): + prev = trace_ctx.trace + try: + trace_ctx.set_trace(eval_trace) + yield prev + finally: + trace_ctx.set_trace(prev) +@contextmanager +def set_current_trace(new): + prev = trace_ctx.trace + try: + trace_ctx.set_trace(new) + yield + finally: + trace_ctx.set_trace(prev) -# The global state of the tracer is accessed by a thread-local object. -# This allows concurrent tracing in separate threads; passing traced objects -# between threads is forbidden. -class ThreadLocalState(threading.local): - def __init__(self): - self.trace_state = TraceState() +@contextmanager +def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]): + prev = trace_ctx.axis_env + try: + trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs)) + yield + finally: + trace_ctx.set_axis_env(prev) -thread_local_state = ThreadLocalState() +@contextmanager +def add_spmd_axis_names(axis_names: AxisName | None): + prev = trace_ctx.axis_env + try: + if axis_names is not None: + trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names)) + yield + finally: + trace_ctx.set_axis_env(prev) +def get_axis_env(): + return trace_ctx.axis_env def _initialize_jax_jit_thread_local_state(): """Initializes the C++ thread-local context. @@ -1097,34 +1087,23 @@ def _initialize_jax_jit_thread_local_state(): This function does not live in `config.py`, to prevent circular imports. """ - tls = jax_jit.thread_local_state() - if tls.extra_jit_context is None: - dynamic = thread_local_state.trace_state.trace_stack.dynamic - state = (dynamic.level, dynamic.trace_type) - config.update_thread_local_jit_state(dynamic_trace_state=state) - + trace_ctx.update_thread_local_jit_state() jax_jit.set_thread_local_state_initialization_callback( _initialize_jax_jit_thread_local_state) def trace_state_clean() -> bool: - trace_state = thread_local_state.trace_state - return (trace_state.substack == [Sublevel(0)] and - trace_state.axis_env == [] and - trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and - trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace)) + return trace_ctx.is_top_level() def reset_trace_state() -> bool: """Resets the global trace state and returns True if it was already clean.""" - if not trace_state_clean(): - thread_local_state.trace_state.__init__() + if not trace_ctx.is_top_level(): + trace_ctx.reset() + trace_ctx.update_thread_local_jit_state() return False else: return True -def cur_sublevel() -> Sublevel: - return thread_local_state.trace_state.substack[-1] - TRACER_LEAK_DEBUGGER_WARNING = """\ JAX check_tracer_leaks behavior can trigger false positives when used with a debugger. To avoid false positives and silence this warning, you can disable thread tracing using @@ -1134,13 +1113,21 @@ def cur_sublevel() -> Sublevel: threading.current_thread().pydev_do_not_trace = True """ -def maybe_find_leaked_tracers(x: MainTrace | Sublevel | None - ) -> list[Tracer]: - """Find the leaked tracers holding a reference to the MainTrace or SubLevel. +@contextmanager +def ensure_no_leaks(trace:Trace): + yield + trace.invalidate() + if config.check_tracer_leaks.value: + trace_ref = ref(trace) + del trace + live_trace = trace_ref() + if live_trace is not None: + leaked_tracers = maybe_find_leaked_tracers(live_trace) + if leaked_tracers: + raise leaked_tracer_error("trace", live_trace, leaked_tracers) - It's possible there's none! eg. there's some cases where JAX itself holds a - reference to `x` inside of a lambda closure, and no tracers were leaked - by the user. In this case an empty list is returned. +def maybe_find_leaked_tracers(trace: Trace) -> list[Tracer]: + """Find the leaked tracers holding a reference to the Trace """ if not getattr(threading.current_thread(), 'pydev_do_not_trace', True): warnings.warn(TRACER_LEAK_DEBUGGER_WARNING) @@ -1148,8 +1135,7 @@ def maybe_find_leaked_tracers(x: MainTrace | Sublevel | None # only due to cyclical dependencies. (We don't care about unreachable leaked # tracers since they can't interact with user code and cause a problem.) gc.collect() - traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x))) - tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces))) + tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(trace))) return tracers def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception: @@ -1216,83 +1202,6 @@ def _why_alive_container_info(container, obj_id) -> str: return f' named {container.__name__}' return name - -@contextmanager -def new_main(trace_type: type[Trace], dynamic: bool = False, - **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/jax-ml/jax/pull/3370 - stack = thread_local_state.trace_state.trace_stack - level = stack.next_level() - main = MainTrace(level, trace_type, **payload) - stack.push(main) - if dynamic: - prev_dynamic, stack.dynamic = stack.dynamic, main - _update_thread_local_jit_state(stack.dynamic) - - try: - yield main - finally: - stack.pop() - if dynamic: - stack.dynamic = prev_dynamic - _update_thread_local_jit_state(stack.dynamic) - - if config.check_tracer_leaks.value: - t = ref(main) - del main - if t() is not None: - leaked_tracers = maybe_find_leaked_tracers(t()) - if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers) - -@contextmanager -def new_dynamic(level: int) -> Generator[None, None, None]: - stack = thread_local_state.trace_state.trace_stack - prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level] - _update_thread_local_jit_state(stack.dynamic) - try: - yield - finally: - stack.dynamic = prev_dynamic - _update_thread_local_jit_state(stack.dynamic) - -def dynamic_level() -> int: - return thread_local_state.trace_state.trace_stack.dynamic.level - -@contextmanager -def new_base_main(trace_type: type[Trace], - **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/jax-ml/jax/pull/3370 - stack = thread_local_state.trace_state.trace_stack - main = MainTrace(0, trace_type, **payload) - prev_dynamic, stack.dynamic = stack.dynamic, main - prev_base, stack.stack[0] = stack.stack[0], main - _update_thread_local_jit_state(stack.dynamic) - try: - yield main - finally: - stack.dynamic = prev_dynamic - stack.stack[0] = prev_base - _update_thread_local_jit_state(stack.dynamic) - - if config.check_tracer_leaks.value: - t = ref(main) - del main - if t() is not None: - leaked_tracers = maybe_find_leaked_tracers(t()) - if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers) - -@contextmanager -def pop_level(level: int): - if level == 0: - return (yield) # noqa: B901 - prev, thread_local_state.trace_state.trace_stack.stack = \ - thread_local_state.trace_state.trace_stack.stack, \ - thread_local_state.trace_state.trace_stack.stack[:level] - try: - yield - finally: - thread_local_state.trace_state.trace_stack.stack = prev - @contextmanager def ensure_compile_time_eval(): """Context manager to ensure evaluation at trace/compile time (or error). @@ -1353,50 +1262,21 @@ def jax_fn(x): But in some cases it can be more convenient to use this context manager. """ - with new_base_main(EvalTrace): + with config.eager_constant_folding(True): yield -eval_context = ensure_compile_time_eval # alias, backward compatibility @contextmanager -def new_sublevel() -> Generator[None, None, None]: - sublevel = Sublevel(len(thread_local_state.trace_state.substack)) - thread_local_state.trace_state.substack.append(sublevel) - try: +def eval_context(): + with set_current_trace(eval_trace): yield - finally: - thread_local_state.trace_state.substack.pop() - - if config.check_tracer_leaks.value: - t = ref(sublevel) - del sublevel - if t() is not None: - leaked_tracers = maybe_find_leaked_tracers(t()) - if leaked_tracers: - raise leaked_tracer_error("sublevel", t(), leaked_tracers) +# TODO(dougalm): deprecate/delete def full_lower(val): if isinstance(val, Tracer): return val.full_lower() else: return val - -def _get_trace_level(t: Tracer) -> int: return t._trace.level - - -def find_top_trace(xs) -> Trace: - top_tracer = max((x for x in xs if isinstance(x, Tracer)), - default=None, key=_get_trace_level) - if top_tracer is not None: - top_tracer._assert_live() - top_main = top_tracer._trace.main - else: - top_main = None - dynamic = thread_local_state.trace_state.trace_stack.dynamic - top_main = (dynamic if top_main is None or dynamic.level > top_main.level - else top_main) - return top_main.with_cur_sublevel() - def get_referent(x: Any) -> Any: return x.get_referent() if isinstance(x, Tracer) else x @@ -1435,11 +1315,14 @@ def __repr__(self): except AttributeError: return self.__class__.__name__ - def strip_weak_type(self) -> AbstractValue: + def update_weak_type(self, weak_type): return self - def join(self, other): - raise NotImplementedError("must override") + def strip_weak_type(self) -> AbstractValue: + return self.update_weak_type(False) + + def normalize(self) -> AbstractValue: + return self.strip_weak_type() def update(self, **kwargs): raise NotImplementedError("must override") @@ -1447,7 +1330,6 @@ def update(self, **kwargs): def str_short(self, short_dtypes=False): return str(self) - # For type signatures involving dynamic shapes, we use lists of abstract values # which may contain (reverse) de Bruijn indices in their shapes. class DBIdx(NamedTuple): @@ -1481,26 +1363,10 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -class Bot(AbstractValue): pass -bot = Bot() - - -def lattice_join(x: AbstractValue | None, - y: AbstractValue | None) -> AbstractValue: - if x is None: - assert y is not None - return y - elif y is None: - return x - elif isinstance(x, type(y)): - return y.join(x) - elif isinstance(y, type(x)): - return x.join(y) - elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray): - # TODO(mattjj): remove this special case after dynamic shapes are integrated - return x.join(y) - else: - raise TypeError(x, y) +# TODO(dougalm): Deprecate. This is here for backwards compat. +def lattice_join(x, y): + assert typematch(x, y) + return x # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1535,12 +1401,16 @@ def get_aval(x): else: return concrete_aval(x) -def get_type(x): - aval = get_aval(x) - if isinstance(aval, ConcreteArray): - return raise_to_shaped(aval) +get_type = get_aval + +def is_concrete(x): + return to_concrete_value(x) is not None + +def to_concrete_value(x): + if isinstance(x, Tracer): + return x.to_concrete_value() else: - return aval + return x def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) @@ -1565,10 +1435,11 @@ def concrete_or_error(force: Any, val: Any, context=""): if force is None: force = lambda x: x if isinstance(val, Tracer): - if isinstance(val.aval, ConcreteArray): - return force(val.aval.val) - else: + maybe_concrete = val.to_concrete_value() + if maybe_concrete is None: raise ConcretizationTypeError(val, context) + else: + return force(maybe_concrete) else: return force(val) @@ -1624,15 +1495,11 @@ class UnshapedArray(AbstractValue): array_abstraction_level = 4 def __init__(self, dtype, weak_type=False): + # Is it silly to initialize this object and then complain that we should + # never create one? Yes. But otherwise pytype complains. self.dtype = _dtype_object(dtype) self.weak_type = weak_type - - def update(self, dtype=None, weak_type=None): - if dtype is None: - dtype = self.dtype - if weak_type is None: - weak_type = self.weak_type - return UnshapedArray(dtype, weak_type) + raise Exception("We should never create an UnshapedArray object") def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and @@ -1659,32 +1526,11 @@ def __repr__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def to_tangent_aval(self) -> AbstractValue: - return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) - - def join(self, other): - if self.dtype == other.dtype: - if self.weak_type == other.weak_type: - return self - else: - return UnshapedArray(self.dtype, weak_type=False) - else: - raise TypeError(self, other) - def str_short(self, short_dtypes=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - def strip_weak_type(self): - """Returns a copy of the aval with weak_type=False.""" - return self.update(weak_type=False) - - @property - def shape(self): - msg = ("UnshapedArray has no shape. Please open an issue at " - "https://github.com/jax-ml/jax/issues because it's unexpected for " - "UnshapedArray instances to ever be produced.") - raise TypeError(msg) + def update_weak_type(self, weak_type): + return self.update(weak_type=weak_type) def _canonicalize_dimension(dim: DimSize) -> DimSize: # Dimensions are most commonly integral (by far), so we check that first. @@ -1744,7 +1590,7 @@ def _invalid_shape_error(shape: Shape, context: str=""): msg += f" {context}." if not config.dynamic_shapes.value and any( isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) - and not isinstance(get_aval(x), ConcreteArray) for x in shape): + and not is_concrete(x) for x in shape): msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " "smaller subfunctions.") for x in shape: @@ -1753,6 +1599,23 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) + +def get_sharding(sharding, ndim): + from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore + + if sharding is not None: + assert len(sharding.spec) == ndim + return sharding + + context_mesh = mesh_lib.get_abstract_mesh() + # TODO(yashkatariya): Error out and ask users to set the context mesh in their + # code. + if not context_mesh: + return None + assert sharding is None + return NamedSharding(context_mesh, P(*[None] * ndim)) + + class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 @@ -1762,20 +1625,18 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None): self.dtype = _dtype_object(dtype) self.weak_type = weak_type if config.sharding_in_types.value: - if sharding is not None: - assert len(sharding.spec) == len(self.shape) - self.sharding = sharding + self.sharding = get_sharding(sharding, len(self.shape)) - def update(self, shape=None, dtype=None, weak_type=None, sharding=None): + def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: shape = self.shape if dtype is None: dtype = self.dtype if weak_type is None: weak_type = self.weak_type - if sharding is None: - sharding = getattr(self, 'sharding', None) - return ShapedArray(shape, dtype, weak_type, sharding=sharding) + if 'sharding' not in kwargs: + kwargs['sharding'] = getattr(self, 'sharding', None) + return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) size = property(lambda self: @@ -1801,25 +1662,22 @@ def __hash__(self): getattr(self, 'sharding', None))) def to_tangent_aval(self): - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) - - def join(self, other): - if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype) + if config.sharding_in_types.value: + return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type, self.sharding) else: - raise TypeError(self, other) + return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type) def str_short(self, short_dtypes=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name) dt_str = dt_str.replace('void', 'float0') if hasattr(self, 'sharding') and self.sharding is not None: - shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec)) - return f'{dt_str}[{shapestr}]' + shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) + axis_types = self.sharding.mesh.axis_types + axt = _get_axis_type_str(axis_types) if axis_types is not None else '' + return f'{dt_str}[{shapestr}]{axt}' else: shapestr = ','.join(map(str, self.shape)) return f'{dt_str}[{shapestr}]' @@ -1831,74 +1689,41 @@ def _len(self, ignored_tracer): raise TypeError("len() of unsized object") from err # same as numpy error +def _get_axis_type_str(axis_types): + from jax._src.mesh import AxisTypes # type: ignore + + out = [] + for t, axes in axis_types.items(): + a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes + if t == AxisTypes.Collective: + out.append(f"C:{a}") + elif t == AxisTypes.User: + out.append(f"U:{a}") + else: + assert t == AxisTypes.Auto + out.append(f"A:{a}") + return f"{{{', '.join(out)}}}" + def _get_shape_sharding_str(shape, spec): + out = [] for s1, s2 in zip(shape, spec): if s2 is None: - yield f"{s1}" + out.append(f"{s1}") elif isinstance(s2, tuple): ss = ','.join(s for s in s2) - yield f"{s1}@({ss})" + out.append(f"{s1}@({ss})") else: - yield f"{s1}@{s2}" + out.append(f"{s1}@{s2}") + return ','.join(out) +def _get_abstract_sharding(val): + from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error -def _forward_to_value(self, fun, ignored_tracer, *args): - return fun(self.val, *args) - - -class ConcreteArray(ShapedArray): - __slots__ = ['val'] - array_abstraction_level = 0 - - def __init__(self, dtype, val, weak_type=None): - super().__init__( - np.shape(val), dtype, - weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type) - dtypes.check_valid_dtype(self.dtype) - # Note: canonicalized self.dtype doesn't necessarily match self.val - assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype) - self.val = val - - def update(self, dtype=None, val=None, weak_type=None): - dtype = self.dtype if dtype is None else dtype - val = self.val if val is None else val - weak_type = self.weak_type if weak_type is None else weak_type - return ConcreteArray(dtype, val, weak_type) - - def __eq__(self, other): - if (type(self) is type(other) and self.dtype == other.dtype - and self.shape == other.shape and self.weak_type == other.weak_type): - with eval_context(): # in case self.val is an Array - return (self.val == other.val).all() - else: - return False - - def __hash__(self): - return id(self.val) - - def join(self, other) -> AbstractValue: - if self == other: - return self - elif self.shape == other.shape and self.dtype == other.dtype: - weak_type = self.weak_type and other.weak_type - return ShapedArray(self.shape, self.dtype, weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type) - else: - raise TypeError(self, other) - - def str_short(self, short_dtypes=False) -> str: - dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - return f'{self.val}, dtype={dt_str}' - - _bool = partialmethod(_forward_to_value, bool) - _int = partialmethod(_forward_to_value, int) - _hex = partialmethod(_forward_to_value, hex) - _oct = partialmethod(_forward_to_value, oct) - _index = partialmethod(_forward_to_value, operator.index) - - _float = concretization_function_error(float, True) - _complex = concretization_function_error(complex, True) + if (config.sharding_in_types.value and hasattr(val, 'sharding') and + isinstance(val.sharding, NamedSharding)): + return NamedSharding(val.sharding.mesh.abstract_mesh, + val.sharding.spec._normalized_spec(val.ndim)) + return None def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): @@ -1962,28 +1787,10 @@ def __eq__(self, other): def __hash__(self): return hash((self.shape, self.dtype, self.weak_type)) - def join(self, other): - if (definitely_equal_shape(self.shape, other.shape) and - self.dtype == other.dtype): - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype) - else: - raise TypeError(self, other) - def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) -class DConcreteArray(DShapedArray): - __slots__ = ['val'] - array_abstraction_level = 1 - def __init__(self, shape, dtype, weak_type, val): - super().__init__(shape, dtype, weak_type) - self.val = val - - pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} @@ -2040,8 +1847,7 @@ def data(self): pytype_aval_mappings[DArray] = \ - lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, - x._data) + lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -2078,6 +1884,7 @@ def __repr__(self) -> str: return 'Mutable' + repr(self[...]) def mutable_array(init_val): return mutable_array_p.bind(init_val) mutable_array_p = Primitive('mutable_array') +mutable_array_p.ref_primitive = True class InternalMutableArrayEffect(effects.Effect): pass @@ -2092,16 +1899,23 @@ def mutable_array_abstract_eval(init_aval): @mutable_array_p.def_impl def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # pytype: disable=import-error - aval = raise_to_shaped(get_aval(init_val)) + aval = get_aval(init_val) return MutableArray(AbstractRef(aval), init_val) +def freeze(ref): + return freeze_p.bind(ref) +freeze_p = Primitive('freeze') +freeze_p.ref_primitive = True + +@freeze_p.def_effectful_abstract_eval +def freeze_abstract_eval(ref_aval): + return ref_aval.inner_aval, {internal_mutable_array_effect} + +@freeze_p.def_impl +def _freeze_impl(ref): + return ref[()] class AbstractToken(AbstractValue): - def join(self, other): - if isinstance(other, AbstractToken): - return self - else: - assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() @@ -2121,27 +1935,10 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -def raise_to_shaped(aval: AbstractValue, weak_type=None): - aval_type = type(aval) - if aval_type is ShapedArray and weak_type is None: - return aval - if weak_type is None: - weak_type = getattr(aval, 'weak_type', False) - for typ in aval_type.__mro__: - handler = raise_to_shaped_mappings.get(typ) - if handler: return handler(aval, weak_type) - raise TypeError(type(aval)) - -raise_to_shaped_mappings: dict[type, Callable] = { - AbstractToken: lambda aval, _: aval, - Bot: lambda aval, _: aval, - UnshapedArray: lambda aval, _: aval, - ShapedArray: lambda aval, weak_type: ShapedArray( - aval.shape, aval.dtype, weak_type), - DConcreteArray: lambda aval, weak_type: DShapedArray( - aval.shape, aval.dtype, weak_type - ), -} +# TODO(dougalm): Deprecate these. They're just here for backwards compat. +def raise_to_shaped(aval): + return aval +raise_to_shaped_mappings: dict[type, Callable] = {} ### Operations on shapes and dimension sizes. @@ -2281,6 +2078,70 @@ def dimension_as_value(d: DimSize): if hasattr(d, "dimension_as_value"): return d.dimension_as_value() return operator.index(d) +def canonicalize_slice( + s: slice, + axis_size: DimSize + ) -> tuple[DimSize, DimSize, DimSize]: + """Computes the start index, step, and size of the slice `x[s]`. + + This is similar to `s.indices(axis_size)`, except that it returns + `(start, step, size)`, and it works when the slice and/or the + `axis_size` are symbolic. + + See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding + """ + def convert_to_index(d: DimSize) -> DimSize: + # Convert np.array and jax.Array to int, leave symbolic dimensions alone + try: + return operator.index(d) + except: + return d + + # Must resolve statically if step is {<0, ==0, >0} + step = convert_to_index(s.step) if s.step is not None else 1 + try: + if step == 0: + raise ValueError("slice step cannot be zero") + step_gt_0 = (step > 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the step ({step}) must " + + f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") + + def clamp_index(i: DimSize, which: str): + try: + i_ge_0 = (i >= 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the {which} ({i}) must " + + f"be resolved statically if it is >= 0.\nDetails: {e}") + if i_ge_0: + if step_gt_0: + return min_dim(axis_size, i) + else: + return min_dim(axis_size - 1, i) + else: + if step_gt_0: + return max_dim(0, axis_size + i) + else: + return max_dim(-1, axis_size + i) + + if s.start is None: + start = 0 if step_gt_0 else axis_size - 1 + else: + start = clamp_index(convert_to_index(s.start), "start") + + if s.stop is None: + stop = axis_size if step_gt_0 else -1 + else: + stop = clamp_index(convert_to_index(s.stop), "stop") + + gap = step if step_gt_0 else - step + distance = (stop - start) if step_gt_0 else (start - stop) + slice_size = max_dim(0, distance + gap - 1) // gap + return start, step, slice_size + + class SomeTracer: __slots__ = () def __repr__(self): return "[dynamic]" @@ -2338,11 +2199,10 @@ class CallPrimitive(Primitive): multiple_results = True call_primitive = True - def bind(self, fun, *args, **params): - call_bind_continuation, top_trace, fun_, tracers, params = ( - call_bind_with_continuation(self, fun, *args, **params)) - outs = top_trace.process_call(self, fun_, tracers, params) - return call_bind_continuation(outs) + def bind_with_trace(self, trace, fun_and_args, params): + fun = fun_and_args[0] + args = fun_and_args[1:] + return trace.process_call(self, fun, args, params) def get_bind_params(self, params): new_params = dict(params) @@ -2352,45 +2212,9 @@ def get_bind_params(self, params): subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr)) return [subfun], new_params -def call_bind_with_continuation(primitive: CallPrimitive, fun, *args, **params): - top_trace = find_top_trace(args) - fun_, env_trace_todo = process_env_traces_call( - fun, primitive, top_trace.level, tuple(params.items())) - tracers = map(top_trace.full_raise, args) - fun_ = lu.annotate(fun_, fun.in_type) - - def call_bind_continuation(outs): - return map(full_lower, apply_todos(env_trace_todo(), outs)) - return call_bind_continuation, top_trace, fun_, tracers, params - -@lu.transformation_with_aux -def process_env_traces_call(primitive: CallPrimitive, level: int, - params_tuple: tuple, *args): - outs = yield args, {} - params = dict(params_tuple) - todo = [] - while True: - tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level] - if not tracers: - break - ans = max(tracers, key=_get_trace_level) - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, cur_todo = trace.post_process_call(primitive, outs, params) - todo.append(cur_todo) - yield outs, tuple(todo) # Ensure the aux output is immutable - -def apply_todos(todos, outs): - todos_list = list(todos) - while todos_list: - outs = map(full_lower, todos_list.pop()(outs)) - return outs - - def call_impl(f: lu.WrappedFun, *args, **params): del params # params parameterize the call primitive, not the function - with new_sublevel(): - return f.call_wrapped(*args) + return f.call_wrapped(*args) call_p: CallPrimitive = CallPrimitive('call') call = call_p.bind @@ -2409,49 +2233,21 @@ def get_bind_params(self, params): closed_call_p.def_effectful_abstract_eval( lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects)) - -outfeed_primitives: set[Primitive] = set() -def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool: - """Finds if there are outfeed primitives anywhere inside a Jaxpr.""" - return any(primitive_uses_outfeed(eqn.primitive, eqn.params) - for eqn in jaxpr.eqns) - -def _param_uses_outfeed(param): - if type(param) is Jaxpr: - if jaxpr_uses_outfeed(param): - return True - elif type(param) is ClosedJaxpr: - if jaxpr_uses_outfeed(param.jaxpr): - return True - return False - -def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool: - if prim in outfeed_primitives: - return True - for param in params.values(): - if isinstance(param, tuple): - if any(unsafe_map(_param_uses_outfeed, param)): - return True - elif _param_uses_outfeed(param): - return True - return False - # ------------------- Map ------------------- class MapPrimitive(Primitive): multiple_results = True map_primitive = True - def bind(self, fun, *args, **params): + def bind_with_trace(self, trace, fun_and_args, params): + fun = fun_and_args[0] + args = fun_and_args[1:] assert len(params['in_axes']) == len(args) - return map_bind(self, fun, *args, **params) + return trace.process_map(self, fun, args, params) def process(self, trace, fun, tracers, params): return trace.process_map(self, fun, tracers, params) - def post_process(self, trace, out_tracers, params): - return trace.post_process_map(self, out_tracers, params) - def get_bind_params(self, params): new_params = dict(params) jaxpr = new_params.pop('call_jaxpr') @@ -2460,59 +2256,6 @@ def get_bind_params(self, params): new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params - -def map_bind_with_continuation(primitive: MapPrimitive, fun, *args, - out_axes_thunk, **params): - # The new thunk depends deterministically on the old thunk and the wrapped - # function. Any caching already has to include the wrapped function as part - # of the key, so we only use the previous thunk for equality checks. - @as_hashable_function(closure=out_axes_thunk) - def new_out_axes_thunk(): - out_axes = out_axes_thunk() - _, out_axes_transforms = todo_and_xforms() - for t in out_axes_transforms: - out_axes = t(out_axes) - return out_axes - params = dict(params, out_axes_thunk=new_out_axes_thunk) - top_trace = find_top_trace(args) - fun, todo_and_xforms = process_env_traces_map( - fun, primitive, top_trace and top_trace.level, tuple(params.items())) - tracers = map(top_trace.full_raise, args) - - def map_bind_continuation(outs): - env_trace_todo, _ = todo_and_xforms() - return map(full_lower, apply_todos(env_trace_todo, outs)) - - return map_bind_continuation, top_trace, fun, tracers, params - - -def map_bind(primitive: MapPrimitive, fun, *args, **params): - map_bind_continuation, top_trace, fun, tracers, params = ( - map_bind_with_continuation(primitive, fun, *args, **params)) - return map_bind_continuation( - primitive.process(top_trace, fun, tracers, params)) - -@lu.transformation_with_aux -def process_env_traces_map(primitive: MapPrimitive, level: int, - params_tuple: tuple, *args): - outs = yield args, {} - params = dict(params_tuple) - todo = [] - out_axes_transforms = [] - while True: - tracers = [x for x in outs if isinstance(x, Tracer) - and (level is None or x._trace.level > level)] - if not tracers: - break - ans = max(tracers, key=_get_trace_level) - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, (cur_todo, cur_xform) = primitive.post_process(trace, outs, params) - todo.append(cur_todo) - out_axes_transforms.append(cur_xform) - yield outs, (tuple(todo), tuple(out_axes_transforms)) - - def mapped_aval(size: AxisSize, axis: int | None, aval: AbstractValue) -> AbstractValue: handler, _ = aval_mapping_handlers.get(type(aval), (None, None)) @@ -2535,16 +2278,20 @@ def _map_shaped_array( assert axis is None or aval.shape[axis] == size # TODO: Extend the named shape if axis is None: return aval + sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) + if config.sharding_in_types.value else None) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding) def _unmap_shaped_array( size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: + sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name)) + if config.sharding_in_types.value else None) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding) else: raise TypeError(axis) def _map_dshaped_array( @@ -2567,60 +2314,9 @@ def _unmap_dshaped_array( aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), - ConcreteArray: (_map_shaped_array, _unmap_shaped_array), AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a) } -@contextmanager -def extend_axis_env(axis_name: AxisName, size: int, tag: Any): - frame = AxisEnvFrame(axis_name, size, tag) - ts = thread_local_state.trace_state - ts.axis_env.append(frame) - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - try: - yield - finally: - ts.axis_env.pop() - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - -@contextmanager -def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None): - frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes] - ts = thread_local_state.trace_state - ts.axis_env.extend(frames) - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - try: - yield - finally: - for _ in frames: ts.axis_env.pop() - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - - -@contextmanager -def stash_axis_env(): - "Promise that a function or with-suite does not depend implicitly on axis env" - # If the promise is broken, then a NameError about an unbound axis name will - # be raised. - ts = thread_local_state.trace_state - prev_axis_env, ts.axis_env = ts.axis_env, [] - config.update_thread_local_jit_state(axis_env_state=()) - try: - yield - finally: - ts.axis_env = prev_axis_env - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - - # When a mapped function is given no axis name, we generate a name object based # on the id of the function object. Collisions aren't important because this # name can't be used in collectives, as user code never gets a ref to this @@ -2646,20 +2342,6 @@ def __lt__(self, other): return type(other) is _TempAxisName and self.id < other.id -def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None - ) -> AxisEnvFrame: - frames = thread_local_state.trace_state.axis_env - for frame in reversed(frames): - if (frame.name == axis_name and - (main_trace is None or frame.main_trace is main_trace)): - return frame - named_axes = [frame.name for frame in reversed(frames) - if not isinstance(frame.name, _TempAxisName)] - raise NameError( - f'unbound axis name: {axis_name}. The following axis names (e.g. defined ' - f'by pmap) are available to collective operations: {named_axes}') - - @dataclass(frozen=True) class NamedAxisEffect(effects.Effect): """A side-effect introducing a new named axis into the current scope.""" @@ -2687,98 +2369,9 @@ def remove_named_axis_effects( return jaxpr return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names)) - -ParamDict = dict[str, Any] -AxisSubst = Callable[[AxisName], tuple[AxisName, ...]] - -class NameGatheringSubst: - def __init__(self): - self.axis_names = set() - def __call__(self, axis_name): - self.axis_names.add(axis_name) - return (axis_name,) - -def used_axis_names(primitive: Primitive, params: ParamDict) -> set[AxisName]: - subst = NameGatheringSubst() - subst_axis_names(primitive, params, subst) - return subst.axis_names - -def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict: - if primitive in axis_substitution_rules: - return axis_substitution_rules[primitive](params, subst, traverse) - if not traverse: - return params - # Default implementation: substitute names in all jaxpr parameters - if isinstance(primitive, MapPrimitive): - def shadowed_subst(name): - return (name,) if name == params['axis_name'] else subst(name) - else: - shadowed_subst = subst - jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))] - if not jaxpr_params: - return params - new_params = dict(params) - for name, jaxpr in jaxpr_params: - new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst) - return new_params - -class DuplicateAxisNameError(Exception): - def __init__(self, var): - self.var = var - self.eqn = None - -def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]: - new_effects = set[Effect]() - for e in effects: - if isinstance(e, NamedAxisEffect): - new_effects.update(map(NamedAxisEffect, subst(e.name))) - else: - new_effects.add(e) - return new_effects - -def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var: - # Var identity is load-bearing, so we can't have duplicates! - if isinstance(v, DropVar): return v - assert v not in var_map - var_map[v] = v - return v - -def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn: - invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars] - try: - outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars] - except DuplicateAxisNameError as e: - e.eqn = eqn - raise - params = subst_axis_names(eqn.primitive, eqn.params, subst) - effects = subst_axis_names_effects(eqn.effects, subst) - return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects) - -def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): - consts = None - if isinstance(jaxpr, ClosedJaxpr): - consts = jaxpr.consts - jaxpr = jaxpr.jaxpr - var_map: dict[Var, Var] = {} - invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr] - constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr] - eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] - outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr] - effects = subst_axis_names_effects(jaxpr.effects, subst) - new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects) - if consts is not None: - return ClosedJaxpr(new_jaxpr, consts) - return new_jaxpr - def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr): return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)} -def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): - if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it! - subst.axis_names |= used_axis_names_jaxpr(jaxpr) - return jaxpr - return do_subst_axis_names_jaxpr(jaxpr, subst) - def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects): return _replace_jaxpr_effects(jaxpr, frozenset(effects)) @@ -2786,23 +2379,6 @@ def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects): def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]): return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects))) - -axis_substitution_rules: dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {} - -# ------------------- AxisPrimitive ------------------- -# Primitives that store axis names in params and want those axis names to -# participate in dispatch should subclass AxisPrimitive. - -class AxisPrimitive(Primitive): - def bind(self, *args, **params): - top_trace = find_top_trace(args) - axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)), - default=None, key=lambda t: getattr(t, 'level', -1)) - top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level - else axis_main.with_cur_sublevel()) - return self.bind_with_trace(top_trace, args, params) - - # ------------------- Jaxpr checking ------------------- def typecheck(aval: AbstractValue, x) -> bool: @@ -2811,18 +2387,23 @@ def typecheck(aval: AbstractValue, x) -> bool: def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: """Determine whether `aval` conforms to `aval_ref`. Ignores weak_type.""" try: - return typematch(aval_ref, lattice_join(aval_ref, aval)) + return typematch(aval_ref, aval) except TypeError: return False -def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool: - """Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type.""" - if aval1 == aval2: return True - # unequal avals may still represent the same type, because type is represented - # by avals at the shaped level, and because weak type tags aren't considered - # part of the type - return (raise_to_shaped(aval1, weak_type=False) == - raise_to_shaped(aval2, weak_type=False)) +def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: + """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" + t1 = t1.normalize() + t2 = t2.normalize() + if t1 == t2: + return True + elif (isinstance(t1, (ShapedArray, DShapedArray)) and + isinstance(t2, (ShapedArray, DShapedArray))): + # This case handles DShapedArray and shape polynomials. Alternatively we + # could try normalizing first and then doing simple equality. + return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + else: + return False class JaxprTypeError(TypeError): pass @@ -2950,10 +2531,11 @@ def write(v: Var, a: AbstractValue) -> None: # Check the computed effect type matches the eqn's annotation, and is # included in the jaxpr's annotation. - if prim is mutable_array_p: - outvar, = eqn.outvars - in_idx[outvar] = None # type: ignore - mut_arrays.add(outvar) + if prim.ref_primitive: + if prim is mutable_array_p: + outvar, = eqn.outvars + in_idx[outvar] = None # type: ignore + mut_arrays.add(outvar) if eqn.effects != eqn_effects: raise JaxprTypeError("Inferred effects do not match equation effects. " f"Equation effects: {eqn.effects}. " @@ -3073,10 +2655,16 @@ def substitute(aval: AbstractValue): return aval for v, x in zip(call_jaxpr.invars, in_atoms): if not typecompat(substitute(v.aval), x.aval): - # TODO(mattjj): vars in error message are confusing b/c of Var.__repr__ - raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " - f"{x.aval} to jaxpr expecting type " - f"{substitute(v.aval)}") + # TODO(yashkatariya): Remove this once numpy array's aval has a sharding + # on it. + if (config.sharding_in_types.value and isinstance(x, Literal) and + v.aval.sharding is not None and x.val.ndim == 0): + pass + else: + # TODO(mattjj): vars in error message are confusing b/c of Var.__repr__ + raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " + f"{x.aval} to jaxpr expecting type " + f"{substitute(v.aval)}") env[v] = x if type(x) is Var else x.val _check_jaxpr(ctx_factory, call_jaxpr) @@ -3120,7 +2708,7 @@ def _check_map(ctx_factory, prim, in_avals, params): raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} " f"to jaxpr expecting {binder_aval}") - with extend_axis_env(params['axis_name'], axis_size, None): + with extend_axis_env_nd([(params['axis_name'], axis_size)]): _check_jaxpr(ctx_factory, call_jaxpr) mapped_out_avals = [v.aval for v in call_jaxpr.outvars] @@ -3437,46 +3025,45 @@ def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any], # Comparable object for checking whether JAX's trace state has changed. class OpaqueTraceState: - def __init__(self, trace_info, convention): - self._trace_info = trace_info - self._convention = convention + def __init__(self, trace_ref): + self._trace_ref = trace_ref def __eq__(self, other): if isinstance(other, OpaqueTraceState): - if self._convention in ["nnx"]: - return self._trace_info is other._trace_info - elif self._convention in ["haiku", "flax"]: - return self._trace_info == other._trace_info - else: - raise Exception(f"unrecognized convention: {self._convention}") - - -# Each library has its own opinion about what the important fragment of jax's -# internal state is. TODO: reconcile the differences and remove the flag. -def get_opaque_trace_state(convention="flax"): - if convention == "flax": - trace_info = find_top_trace(()).level - elif convention == "haiku": - trace_stack = thread_local_state.trace_state.trace_stack.stack - top_type = trace_stack[0].trace_type - level = trace_stack[-1].level - sublevel = cur_sublevel() - trace_info = (top_type, level, sublevel) - elif convention == "nnx": - trace_info = thread_local_state.trace_state.trace_stack.dynamic - else: - raise Exception(f"unrecognized convention: {convention}") + return self._trace_ref == other._trace_ref + else: + return False - return OpaqueTraceState(trace_info, convention) +def get_opaque_trace_state(convention): + del convention + return OpaqueTraceState(ref(trace_ctx.trace)) def nonempty_axis_env() -> bool: - return bool(thread_local_state.trace_state.axis_env) + return bool(trace_ctx.axis_env.axis_sizes) def unsafe_am_i_under_a_jit() -> bool: - return 'DynamicJaxprTrace' in str(thread_local_state.trace_state.trace_stack) + return 'DynamicJaxprTrace' in str(unsafe_get_trace_stack(trace_ctx.trace)) def unsafe_am_i_under_a_vmap() -> bool: - return 'BatchTrace' in str(thread_local_state.trace_state.trace_stack) + return 'BatchTrace' in str(unsafe_get_trace_stack(trace_ctx.trace)) + +# TODO(douglam): deprecate/delete +def find_top_trace(_): + return unsafe_get_current_trace() + + +def unsafe_get_current_trace(): + return trace_ctx.trace + +def unsafe_get_trace_stack(trace): + if hasattr(trace, "parent_trace"): + return unsafe_get_trace_stack(trace.parent_trace) + [trace] + else: + return [trace] + +def unsafe_get_axis_names() -> list[Any]: + return list(trace_ctx.axis_env.axis_sizes) -def unsafe_get_axis_names() -> list[str]: - return [axis.name for axis in thread_local_state.trace_state.axis_env] +# TODO(douglam): deprecate/delete +def axis_frame(axis_name): + return trace_ctx.axis_env.axis_size(axis_name) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index e20271f66301..c45bb8a9efd8 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -18,8 +18,8 @@ import math import jax -from jax import core from jax import dtypes +from jax._src import core from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching @@ -295,8 +295,8 @@ def check_is_flash_attention( _, T, _, H = query.shape _, S, _, _ = key.shape - if not ((H <= 128 and H % 8 == 0) and - (not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)): + if (H > 128 or H % 8 != 0 or + (is_training and has_bias and (T % 2 != 0 or S % 2 != 0))): # check if flash attention is supported # for training, for patterns with bias, seqlen should be divisible by 2 raise NotImplementedError( @@ -1022,11 +1022,9 @@ def dot_product_attention(query: Array, Returns: Output of the same shape as the query. """ - # check if cuDNN is installed + # TODO(b/380898464): Check the compute capability, e.g., require GPU device, + # in the kernel implementation (c++) code. cudnn_version = check_cudnn_version() - # only support at least Ampere - if not check_compute_capability("8.0"): - raise RuntimeError("Require at least Ampere arch to run") layout = _normalize_layout(qkv_layout) if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") @@ -1047,7 +1045,7 @@ def dot_product_attention(query: Array, # combine bias and mask if bias is None: - bias = mask + bias = mask else: if mask is not None: # should be broadcast to same shape diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index 8a13399e3d63..f320672463cb 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -14,7 +14,7 @@ import functools import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax.interpreters import mlir from jax.interpreters.mlir import hlo from jax.interpreters.mlir import ir diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 35e7d33430bd..74ad261b3218 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Callable +from typing import Any import functools import operator @@ -48,17 +49,93 @@ @custom_api_util.register_custom_decorator_type class custom_vmap: - fun: Callable - vmap_rule: Callable | None - - def __init__(self, fun: Callable): + """Customize the vmap behavior of a JAX-transformable function. + + This decorator is used to customize the behavior of a JAX function under the + :func:`jax.vmap` transformation. A ``custom_vmap``-decorated function will + mostly (see below for caveats) have the same behavior as the underlying + function, except when batched using :py:func:`jax.vmap`. When batched, the + rule defined using :py:func:`~jax.custom_batching.custom_vmap.def_vmap` will + be used. + + For example: + + >>> @jax.custom_batching.custom_vmap + ... def f(x, y): + ... return x + y + ... + >>> @f.def_vmap + ... def f_vmap_rule(axis_size, in_batched, xs, ys): + ... assert all(in_batched) + ... assert xs.shape[0] == axis_size + ... assert ys.shape[0] == axis_size + ... out_batched = True + ... return xs * ys, out_batched + ... + >>> xs = jnp.arange(3) + >>> ys = jnp.arange(1, 4) + >>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys + Array([0, 2, 6], dtype=int32) + + Of note, ``custom_vmap`` functions do not support reverse-mode autodiff. To + customize both vmap and reverse-mode autodiff, combine ``custom_vmap`` with + :py:class:`jax.custom_vjp`. For example: + + >>> @jax.custom_vjp + ... @jax.custom_batching.custom_vmap + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> @f.def_vmap + ... def f_vmap_rule(axis_size, in_batched, xs, ys): + ... return jnp.cos(xs) * ys, True + ... + >>> def f_fwd(x, y): + ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) + ... + >>> def f_bwd(res, g): + ... cos_x, sin_x, y = res + ... return (cos_x * g * y, sin_x * g) + ... + >>> f.defvjp(f_fwd, f_bwd) + >>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3)) + Array([1., 1., 1.], dtype=float32) + >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) + Array(1., dtype=float32) + + Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the + ``custom_vmap``-decorated function. + """ + + fun: Callable[..., Any] + vmap_rule: Callable[..., tuple[Any, Any]] | None + + def __init__(self, fun: Callable[..., Any]): functools.update_wrapper(self, fun) self.fun = fun self.vmap_rule = None __getattr__ = custom_api_util.forward_attr - def def_vmap(self, vmap_rule: Callable) -> Callable: + def def_vmap( + self, + vmap_rule: Callable[..., tuple[Any, Any]], + ) -> Callable[..., tuple[Any, Any]]: + """Define the vmap rule for this custom_vmap function. + + Args: + vmap_rule: A function that implements the vmap rule. This function should + accept the following arguments: (1) an integer ``axis_size`` as its + first argument, (2) a pytree of booleans with the same structure as the + inputs to the function, specifying whether each argument is batched, + and (3) the batched arguments. It should return a tuple of the batched + output and a pytree of booleans with the same structure as the output, + specifying whether each output element is batched. See the documentation + for :py:func:`jax.custom_batching.custom_vmap` for some examples. + + Returns: + This method passes the rule through, returning ``vmap_rule`` unchanged. + """ self.vmap_rule = vmap_rule return vmap_rule @@ -138,9 +215,9 @@ def maybe_bdim_at_front(x, bdim): # axes instead of accepting and matching a given spec of output axes. Assumes # `f` is pytree-flattened def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): - f, out_axes = batching.batch_subtrace(f) - f = batching._batch_outer(f, axis_name, axis_size, in_axes, - batching.BatchTrace, None) + axis_data = batching.AxisData(axis_name, axis_size, None) + tag = core.TraceTag() + f, out_axes = batching.batch_subtrace(f, tag, axis_data, in_axes) outs = f.call_wrapped(*args) return outs, out_axes() @@ -272,6 +349,31 @@ def tree_merge(mask, lhs_tree, rhs_tree): mask, lhs_tree, rhs_tree) def sequential_vmap(f): + """A special case of ``custom_vmap`` that uses a loop. + + A function decorated with ``sequential_vmap`` will be called sequentially + within a loop when batched. This is useful for functions that don't natively + support batch dimensions. + + For example: + + >>> @jax.custom_batching.sequential_vmap + ... def f(x): + ... jax.debug.print("{}", x) + ... return x + 1 + ... + >>> jax.vmap(f)(jnp.arange(3)) + 0 + 1 + 2 + Array([1, 2, 3], dtype=int32) + + Where the print statements demonstrate that this :py:func:`~jax.vmap` is being + generated using a loop. + + See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for + more details. + """ f = custom_vmap(f) @f.def_vmap diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index f5ecdfcda286..69130cc1831e 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,7 +31,6 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs) -from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -76,13 +75,14 @@ def _zeros_like_pytree(x): # like the api_util.py function, but also grabs output avals for error checking -@lu.transformation_with_aux -def _flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def _flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} + ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) - ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat] - yield ans_flat, (ans_tree, ans_avals) + ans_avals = [core.get_aval(x) for x in ans_flat] + store.store((ans_tree, ans_avals)) + return ans_flat ### JVPs @@ -267,18 +267,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable def _add_args(f, extra_args): return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args)) -@lu.transformation -def _add_args_(extra_args, *args, **kwargs): +@lu.transformation2 +def _add_args_(f, extra_args, *args, **kwargs): extra_args = tuple(arg.val for arg in extra_args) all_args = (extra_args + args) - yield (yield all_args, kwargs) + return f(*all_args, **kwargs) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) - pair_out = yield (py_primals, py_tangents), {} + pair_out = f(py_primals, py_tangents) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} " "must produce a pair (list or tuple of length two) representing " @@ -287,7 +287,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) - primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + primal_avals = [core.get_aval(x) for x in primals_out] if out_tree != out_tree2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must " "produce primal and tangent outputs with equal container (pytree) " @@ -327,11 +327,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] + primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out] expected_tangent_avals_out = [ - raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() + core.get_aval(x).strip_weak_type().to_tangent_aval() for x in primals_out] - tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) + tangent_avals_out = [core.get_aval(t).strip_weak_type() if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] if expected_tangent_avals_out != tangent_avals_out: @@ -349,30 +349,18 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): if av_et != av_t) raise TypeError(msg.format('\n'.join(disagreements))) - yield primals_out + tangents_out, (out_tree, primal_avals) + store.store((out_tree, primal_avals)) + return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): multiple_results = True - def bind(self, fun, jvp, *args, symbolic_zeros): - args = map(core.full_lower, args) - top_trace = core.find_top_trace(args) - fun, env_trace_todo1 = process_env_traces( - fun, self, top_trace and top_trace.level, False) - jvp, env_trace_todo2 = process_env_traces( - jvp, self, top_trace and top_trace.level, True) - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers, - symbolic_zeros=symbolic_zeros) - _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) - return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) + def bind_with_trace(self, trace, args, params): + fun, jvp, tracers = args[0], args[1], args[2:] + return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params) def impl(self, fun, _, *args): - with core.new_sublevel(): - return fun.call_wrapped(*args) - - def post_process(self, trace, out_tracers, jvp_was_run: bool): - return trace.post_process_custom_jvp_call(out_tracers, jvp_was_run) + raise NotImplementedError def get_bind_params(self, params): new_params = dict(params) @@ -402,24 +390,6 @@ def jvp(*xs): return [*out_primals, *out_tangents] return jvp -@partial(lu.transformation_with_aux, use_eq_store=True) -def process_env_traces(primitive, level: int, jvp_was_run: bool, *args): - outs = yield args, {} - todo = [] - while True: - tracers = [x for x in outs if isinstance(x, core.Tracer) - and (level is None or x._trace.level > level)] - if tracers: - ans = max(tracers, key=lambda x: x._trace.level) - else: - break - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, cur_todo = primitive.post_process(trace, outs, jvp_was_run) - todo.append(cur_todo) - yield outs, tuple(todo) # Ensure the aux output is immutable - - effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') @@ -637,7 +607,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable f_, dyn_args = lu.wrap_init(self.fun), args fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + in_avals = [core.get_aval(x) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name, fwd_name, in_tree, out_type) @@ -684,15 +654,15 @@ def _check_for_tracers(x): "arguments should typically not be indicated as nondiff_argnums.") raise UnexpectedTracerError(msg) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, *args): if symbolic_zeros: args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])] else: args = args[::2] py_args = tree_unflatten(in_tree, args) - pair_out = yield py_args, {} + pair_out = f(*py_args) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -705,7 +675,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) - primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + primal_avals = [core.get_aval(x) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None @@ -742,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - yield (*res, *primals_out), (out_tree, res_tree) + store.store((out_tree, res_tree)) + return (*res, *primals_out) -@lu.transformation -def _flatten_bwd(in_tree, in_avals, out_trees, *args): +@lu.transformation2 +def _flatten_bwd(f, in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) - py_cts_in = yield (py_res, py_cts_out), {} + py_cts_in = f(py_res, py_cts_out) if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)): py_cts_in = tuple(py_cts_in) # For each None in py_cts_in, indicating an argument for which the rule @@ -803,11 +774,11 @@ def append(x, d): msg = ("Custom VJP bwd rule must produce an output with the same " "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " - f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding " + f"shape/dtype {a_.str_short()} corresponding " f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) - yield results + return results # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: @@ -824,61 +795,12 @@ def _temporary_shape_exception(a, a_) -> bool: class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive - def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): - args = map(core.full_lower, args) - top_trace = core.find_top_trace(args) - fun, env_trace_todo1 = process_env_traces( - fun, self, top_trace and top_trace.level, False) - fwd, env_trace_todo2 = process_env_traces_fwd( - fwd, top_trace and top_trace.level, out_trees) - tracers = map(top_trace.full_raise, args) - bwd_ = lambda *args: bwd(*args) - outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, - out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) - if fst: - return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) - else: - env_trace_todo, bwd_transform = env_trace_todo - bwd = _apply_bwd_transform(bwd_transform, bwd) - return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) - - def impl(self, fun, fwd, bwd, *args, out_trees): - del fwd, bwd, out_trees - with core.new_sublevel(): - return fun.call_wrapped(*args) + def bind_with_trace(self, trace, args, params): + fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] + return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) - def post_process(self, trace, out_tracers, params): - return trace.post_process_custom_vjp_call(out_tracers, params) custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') -@partial(lu.transformation_with_aux, use_eq_store=True) -def process_env_traces_fwd(level: int, out_trees, *args): - outs = yield args, {} - todo = [] - bwd_transforms = [] - while True: - tracers = [x for x in outs if isinstance(x, core.Tracer) - and (level is None or x._trace.level > level)] - if tracers: - ans = max(tracers, key=lambda x: x._trace.level) - else: - break - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, cur_todo, bwd_xform = trace.post_process_custom_vjp_call_fwd(outs, out_trees) - todo.append(cur_todo) - bwd_transforms.append(bwd_xform) - yield outs, (tuple(todo), tuple(bwd_transforms)) - - -def _apply_bwd_transform(todos, bwd): - todos_list = list(todos) - while todos_list: - bwd = todos_list.pop()(bwd) - return bwd - def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): return core.jaxpr_as_fun(fun_jaxpr)(*args) @@ -889,7 +811,7 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): f'Effects not supported in `custom_vjp`: {disallowed_effects}') return fun_jaxpr.out_avals, fun_jaxpr.effects -custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr') +custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr') custom_vjp_call_jaxpr_p.multiple_results = True custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) @@ -911,7 +833,7 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, @@ -921,18 +843,16 @@ def _custom_vjp_call_jaxpr_jvp( ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp def _custom_vjp_call_jaxpr_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, + axis_data, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] _, args_batched = split_list(in_batched, [num_consts]) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name, - main_type) + fun_jaxpr, axis_data, in_batched, False) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] @@ -940,16 +860,15 @@ def _custom_vjp_call_jaxpr_vmap( def batched_fwd_jaxpr_thunk(*zeros): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, - main_type) + fwd_jaxpr, axis_data, args_batched, False) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] + tag = core.TraceTag() batched_bwd = batching.batch_custom_vjp_bwd( - bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, - spmd_axis_name) + bwd, tag, axis_data, fwd_out_dims, fwd_args_batched) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, @@ -957,10 +876,7 @@ def batched_fwd_jaxpr_thunk(*zeros): num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims -batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ - _custom_vjp_call_jaxpr_vmap -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial( - _custom_vjp_call_jaxpr_vmap, None) +batching.fancy_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) @@ -1144,11 +1060,12 @@ def rev(objective_fn, res, g): def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/jax-ml/jax/issues/6415 for motivation. - x = core.full_lower(x) + # See https://github.com/google/jax/issues/6415 for motivation. if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. return False + elif isinstance(x, ad.JVPTracer) and isinstance(x.tangent, ad.Zero): + return _maybe_perturbed(x.primal) elif isinstance(x, pe.DynamicJaxprTracer): # If x is a DynamicJaxprTracer then we're staging out; differentiation could # happen later, but some types always have trivial tangents. @@ -1195,7 +1112,7 @@ def merge(l1, l2): return out, merge def abstractify(x): - return core.raise_to_shaped(core.get_aval(x)) + return core.get_aval(x) ### Custom transposition @@ -1296,7 +1213,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) - out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) + out_avals = f_jaxpr.out_avals t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) @@ -1350,7 +1267,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose, return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): - return map(core.raise_to_shaped, kwargs['callee'].out_avals) + return kwargs['callee'].out_avals linear_call_p = core.Primitive('linear_call') linear_call_p.multiple_results = True @@ -1483,7 +1400,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: in_tree, out_type) flat_fwd = _fix_fwd_args(flat_fwd) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + in_avals = [core.get_aval(x) for x in args_flat] fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) prim_tree, res_tree = out_trees() @@ -1511,11 +1428,11 @@ def fun_jaxpr_thunk(): return wrapped_fwd -@lu.transformation -def _fix_fwd_args(*args): +@lu.transformation2 +def _fix_fwd_args(f, *args): args = [(x, True) for x in args] args = [x for pair in args for x in pair] - yield (yield args, {}) + return f(*args) def _remat_opt_impl( *args, @@ -1532,7 +1449,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): return fwd_jaxpr.out_avals, fwd_jaxpr.effects def _remat_opt_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, + axis_data, args, in_dims, *, num_consts: int, num_res: int, @@ -1541,11 +1458,9 @@ def _remat_opt_vmap( ): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, in_batched, False, - axis_name, spmd_axis_name, main_type) + fwd_jaxpr, axis_data, in_batched, False) extra_consts = batched_fwd_jaxpr.consts batched_fwd_jaxpr = pe.close_jaxpr( pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) @@ -1557,8 +1472,7 @@ def _remat_opt_vmap( def batched_fun_jaxpr_thunk(): fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name, - main_type) + fun_jaxpr, axis_data, prim_batched, False) return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts batched_outs = remat_opt_p.bind(*extra_consts, *args, @@ -1592,7 +1506,7 @@ def _remat_opt_jvp( [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) - @pe._memoize + # @pe._memoize def fun_jvp_jaxpr_thunk(): fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) in_nz = [True] * len(primals) @@ -1620,6 +1534,8 @@ def _remat_opt_transpose( "remat optimization for custom_vjp does not support higher-order AD") def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + if not any(used_outs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] if any(used_res): @@ -1666,8 +1582,9 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): xla.register_initial_style_primitive(remat_opt_p) mlir.register_lowering(remat_opt_p, mlir.lower_fun( _remat_opt_impl, multiple_results=True)) -batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap -batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None) + + +batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index c5cf0edf14c6..95e0578f0b2d 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -458,7 +458,9 @@ def __call__(self, *args, **kwargs): in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_partitioning") - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + mesh = mesh_lib.thread_resources.env.physical_mesh + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_partitioning_p.bind( diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py new file mode 100644 index 000000000000..1e3e7fe60683 --- /dev/null +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -0,0 +1,435 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements SdyShardingRule.""" + +from collections import OrderedDict + +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy + + +# A single character replacement for ... to simplify parsing. +BATCHING: str = "…" + +# A prefix for names of batching dimension factors, used for expanding the +# leading ... into factors. +_BATCHING_DIM_FACTOR_PREFIX = "?" + +def _check_factor(factor:str): + """Validates a factor. + + A factor is a string starting with a letter and containing only letters, + digits, or underscores. + """ + if not factor[0].isalpha(): + raise ValueError(f"Factor names have to start with a letter, but got '{factor[0]}'") + for char in factor[1:]: + if char != "_" and not char.isdigit() and not char.isalpha(): + raise ValueError(f"Unknown character '{char}'") + +class CompoundFactor(tuple): + """Describes the factors for a compound factor. + + A compound factor should contain at least two factors, e.g. + * CompoundFactor('b', 'c'). + """ + def __init__(self, *factors): + if len(factors) < 2: + raise ValueError("A compound factor should contain at least two factors") + for factor in factors: + if not isinstance(factor, str): + raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}") + if factor == BATCHING: + raise ValueError("Ellipsis can't be used in a compound factor") + else: + _check_factor(factor) + + def __new__(cls, *factors): + return tuple.__new__(CompoundFactor, factors) + + +class ArrayMapping(tuple): + """Describes the factors for an operand or result. + + Each element is either a factor or a CompoundFactor. A leading element can + also be BATCHING, which represents batching dimensions. examples: + * ArrayMapping('a') + * ArrayMapping('b', 'c') + * ArrayMapping(CompoundFactor('b', 'c'), 'd') + * ArrayMapping(BATCHING, CompoundFactor('b', 'c'), 'd') + """ + def __init__(self, *dim_mappings): + for i, d in enumerate(dim_mappings): + if not isinstance(d, str) and not isinstance(d, CompoundFactor): + raise ValueError( + "Each element of ArrayMapping must be a str or CompoundFactor, but" + f" got {type(d)}") + if isinstance(d, str): + if d == BATCHING: + if i != 0: + raise ValueError("Ellipsis can only be used at the beginning of a dimension") + else: + _check_factor(d) + + def __new__(cls, *dim_mappings): + return tuple.__new__(ArrayMapping, dim_mappings) + + +class SdyShardingRule: + """Represents a Shardy sharding rule. + + An SdyShardingRule contains the ArrayMappings for operands and results, and an + optional list of factor sizes. A factor is a name used in the ArrayMappings. + If a factor is only used in CompoundFactors, its size must be specified. + """ + operand_mappings: tuple[ArrayMapping, ...] + result_mappings: tuple[ArrayMapping, ...] + factor_sizes: dict[str, int] + + def __init__(self, operand_mappings: tuple[ArrayMapping, ...], + result_mappings: tuple[ArrayMapping, ...], **factor_sizes): + # Find all factors and mark whether their size can be inferred. + factors_inferrable = dict() + for value in operand_mappings + result_mappings: + for dim in value: + if isinstance(dim, str): + factors_inferrable[dim] = True + else: + for factor in dim: + if factor not in factors_inferrable.keys(): + factors_inferrable[factor] = False + + # Check that factors in factor_sizes are used in the rule. + for factor in factor_sizes: + if factor not in factors_inferrable: + raise ValueError( + f"Factor {factor} is not used in the rule, but size is provided") + + # Check that factors that are used for a whole dimension aren't in + # factor_sizes and factors that are never used for a whole dimension are + # in factor_sizes. + for factor, inferrable in factors_inferrable.items(): + if factor not in factor_sizes and not inferrable: + raise ValueError( + f"Factor {factor} is only used in compound factors; must specify" + " its size") + if factor in factor_sizes and inferrable: + raise ValueError( + f"Factor {factor} represents a whole dimension; do not specify its" + " size") + + self.operand_mappings = operand_mappings + self.result_mappings = result_mappings + self.factor_sizes = factor_sizes + + def __str__(self): + return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})" + + +def _get_batching_dim_factor_name(batch_dim_order : int): + """Constructs a factor name for a batching dimension. + + We expand the leading ... into factors representing the batching dimensions + to support building the MLIR representation for the sharding rule. For this + reason, we construct a factor name that won't be used by users for the + batching dimensions. + """ + return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}" + +def _parse_values( + rule: str, +) -> tuple[ArrayMapping, ...]: + """Parses the LHS or RHS of an Einsum notation like string. + + Converts each operand or result in the Einsum notation like string to a tuple + of ArrayMapping. This very closely follows how einops parses their rules in + einops/parsing.py. + + Args: + rule: The Einsum notation for the operands or results of an operation. + + Returns: + The tuple of ArrayMapping. + + Raises: + ValueError: If the rule is not balanced or contains unknown characters. + """ + + # Remove unnecessary spaces in the rule to simplify the parsing process. + words = rule.split() + rule = " ".join(words) + + # Similar to einops rules, an empty LHS/RHS has a single scalar value. + if not rule: + return (ArrayMapping(),) + + all_values = [] + # Represent all dimensions of an value. When an value[0]==BATCHING, the + # value may have 0 or more leading dimensions. + value = [] + current_factor = None + # A value of None indicates the current dimension is not a compound dimension, + # while a value of [] indicates that we have just started parsing a compound + # dimension. + current_compound_dim: list[str] | None = None + + def add_factor(x): + if current_compound_dim is None: + value.append(x) + else: + current_compound_dim.append(x) + + for char in rule: + if char == BATCHING: + if (current_factor is not None or current_compound_dim is not None + or value): + raise ValueError( + "Ellipsis can only be used at the beginning of a dimension") + add_factor(BATCHING) + continue + if char in "(), ": + if current_factor is not None: + add_factor(current_factor) + current_factor = None + if char == "(": + if current_compound_dim is not None: + raise ValueError( + "Compound factors should be one level, nested brackets are not" + " allowed") + current_compound_dim = [] + elif char == ")": + if current_compound_dim is None: + raise ValueError("Brackets are not balanced") + if len(current_compound_dim) <= 1: + raise ValueError("Brackets should contain at least two factors") + value.append(CompoundFactor(*current_compound_dim)) + current_compound_dim = None + elif char == ",": + all_values.append(ArrayMapping(*value)) + value = [] + elif char == "_" or char.isdigit() or char.isalpha(): + if current_factor is None: + if str.isdigit(char): + raise ValueError(f"Factor names have to start with a letter, but got '{char}'") + current_factor = char + else: + current_factor += char + else: + raise ValueError(f"Unknown character '{char}'") + + if current_compound_dim is not None: + raise ValueError(f"Brackets are not balanced in rule: '{rule}'") + if current_factor is not None: + add_factor(current_factor) + all_values.append(ArrayMapping(*value)) + + return tuple(all_values) + +def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule: + """Constructs a SdyShardingRule object from the Einsum notation like string. + + This is done by verifying that the input Einsum notation like string and + with optional factor sizes represents a valid sharding rule and converting + it to an internal representation. + + Args: + rule: The Einsum notation like string for an operation. + **factor_sizes: The optional factor sizes. + + Raises: + ValueError: If there is any problem with the rule or factor_sizes. + """ + if not isinstance(rule, str): + raise TypeError(f"rule must be a str, but got {type(rule)}") + if not all(isinstance(size, int) for size in factor_sizes.values()): + raise TypeError( + f"factor_sizes must be a dict of str to int, but got {factor_sizes}") + + # Replace ... with a single char to simplify parsing. + if BATCHING in rule: + raise ValueError(f"Unknown character '{BATCHING}'") + if "." in rule: + rule = rule.replace("...", BATCHING) + if "." in rule: + raise ValueError("Character '.' must be used inside ellipsis '...'") + + try: + operands, results = rule.split("->") + except ValueError as e: + raise ValueError(f"There is no -> in rule: '{rule}'") from e + + operand_mappings = _parse_values(operands) + result_mappings = _parse_values(results) + + return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes) + + +def sdy_sharding_rule_to_mlir( + rule: SdyShardingRule, + operand_types: list[ir.Type], + result_types: list[ir.Type],) -> ir.Attribute: + """Builds the MLIR representation for the sharding rule. + + This is done by verifying that the rule is consistent with the types of + the operation and converting the Einsum notation like string to + OpShardingRuleAttr. + """ + if len(rule.operand_mappings) != len(operand_types): + raise ValueError( + f"Sharding rule has {len(rule.operand_mappings)} operands, but the operation" + f" has {len(operand_types)} operands") + if len(rule.result_mappings) != len(result_types): + raise ValueError( + f"Sharding rule has {len(rule.result_mappings)} results, but the operation" + f" has {len(result_types)} results") + + factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() + types = operand_types + result_types + UNKNOWN = -1 # Representation for unknown factor size or factor index. + + def get_message_for_value(i): + if i >= len(operand_types): + return f"{i - len(operand_types)}th result" + else: + return f"{i}th operand" + + def get_rank_for_value(i): + return ir.ShapedType(types[i]).rank + + def get_size_for_value_dim(i, j): + return ir.ShapedType(types[i]).shape[j] + + def add_factor(factor, size): + """Adds a factor to factors_to_indices_sizes. + + `size` may be a dimensions size, a user specified factor size, or UNKNOWN + if a factor is first used as in a compound factor and then used for a + whole dimension. + """ + factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) + if factor_index != UNKNOWN: + # Not the first time seeing the factor. + if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: + factor_or_batching_dim = ( + f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor + else f"Batching dimension {factor[1:]}") + raise ValueError( + f"{factor_or_batching_dim} corresponds to two sizes:" + f" {factor_size} and {size}") + if size != UNKNOWN and factor_size == UNKNOWN: + factors_to_indices_sizes[factor] = [factor_index, size] + else: + # First time seeing the factor. + factor_index = len(factors_to_indices_sizes) + factors_to_indices_sizes[factor] = [factor_index, size] + + def add_batching_dim_factor(batch_dim_order, factor_size): + ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) + add_factor(ellipsis_batch_dim_name, factor_size) + + def build_dim_mapping_for_compound_factors(i, j, factors): + accumulated_size = 1 + all_indices = [] + for factor in factors: + factor_index, factor_size = factors_to_indices_sizes[factor] + accumulated_size *= factor_size + all_indices.append(factor_index) + + dim_size = get_size_for_value_dim(i, j) + if accumulated_size != dim_size: + raise ValueError( + f"{get_message_for_value(i)} actual size {dim_size} doesn't match" + f" the size {accumulated_size} derived from the compound factors" + f" {factors}") + + return sdy.DimMappingAttr.get(factor_indices=all_indices) + + # Add factors and their sizes in the order they appear in the rule, + # including the batching dimensions represented by ellipsis. + ellipsis_rank = None + for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): + value = tuple(mapping) + if value and value[0] == BATCHING: + has_batching = True + value = value[1:] + else: + has_batching = False + rule_rank = len(value) + op_rank = get_rank_for_value(i) + # The number of dimensions represented by ellipsis. + current_batching_rank = 0 + if has_batching and op_rank >= rule_rank: + current_batching_rank = op_rank - rule_rank + if has_batching: + if ellipsis_rank is None: + ellipsis_rank = current_batching_rank + elif ellipsis_rank != current_batching_rank: + raise ValueError( + "Ellipsis represents different number of leading dimensions" + f" {ellipsis_rank} and {current_batching_rank}") + rule_rank += current_batching_rank + if rule_rank != op_rank: + msg = get_message_for_value(i) + raise ValueError( + f"Sharding rule {msg} has rank {rule_rank}, but the operation" + f" {msg} has rank {op_rank}") + + for j in range(current_batching_rank): + add_batching_dim_factor(j, get_size_for_value_dim(i, j)) + + for j, dim in enumerate(value): + if isinstance(dim, str): + add_factor(dim, get_size_for_value_dim(i, j + current_batching_rank)) + else: + for factor in dim: + add_factor(factor, rule.factor_sizes.get(factor, UNKNOWN)) + + # Build the tensor mappings for each operand and result. + tensor_mappings = [] + for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): + value = tuple(mapping) + dim_mappings = [] + + if value and value[0] == BATCHING: + value = value[1:] + if ellipsis_rank is None: + current_batching_rank = 0 + else: + current_batching_rank = ellipsis_rank + else: + current_batching_rank = 0 + + for j in range(current_batching_rank): + dim_mappings.append( + sdy.DimMappingAttr.get(factor_indices=[ + factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + + for j, dim in enumerate(value): + if isinstance(dim, str): + dim_mappings.append( + sdy.DimMappingAttr.get( + factor_indices=[factors_to_indices_sizes[dim][0]])) + else: + dim_mappings.append( + build_dim_mapping_for_compound_factors( + i, j + current_batching_rank, dim)) + + tensor_mappings.append( + sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) + + return sdy.OpShardingRuleAttr.get( + factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], + operand_mappings=tensor_mappings[0:len(operand_types)], + result_mappings=tensor_mappings[len(operand_types):]) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index a4de1b8cc46c..9fe77ca0a6ac 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -155,18 +155,9 @@ class CustomTransposePrimitive(core.Primitive): map_primitive = False multiple_results = True - def bind(self, call, *args, **params): - # TODO(frostig,mattjj): This doesn't handle closures yet, which is - # a bit involved. Closures are complicated by us binding `call` - # twice in the JVP rule for custom transpose. The `env_trace_todo` - # output by `process_env_traces` due to one of those two bindings - # should be passable to the other, and need to be passed onward - # since the second bind is deferred by partial eval (since it - # typically receives unknowns) - top_trace = core.find_top_trace(args) - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_custom_transpose(self, call, tracers, **params) - return outs + def bind_with_trace(self, trace, call_args, params): + call, tracers = call_args[0], call_args[1:] + return trace.process_custom_transpose(self, call, tracers, **params) # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 962244a321a9..de4028885d35 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -125,12 +125,10 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register("jax-numpy-astype-complex-to-real") register("jax-numpy-array-none") -register('jax-scipy-beta-args') -register('tracer-hash') -register('jax-numpy-reshape-newshape') register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') +register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') register('pallas-gpu-triton') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 179f8430febe..54c3a43e8a84 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,7 +16,7 @@ from __future__ import annotations import atexit -from collections.abc import Callable, Sequence +from collections.abc import Sequence import contextlib import dataclasses import enum @@ -48,13 +48,12 @@ from jax._src import lib from jax._src.mesh import AbstractMesh, Mesh from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - SingleDeviceSharding, NamedSharding, - GSPMDSharding, TransferToMemoryKind, is_single_device_sharding) + SingleDeviceSharding, NamedSharding, TransferToMemoryKind, + is_single_device_sharding) from jax._src.layout import Layout, DeviceLocalLayout @@ -96,7 +95,8 @@ def apply_primitive(prim, *args, **params): @util.cache() def xla_primitive_callable(prim: core.Primitive, **params): def prim_fun(*args): - return prim.bind(*args, **params) + with config.eager_constant_folding(False): + return prim.bind(*args, **params) prim_fun.__name__ = prim.name prim_fun.__qualname__ = prim.name return api.jit(prim_fun) @@ -137,7 +137,7 @@ def get_token_input( # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0]) + sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -278,17 +278,6 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool: return False -# We can optionally set a Jaxpr rewriter that can be applied just before -# compilation. This mechanism is used for compiling id_tap, we can -# remove it once we bring the id_tap implementation into the core. -outfeed_rewriter: Callable[[core.Jaxpr], core.Jaxpr] | None = None -def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: - if outfeed_rewriter is not None: - return outfeed_rewriter(jaxpr) - else: - return jaxpr - - def check_arg(arg: Any): if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)): raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid " @@ -361,50 +350,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): f"platform {inp_plat} and target sharding's device set " f"ids: {target_ids} on platform {target_plat}") - if xla_extension_version >= 292: - if inp_sharding.is_fully_replicated: - permute_order = None - else: - permute_order = np.vectorize(target_sharding._device_assignment.index, - otypes=[int])(inp_sharding._device_assignment) - new_mesh = Mesh( - target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes), - inp_sharding.mesh.axis_names) - new_s = NamedSharding( - new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, - _logical_device_ids=(None if permute_order is None else - tuple(permute_order.tolist()))) - new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) - return api.jit(_identity_fn, out_shardings=target_sharding, - donate_argnums=donate_argnums)(new_x) + if inp_sharding.is_fully_replicated: + permute_order = None else: - old_hlo_sharding = inp_sharding._to_xla_hlo_sharding(x.ndim) - if old_hlo_sharding.is_replicated(): - new_hlo_sharding = old_hlo_sharding - else: - permute_order = np.vectorize(target_sharding._device_assignment.index, + permute_order = np.vectorize(target_sharding._device_assignment.index, otypes=[int])(inp_sharding._device_assignment) - # Unfortunately need to fallback to V1 sharding here. - new_op_sharding = old_hlo_sharding.to_proto() - new_op_sharding.iota_reshape_dims = [] - new_op_sharding.iota_transpose_perm = [] - new_op_sharding.tile_assignment_devices = np.take( - permute_order, old_hlo_sharding.tile_assignment_devices() - ) - new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding) - assert (list(np.take(inp_sharding._device_assignment, - old_hlo_sharding.tile_assignment_devices())) - == list(np.take(target_sharding._device_assignment, - new_op_sharding.tile_assignment_devices))) - - new_x = array.make_array_from_single_device_arrays( - x.shape, - GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding, - memory_kind=target_sharding.memory_kind), - x._arrays, - ) - return api.jit(_identity_fn, out_shardings=target_sharding, - donate_argnums=donate_argnums)(new_x) + new_mesh = Mesh( + target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes), + inp_sharding.mesh.axis_names) + new_s = NamedSharding( + new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, + _logical_device_ids=(None if permute_order is None else + tuple(permute_order.tolist()))) + new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + return api.jit(_identity_fn, out_shardings=target_sharding, + donate_argnums=donate_argnums)(new_x) @dataclasses.dataclass(frozen=True) @@ -420,6 +380,7 @@ class _DeferredShardArg: s: Sharding aval: core.AbstractValue committed: bool + copy_semantics: CopySemantics @property def result_handler(self): @@ -450,21 +411,18 @@ def _device_put_sharding_impl(x, aval, device, copy): if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types): - # TODO(yashkatariya): Move this check to `jit`. multihost_utils.assert_equal( x, fail_message=( f"{type(x)} passed to device_put is not the same on each" " process. Make sure you are passing the same value of" f" {type(x)} on each process.")) - return api.jit( - _identity_fn, out_shardings=s, - donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x) + return _DeferredShardArg(x, s, aval, True, copy) # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. raise ValueError( "device_put's second argument must be a Device or a Sharding which" f" represents addressable devices, but got {s}. Please pass device or" " Sharding which represents addressable devices.") - return _DeferredShardArg(x, s, aval, True) + return _DeferredShardArg(x, s, aval, True, copy) # Only `Device` exists below. `Sharding` instance is handled above. if isinstance(x, array.ArrayImpl): @@ -472,16 +430,19 @@ def _device_put_sharding_impl(x, aval, device, copy): raise ValueError( "device_put's first argument must be a fully addressable array, but " f"got value with devices {x.devices()}") - if device is None and copy == CopySemantics.ALIAS: - return x + if device is None: + if copy == CopySemantics.ALIAS: + return x + else: + return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) - sh = SingleDeviceSharding(pxla._get_default_device() + sh = SingleDeviceSharding(pxla.get_default_device() if device is None else device) - return _DeferredShardArg(x, sh, aval, device is not None) + return _DeferredShardArg(x, sh, aval, device is not None, copy) def _device_put_impl( @@ -530,12 +491,14 @@ def _batched_device_put_impl( copy_semantics: Sequence[CopySemantics]): ys = [] shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], [] + shard_arg_copy_semantics = [] for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)): y = _device_put_impl(x, device=device, src=src, copy=cp) if isinstance(y, _DeferredShardArg): shard_arg_indices.append(i) shard_arg_xs.append(y.x) shard_arg_shardings.append(y.s) + shard_arg_copy_semantics.append(y.copy_semantics) ys.append(y) if shard_arg_xs: @@ -544,7 +507,8 @@ def _batched_device_put_impl( # device_put handles `Layout` via a different path, so just pass `None` as # the layout here. shard_arg_results = pxla.shard_args( - shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs) + shard_arg_shardings, [None] * len(shard_arg_xs), + shard_arg_copy_semantics, shard_arg_xs) for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) @@ -555,7 +519,12 @@ def _batched_device_put_impl( device_put_p = core.Primitive('device_put') device_put_p.multiple_results = True device_put_p.def_impl(_batched_device_put_impl) -device_put_p.def_abstract_eval(lambda *xs, devices, srcs, copy_semantics: xs) + +def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics): + if config.sharding_in_types.value: + return [x.update(sharding=s) for x, s in zip(xs, devices)] + return xs +device_put_p.def_abstract_eval(_device_put_abstract_eval) def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): results = [None] * len(cts) @@ -596,6 +565,12 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): # TODO(yashkatariya): Maybe we should add the custom calls anyways if it's # being used inside jit? Atleast for now, this preserves the old behavior. if ctx.module_context.all_default_mem_kind: + if config.sharding_in_types.value: + return [ + mlir.wrap_with_sharding_op( + ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto()) + for x, a in zip(xs, ctx.avals_out) + ] return xs def lower(x, device, aval, out_aval): if (isinstance(device, (Sharding, TransferToMemoryKind)) and @@ -621,6 +596,12 @@ def lower(x, device, aval, out_aval): def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): + if config.sharding_in_types.value: + return [ + mlir.wrap_with_sharding_op( + ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto()) + for x, a in zip(xs, ctx.avals_out) + ] return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 3ea9304b67aa..e9796d61c6f3 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -27,6 +27,13 @@ logger = logging.getLogger(__name__) +_CHECK_PROXY_ENVS = config.bool_flag( + name="jax_check_proxy_envs", + default=True, + help="Checks proxy vars in user envs and emit warnings.", +) + + class State: process_id: int = 0 num_processes: int = 1 @@ -42,7 +49,11 @@ def initialize(self, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + coordinator_bind_address: str | None = None, + service_heartbeat_interval_seconds: int = 10, + service_max_missing_heartbeats: int = 10, + client_heartbeat_interval_seconds: int = 10, + client_max_missing_heartbeats: int = 10): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -51,16 +62,18 @@ def initialize(self, if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): local_device_ids = list(map(int, env_ids.split(","))) - (coordinator_address, num_processes, process_id, local_device_ids) = ( - clusters.ClusterEnv.auto_detect_unset_distributed_params( - coordinator_address, - num_processes, - process_id, - local_device_ids, - cluster_detection_method, - initialization_timeout, - ) - ) + if (cluster_detection_method != 'deactivate' and + None in (coordinator_address, num_processes, process_id, local_device_ids)): + (coordinator_address, num_processes, process_id, local_device_ids) = ( + clusters.ClusterEnv.auto_detect_unset_distributed_params( + coordinator_address, + num_processes, + process_id, + local_device_ids, + cluster_detection_method, + initialization_timeout, + ) + ) if coordinator_address is None: raise ValueError('coordinator_address should be defined.') @@ -88,8 +101,10 @@ def initialize(self, self.process_id = process_id - # Emit a warning about PROXY variables if they are in the user's env: - proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] + proxy_vars = [] + if _CHECK_PROXY_ENVS.value: + proxy_vars = [key for key in os.environ.keys() + if '_proxy' in key.lower()] if len(proxy_vars) > 0: vars = " ".join(proxy_vars) + ". " @@ -107,7 +122,9 @@ def initialize(self, 'Starting JAX distributed service on %s', coordinator_bind_address ) self.service = xla_extension.get_distributed_runtime_service( - coordinator_bind_address, num_processes) + coordinator_bind_address, num_processes, + heartbeat_interval=service_heartbeat_interval_seconds, + max_missing_heartbeats=service_max_missing_heartbeats) self.num_processes = num_processes @@ -115,7 +132,9 @@ def initialize(self, raise RuntimeError('distributed.initialize should only be called once.') self.client = xla_extension.get_distributed_runtime_client( - coordinator_address, process_id, init_timeout=initialization_timeout) + coordinator_address, process_id, init_timeout=initialization_timeout, + heartbeat_interval=client_heartbeat_interval_seconds, + max_missing_heartbeats=client_max_missing_heartbeats, use_compression=True) logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() @@ -171,7 +190,9 @@ def initialize(coordinator_address: str | None = None, ``cluster_detection_method="mpi4py"`` to bootstrap the required arguments. Otherwise, you must provide the ``coordinator_address``, - ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + ``num_processes``, ``process_id``, and ``local_device_ids`` arguments + to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster + environment auto detection will be skipped. Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to @@ -197,7 +218,8 @@ def initialize(coordinator_address: str | None = None, cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment, and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``. - Legacy auto-detect options (OMPI, Slurm) remain enabled. + Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses + automatic cluster detection. initialization_timeout: Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 82be38d1cb57..b53e1777f6a9 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ def type(self) -> type: ... # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) + _float8_dtypes.insert(0, _float8_e4m3_dtype) +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) + _float8_dtypes.insert(0, _float8_e3m4_dtype) + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None @@ -339,8 +358,11 @@ def _issubclass(a: Any, b: Any) -> bool: return False +_types_for_issubdtype = (type, np.dtype, ExtendedDType) + # TODO(jakevdp): consider whether to disallow None here. We allow it # because np.issubdtype allows it (and treats it as equivalent to float64). +@set_module('jax.numpy') def issubdtype(a: DTypeLike | ExtendedDType | None, b: DTypeLike | ExtendedDType | None) -> bool: """Returns True if first argument is a typecode lower/equal in type hierarchy. @@ -360,8 +382,8 @@ def issubdtype(a: DTypeLike | ExtendedDType | None, # unhashable (e.g. custom objects with a dtype attribute). The following check is # fast and covers the majority of calls to this function within JAX library code. return _issubdtype_cached( - a if isinstance(a, (type, np.dtype, ExtendedDType)) else np.dtype(a), # type: ignore[arg-type] - b if isinstance(b, (type, np.dtype, ExtendedDType)) else np.dtype(b), # type: ignore[arg-type] + a if isinstance(a, _types_for_issubdtype) else np.dtype(a), # type: ignore[arg-type] + b if isinstance(b, _types_for_issubdtype) else np.dtype(b), # type: ignore[arg-type] ) @@ -456,6 +478,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, } +@set_module('jax.numpy') def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool: """Returns a boolean indicating whether a provided dtype is of a specified kind. @@ -648,10 +671,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy "JAX's internal logic; please report it to the JAX maintainers." ) +@set_module('jax.numpy') def promote_types(a: DTypeLike, b: DTypeLike) -> DType: """Returns the type to which a binary operation should cast its arguments. - For details of JAX's type promotion semantics, see :ref:`type-promotion`. + JAX implementation of :func:`numpy.promote_types`. For details of JAX's + type promotion semantics, see :ref:`type-promotion`. Args: a: a :class:`numpy.dtype` or a dtype specifier. @@ -659,6 +684,35 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType: Returns: A :class:`numpy.dtype` object. + + Examples: + Type specifiers may be strings, dtypes, or scalar types, and the return + value is always a dtype: + + >>> jnp.promote_types('int32', 'float32') # strings + dtype('float32') + >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes + dtype('float32') + >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types + dtype('float32') + + Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are + treated as weakly-typed and will not change the bit width of a strongly-typed + counterpart (see discussion in :ref:`type-promotion`): + + >>> jnp.promote_types('uint8', int) + dtype('uint8') + >>> jnp.promote_types('float16', float) + dtype('float16') + + This differs from the NumPy version of this function, which treats built-in scalar + types as equivalent to 64-bit types: + + >>> import numpy + >>> numpy.promote_types('uint8', int) + dtype('int64') + >>> numpy.promote_types('float16', float) + dtype('float64') """ # Note: we deliberately avoid `if a in _weak_types` here because we want to check # object identity, not object equality, due to the behavior of np.dtype.__eq__ @@ -784,7 +838,7 @@ def check_user_dtype_supported(dtype, fun_name=None): int2, int4, uint2, - uint4, + uint4 ] if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 4be10e733c0d..7bade8171078 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -108,12 +108,12 @@ def global_shards(self): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(xs, shardings, layouts): +def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics): arrs = [x._data for x in xs] phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] # TODO(yashkatariya): `layouts` should be converted to physical layouts. - return pxla.shard_args(phys_shardings, layouts, arrs) + return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 590f68ac0b3b..6540fd1f5d41 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -677,7 +677,7 @@ class KeyReuseError(JAXTypeError): KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 This sort of key reuse is problematic because the JAX PRNG is stateless, and keys - must be manually split; For more information on this see `Sharp Bits: Random Numbers - `_. + must be manually split; For more information on this see `the Pseudorandom Numbers + tutorial `_. """ pass diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index b1bb797d538c..dc87b501b4b6 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -26,7 +26,6 @@ import json import re from typing import Any, Protocol, TypeVar, Union, cast -import warnings from absl import logging import numpy as np @@ -102,20 +101,6 @@ def custom_call(cls, target_name: str) -> DisabledSafetyCheck: """ return DisabledSafetyCheck(f"custom_call:{target_name}") - @classmethod - def shape_assertions(cls) -> DisabledSafetyCheck: - """DEPRECATED: A noop. - - Was used previously to allow invocations with shapes that do not meet the - constraints. Has no effect anymore, shape assertions cannot be disabled. - """ - # TODO(necula): remove this after compatibility period. Was deprecated in - # May 2024. - warnings.warn( - "DisabledSafetyCheck.shape_assertions is deprecated, has no effect anymore", - DeprecationWarning, stacklevel=2) - return DisabledSafetyCheck("shape_assertions") - def is_custom_call(self) -> str | None: """Returns the custom call target allowed by this directive.""" m = re.match(r'custom_call:(.+)$', self._impl) @@ -218,6 +203,7 @@ class Exported: _get_vjp: Callable[[Exported], Exported] | None def mlir_module(self) -> str: + """A string representation of the `mlir_module_serialized`.""" return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) def __str__(self): @@ -226,8 +212,8 @@ def __str__(self): return f"Exported(fun_name={self.fun_name}, ...)" def in_shardings_jax( - self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + self, + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: """Creates Shardings corresponding to self.in_shardings_hlo. The Exported object stores `in_shardings_hlo` as HloShardings, which are @@ -236,30 +222,31 @@ def in_shardings_jax( `jax.device_put`. Example usage: - >>> from jax import export - >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) - >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), - ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) - ... )(np.arange(jax.device_count())) - >>> exp.in_shardings_hlo - ({devices=[8]<=[8]},) - - # Create a mesh for running the exported object - >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) - >>> - # Put the args and kwargs on the appropriate devices - >>> run_arg = jax.device_put(np.arange(jax.device_count()), - ... exp.in_shardings_jax(run_mesh)[0]) - >>> res = exp.call(run_arg) - >>> res.addressable_shards - [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), - Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), - Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), - Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), - Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), - Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), - Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), - Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] + + >>> from jax import export + >>> # Prepare the exported object: + >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) + >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), + ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) + ... )(np.arange(jax.device_count())) + >>> exp.in_shardings_hlo + ({devices=[8]<=[8]},) + >>> # Create a mesh for running the exported object + >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) + >>> # Put the args and kwargs on the appropriate devices + >>> run_arg = jax.device_put(np.arange(jax.device_count()), + ... exp.in_shardings_jax(run_mesh)[0]) + >>> res = exp.call(run_arg) + >>> res.addressable_shards + [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), + Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), + Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), + Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), + Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), + Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), + Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), + Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] + """ return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) for s in self.in_shardings_hlo) @@ -267,40 +254,13 @@ def in_shardings_jax( def out_shardings_jax( self, mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: - """Creates Shardings corresponding to self.out_shardings_hlo. + """Creates Shardings corresponding to `self.out_shardings_hlo`. See documentation for in_shardings_jax. """ return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) for s in self.out_shardings_hlo) - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def lowering_platforms(self): - """DEPRECATED.""" - warnings.warn("lowering_platform is deprecated. Use .platforms instead.", - DeprecationWarning, stacklevel=2) - return self.platforms - - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def mlir_module_serialization_version(self): - """DEPRECATED.""" - warnings.warn("mlir_module_serialization_version is deprecated. Use .calling_convention_version instead.", - DeprecationWarning, stacklevel=2) - return self.calling_convention_version - - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def uses_shape_polymorphism(self): - """DEPRECATED.""" - warnings.warn("uses_shape_polymorphism is deprecated. Use .uses_global_constants instead.", - DeprecationWarning, stacklevel=2) - return self.uses_global_constants - def has_vjp(self) -> bool: """Returns if this Exported supports VJP.""" return self._get_vjp is not None @@ -331,6 +291,21 @@ def serialize(self, return serialize(self, vjp_order=vjp_order) def call(self, *args, **kwargs): + """Call an exported function from a JAX program. + + Args: + args: the positional arguments to pass to the exported function. This + should be a pytree of arrays with the same pytree structure as the + arguments for which the function was exported. + kwargs: the keyword arguments to pass to the exported function. + + Returns: a pytree of result array, with the same structure as the + results of the exported function. + + The invocation supports reverse-mode AD, and all the features supported + by exporting: shape polymorphism, multi-platform, device polymorphism. + See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html). + """ return call_exported(self)(*args, **kwargs) @@ -546,109 +521,11 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: aval = core.raise_to_shaped(core.get_aval(a)) return aval.shape, aval.dtype -def args_specs( - args, # pytree of arguments - polymorphic_shapes, # prefix pytree of strings - get_shape_and_dtype=shape_and_dtype_jax_array, -): - # TODO: deprecated in January 2024, to be removed. - warnings.warn( - "export.args_specs is deprecated in favor of export.symbolic_args_specs", - DeprecationWarning, stacklevel=2) - if get_shape_and_dtype is not shape_and_dtype_jax_array: - # This was needed in some older jax2tf implementations - args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)), - args) - return shape_poly.symbolic_args_specs(args, polymorphic_shapes) - - -# TODO(necula): remove this once we remove jax.experimental.export. -def export_back_compat( - fun_jax: Callable, - *, - lowering_platforms: Sequence[str] | None = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, - ) -> Callable[..., Exported]: - """Exports native serialization for a JAX function. - - Note: this function exists only for internal usage by jax2tf and for - backwards compatibility with jax.experimental.export. Use - `jax.export` instead. - See https://jax.readthedocs.io/en/latest/export/export.html - - Args: - fun_jax: the function to lower and serialize. - lowering_platforms: - Optional sequence containing a subset of 'tpu', 'cpu', - 'cuda', 'rocm'. If more than one platform is specified, then - the lowered code takes an argument specifying the platform. - If None, then use the default JAX backend. - The calling convention for multiple platforms is explained - at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. - disabled_checks: the safety checks to disable. See docstring - of `DisabledSafetyCheck`. - - Returns: - a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. - - Usage: - - def f_jax(*args, **kwargs): ... - exported = jax_export.export(f_jax)(*args, **kwargs) - """ - - def do_export(*args_specs, **kwargs_specs) -> Exported: - if hasattr(fun_jax, "trace"): - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax - else: - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. In that case we raise - # an error if the lowered function contains non-replicated sharding annotations. - wrapped_fun_jax = jax.jit(fun_jax) - - if lowering_platforms is not None: - actual_lowering_platforms = tuple(lowering_platforms) - else: - actual_lowering_platforms = (default_export_platform(),) - - # TODO: move to `lower` - symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] - for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - # Static args may have no `shape` attribute. - if not hasattr(aval, "shape"): - continue - for d in aval.shape: - if shape_poly.is_symbolic_dim(d): - if symbolic_scope is None: - symbolic_scope = (d.scope, k_path) - continue - symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}", - self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", - other_descr=shape_poly.args_kwargs_path_to_str(k_path)) - - traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs) - lowered = traced.lower( - lowering_platforms=actual_lowering_platforms, - _private_parameters=mlir.LoweringParameters( - for_export=True, - export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) - return _export_lowered( - lowered, traced.jaxpr, traced.fun_name, - disabled_checks=disabled_checks, - _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) - return do_export def export( fun_jit: stages.Wrapped, *, platforms: Sequence[str] | None = None, - lowering_platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. @@ -662,7 +539,6 @@ def export( If None, then use the default JAX backend. The calling convention for multiple platforms is explained at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. - lowering_platforms: DEPRECATED, use `platforms`. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -689,34 +565,38 @@ def export( >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ + return _export_internal(fun_jit, platforms=platforms, + disabled_checks=disabled_checks) + + +# TODO(necula): remove this once we improve the integration with jax2tf. +def _export_internal( + fun_jit: stages.Wrapped, + *, + platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Callable[..., Exported]: + """Exports native serialization for a JAX function. + + Note: this function exists only for internal usage by jax2tf. Use + `jax.export` instead. + See https://jax.readthedocs.io/en/latest/export/export.html + + See docstring of `export` for more details. + """ if not isinstance(fun_jit, stages.Wrapped): raise ValueError( f"Function to be exported must be the result of `jit` but is: {fun_jit}") - if platforms is not None and lowering_platforms is not None: - raise ValueError("Cannot use both `platforms` and `lowering_platforms`") - if platforms is None and lowering_platforms is not None: - platforms = lowering_platforms - if platforms is not None: - actual_lowering_platforms = tuple(platforms) - else: - actual_lowering_platforms = (default_export_platform(),) def do_export(*args_specs, **kwargs_specs) -> Exported: + if platforms is not None: + actual_lowering_platforms = tuple(platforms) + else: + actual_lowering_platforms = (default_export_platform(),) + # TODO: move to `lower` - symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] - for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - # Static args may have no `shape` attribute. - if not hasattr(aval, "shape"): - continue - for d in aval.shape: - if shape_poly.is_symbolic_dim(d): - if symbolic_scope is None: - symbolic_scope = (d.scope, k_path) - continue - symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {util.fun_name(fun_jit)}", - self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", - other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + check_symbolic_scope_errors(fun_jit, args_specs, kwargs_specs) traced = fun_jit.trace(*args_specs, **kwargs_specs) lowered = traced.lower( @@ -726,12 +606,32 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( lowered, traced.jaxpr, traced.fun_name, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) return do_export + +def check_symbolic_scope_errors(fun_jax, args_specs, kwargs_specs): + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] + for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: + # Static args may have no `shape` attribute. + if not hasattr(aval, "shape"): + continue + for d in aval.shape: + if shape_poly.is_symbolic_dim(d): + if symbolic_scope is None: + symbolic_scope = (d.scope, k_path) + continue + symbolic_scope[0]._check_same_scope( + d, when=f"when exporting {util.fun_name(fun_jax)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + def _export_lowered( lowered: stages.Lowered, - jaxpr: core.ClosedJaxpr, fun_name: str, + jaxpr: core.ClosedJaxpr, + fun_name: str, disabled_checks: Sequence[DisabledSafetyCheck] = (), _device_assignment_for_internal_jax2tf_use_only = None, ) -> Exported: @@ -801,9 +701,9 @@ def _export_lowered( nr_devices = len(lowering.compile_args["device_assignment"]) def export_sharding(s: LoweringSharding, aval: core.ShapedArray) -> HloSharding | None: - if sharding_impls.is_unspecified(s): + if isinstance(s, sharding_impls.UnspecifiedValue): return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + return s._to_xla_hlo_sharding(aval.ndim) all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], module_kept_var_idx, @@ -1114,7 +1014,9 @@ def _check_lowering(lowering) -> None: "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", + "lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi", "lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi", + "lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. @@ -1122,7 +1024,8 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", "cu_threefry2x32", "cu_threefry2x32_ffi", - "__gpu$xla.gpu.triton", # Pallas call on GPU + # Triton IR does not guarantee stability. + # "__gpu$xla.gpu.triton", # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU @@ -1137,6 +1040,8 @@ def _check_lowering(lowering) -> None: "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", # schur on CPU "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", + # tridiagonal on CPU + "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", # hessenberg on CPU "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on GPU @@ -1414,9 +1319,10 @@ def pp_arg_dim(dim_idx: int | None) -> str: # Must express the exported_dim_vars in terms of the shapes in in_avals. solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( exported.in_avals, args_kwargs_tree=exported.in_tree) - synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) + synthetic_env: shape_poly.DimVarEnv = { + vname: in_avals[arg_idx].shape[dim_idx] + for (vname, arg_idx, dim_idx) in synth_dim_vars} + synthetic_eval = shape_poly.ShapeEvaluator(synthetic_env) # We discharge all the constraints statically. This results in much simpler # composability (because we do not have to worry about the constraints of the # Exported called recursively; we only need to worry about entry-point diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index b72d0134cf1f..b71b377d8999 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -67,6 +67,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, @@ -97,7 +99,7 @@ table Effect { enum DisabledSafetyCheckKind: byte { platform, custom_call, - shape_assertions, + shape_assertions, // unused } table DisabledSafetyCheck { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 434c4c5cf10c..0d9ce961b556 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -359,6 +359,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() @@ -485,8 +489,6 @@ def _serialize_disabled_safety_check( custom_call_target = builder.CreateString(custom_call_target_str) elif check == _export.DisabledSafetyCheck.platform(): kind = ser_flatbuf.DisabledSafetyCheckKind.platform - elif check == _export.DisabledSafetyCheck.shape_assertions(): - kind = ser_flatbuf.DisabledSafetyCheckKind.shape_assertions else: raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}") @@ -510,5 +512,10 @@ def _deserialize_disabled_safety_check( if kind == ser_flatbuf.DisabledSafetyCheckKind.platform: return _export.DisabledSafetyCheck.platform() if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions: - return _export.DisabledSafetyCheck.shape_assertions() + # shape_assertions has been deprecated in June 2024 (turned into a no-op), + # and removed in November 2024. We deserialize it to a DisabledSafetyCheck + # that has no effect. + # TODO(necula): remove this after June 2025, when we should not have any + # more serialized artifacts with shape_assertions. + return _export.DisabledSafetyCheck.custom_call("no op") assert False, kind diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 18dd2c3cbab1..70d298020961 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,6 +53,8 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 77786cbf1a9d..010edef1e54a 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -764,13 +764,23 @@ def __rmul__(self, other): return _DimExpr._linear_combination(self, other, 0, 0, self.scope) return _ensure_poly(other, "mul", self.scope).__mul__(self) - def __pow__(self, power, modulo=None): - assert modulo is None - try: - power = int(power) - except: - raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") - return functools.reduce(op.mul, [self] * power) + def __pow__(self, power: core.DimSize, modulo=None): + if modulo is not None: + raise NotImplementedError("__pow__ modulo not implemented") + if is_symbolic_dim(power): + return power.__rpow__(self) # type: ignore + if power != int(power): + raise ValueError(f"Symbolic dimension cannot be raised to non-integer powers: '{self}' ** '{power}'") + if power >= 0: + return functools.reduce(op.mul, [self] * power, 1) + # We don't support negative powers, because JAX does not allow negative + # powers for integers + raise ValueError(f"Symbolic dimension cannot be raised to negative powers: '{self}' ** '{power}'") + + def __rpow__(self, other, modulo=None): + if modulo is not None: + raise NotImplementedError("__rpow__ modulo not implemented") + return self.__jax_array__().__rpow__(other) def __floordiv__(self, divisor): if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): @@ -1198,12 +1208,6 @@ def is_symbolic_dim(p: DimSize) -> bool: """ return isinstance(p, _DimExpr) -def is_poly_dim(p: DimSize) -> bool: - # TODO: deprecated January 2024, remove June 2024. - warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim", - DeprecationWarning, stacklevel=2) - return is_symbolic_dim(p) - dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): @@ -1413,8 +1417,6 @@ def symbolic_args_specs( shapes_specs, # prefix pytree of strings constraints: Sequence[str] = (), scope: SymbolicScope | None = None, - symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24 - symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24 ): """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. @@ -1435,25 +1437,10 @@ def symbolic_args_specs( arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. - symbolic_constraints: DEPRECATED, use `constraints`. - symbolic_scope: DEPRECATED, use `scope`. Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes replaced with symbolic dimensions as specified by `shapes_specs`. """ - if symbolic_constraints: - warnings.warn("symbolic_constraints is deprecated, use constraints", - DeprecationWarning, stacklevel=2) - if constraints: - raise ValueError("Cannot use both symbolic_constraints and constraints") - constraints = symbolic_constraints - if symbolic_scope is not None: - warnings.warn("symbolic_scope is deprecated, use scope", - DeprecationWarning, stacklevel=2) - if scope is not None: - raise ValueError("Cannot use both symbolic_scope and scope") - scope = symbolic_scope - polymorphic_shapes = shapes_specs args_flat, args_tree = tree_util.tree_flatten(args) @@ -1746,11 +1733,10 @@ def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]: return sorted(dim_vars) -class CachingShapeEvaluator: - def __init__(self, **env): +class ShapeEvaluator: + def __init__(self, env: DimVarEnv): self.env = env - @functools.lru_cache(128) def evaluate(self, e: DimSize): if core.is_constant_dim(e): res = op.index(e) # type: ignore @@ -1769,7 +1755,7 @@ class ShapeConstraint: # is formed by evaluating the DimSize and concatenating the sequence. error_message_pieces: Sequence[str | DimSize] - def check_statically(self, eval: CachingShapeEvaluator) -> None: + def check_statically(self, eval: ShapeEvaluator) -> None: """Evaluates a constraint statically.""" left, right = eval.evaluate(self.left), eval.evaluate(self.right) try: @@ -1785,7 +1771,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: if not ok: raise self.make_error(eval) - def compute(self, eval: CachingShapeEvaluator) -> jax.Array | None: + def compute(self, eval: ShapeEvaluator) -> jax.Array | None: """Computes if the constraint is satisfied. If the constraint can be resolved statically returns None @@ -1820,7 +1806,7 @@ def __str__(self): def error_message_and_inputs( self, - eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]: + eval: ShapeEvaluator) -> tuple[str, Sequence[Any]]: """Forms the error_message and error message_inputs. See shape_assertion. """ @@ -1849,7 +1835,7 @@ def error_message_and_inputs( return ("".join(error_message_strings), error_message_inputs) - def make_error(self, eval: CachingShapeEvaluator) -> Exception: + def make_error(self, eval: ShapeEvaluator) -> Exception: error_message, error_message_inputs = self.error_message_and_inputs(eval) return ValueError(error_message.format(*error_message_inputs)) @@ -1865,7 +1851,7 @@ def add_constraint(self, c = ShapeConstraint(comp, left, right, error_message_pieces) self.constraints.append(c) - def check_statically(self, eval: CachingShapeEvaluator) -> None: + def check_statically(self, eval: ShapeEvaluator) -> None: """Evaluates all the constraints statically. If the static checking of any constraint fails, raises ValueError. @@ -1873,7 +1859,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: for constraint in self.constraints: constraint.check_statically(eval) - def shape_assertions(self, eval: CachingShapeEvaluator) -> None: + def shape_assertions(self, eval: ShapeEvaluator) -> None: """Computes the shape assertions for the set of constraints. See jax_export.Exported docstring. @@ -2006,7 +1992,8 @@ def compute_dim_vars_from_arg_shapes( generate the code for computing the dimension variables. It also generates the shape assertions. - Returns: the values of the dimension variables, in the order determined by + Returns: + The values of the dimension variables, in the order determined by `all_dim_vars(args_avals)`. """ dim_vars = all_dim_vars(args_avals) @@ -2014,13 +2001,13 @@ def compute_dim_vars_from_arg_shapes( tuple(args_avals), args_kwargs_tree=args_kwargs_tree) # Replace the synthetic vars with the dynamic shape of the actual arg - synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], - dimension=dim_idx) - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = CachingShapeEvaluator(**synthetic_env) + synthetic_env: DimVarEnv = { + vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx) + for (vname, arg_idx, dim_idx) in synth_dim_vars + } + synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] - return tuple(dim_values) + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) def _solve_dim_equations( eqns: list[_DimEquation], @@ -2154,7 +2141,8 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] if not eqns: add_explicit_symbolic_constraints(shape_env) - return shape_env, shape_constraints # SUCCESS + # SUCCESS + return shape_env, shape_constraints # pytype: disable=bad-return-type elif len(eqns) >= nr_eqns: break diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 9a45b3f77a93..5207e6289e26 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -18,7 +18,7 @@ import ctypes import functools import os -from typing import Any +from typing import Any, overload import numpy as np @@ -27,7 +27,7 @@ from jax._src import dispatch from jax._src import effects from jax._src import util -from jax._src.callback import _check_shape_dtype, callback_batching_rule +from jax._src.callback import callback_batching_rule from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -116,17 +116,17 @@ def _aval_shape(aval: core.AbstractValue) -> Shape: return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error -def _convert_layout(aval: core.AbstractValue, - layout: FfiLayoutOptions = None) -> Sequence[int]: +def _convert_layout_for_lowering( + aval: core.AbstractValue, layout: FfiLayoutOptions = None) -> Sequence[int]: """Convert a layout to the minor-to-major order used by the custom call API.""" if layout is None: - return list(reversed(range(len(_aval_shape(aval))))) + return tuple(reversed(range(len(_aval_shape(aval))))) elif isinstance(layout, DeviceLocalLayout): if layout._tiling is not None: raise ValueError("The FFI does not support layouts with tiling") return layout.major_to_minor[::-1] else: - return layout + return tuple(layout) def ffi_lowering( @@ -134,7 +134,7 @@ def ffi_lowering( *, operand_layouts: Sequence[FfiLayoutOptions] | None = None, result_layouts: Sequence[FfiLayoutOptions] | None = None, - backend_config: Mapping[str, ir.Attribute] | None = None, + backend_config: Mapping[str, ir.Attribute] | str | None = None, **lowering_args: Any ) -> mlir.LoweringRule: """Build a lowering rule for an foreign function interface (FFI) target. @@ -143,6 +143,10 @@ def ffi_lowering( compute the input and output types and shapes for the custom call, assuming row-major layouts. + Note that layouts passed to this function as tuples should be in + minor-to-major order (as expected by XLA) rather than major-to-minor as used + by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``. + If keyword arguments are passed to the lowering rule, these are treated as attributes, and added to `backend_config`. @@ -163,20 +167,32 @@ def _lowering( ) -> Sequence[ir.Value | Sequence[ir.Value]]: kwargs = dict(lowering_args) kwargs.setdefault("api_version", 4) - kwargs["backend_config"] = dict( - backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) + if kwargs["api_version"] >= 4: + if backend_config is not None and not isinstance(backend_config, dict): + raise ValueError( + "When api_version > 4, backend_config must be a dictionary.") + kwargs["backend_config"] = dict( + backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) + else: + if params: + raise ValueError( + "The use of ffi_call attributes requires a custom call API version " + f"of at least 4; got api_version={kwargs['api_version']}.") + kwargs["backend_config"] = backend_config if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: - kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in) + kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in) else: kwargs["operand_layouts"] = [ - _convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)] + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_in, operand_layouts)] if result_layouts is None: - kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out) + kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out) else: kwargs["result_layouts"] = [ - _convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)] + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_out, result_layouts)] if "result_shapes" not in kwargs and not all( core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): kwargs["result_shapes"] = [ @@ -193,21 +209,88 @@ def _lowering( def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]: avals: list[core.AbstractValue] = [] - for result in results: + for idx, result in enumerate(results): if isinstance(result, core.AbstractToken): avals.append(result) else: - _check_shape_dtype(result) + if not hasattr(result, "shape") or not hasattr(result, "dtype"): + raise ValueError( + "All elements of result_shape_dtypes must have 'shape' and 'dtype' " + f"attributes. Got {result} at position {idx}.") avals.append(core.ShapedArray(result.shape, result.dtype)) return tuple(avals) +def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool: + if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken): + return True + if getattr(a, "shape", ()) != getattr(b, "shape", ()): + return False + if getattr(a, "dtype", ()) != getattr(b, "dtype", ()): + return False + return True + + +def _convert_layouts_for_ffi_call( + avals: Sequence[core.AbstractValue], + layouts: Sequence[FfiLayoutOptions]) -> tuple[Sequence[int], ...]: + return tuple( + _convert_layout_for_lowering( + aval, + layout if layout is None or isinstance(layout, DeviceLocalLayout) + else layout[::-1] + ) + for aval, layout in zip(avals, layouts)) + + +# ffi_call() returns as many results as result_shape_dtypes. +@overload +def ffi_call( + target_name: str, + result_shape_dtypes: ResultMetadata, + *deprecated_args: ArrayLike, + has_side_effect: bool = ..., + vmap_method: str | None = ..., + input_layouts: Sequence[FfiLayoutOptions] | None = ..., + output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ..., + input_output_aliases: dict[int, int] | None = ..., + custom_call_api_version: int = ..., + legacy_backend_config: str | None = ..., + vectorized: bool | DeprecatedArg = ..., + **deprecated_kwargs: Any, +) -> Callable[..., Array] | Array: + ... + + +@overload +def ffi_call( + target_name: str, + result_shape_dtypes: Sequence[ResultMetadata], + *deprecated_args: ArrayLike, + has_side_effect: bool = ..., + vmap_method: str | None = ..., + input_layouts: Sequence[FfiLayoutOptions] | None = ..., + output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ..., + input_output_aliases: dict[int, int] | None = ..., + custom_call_api_version: int = ..., + legacy_backend_config: str | None = ..., + vectorized: bool | DeprecatedArg = ..., + **deprecated_kwargs: Any, +) -> Callable[..., Sequence[Array]] | Sequence[Array]: + ... + + def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *deprecated_args: ArrayLike, has_side_effect: bool = False, vmap_method: str | None = None, + input_layouts: Sequence[FfiLayoutOptions] | None = None, + output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = None, + input_output_aliases: dict[int, int] | None = None, + custom_call_api_version: int = 4, + legacy_backend_config: str | None = None, vectorized: bool | DeprecatedArg = DeprecatedArg(), **deprecated_kwargs: Any, ) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]: @@ -227,7 +310,7 @@ def ffi_call( Args: target_name: the name of the XLA FFI custom call target that was registered - using :func:`~jaxlib.xla_client.register_custom_call_target`. + using :func:`~jax.extend.ffi.register_ffi_target`. result_shape_dtypes: an object, or sequence of objects, with ``shape`` and ``dtype`` attributes which are expected to match the shape and dtype of the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often @@ -238,6 +321,32 @@ def ffi_call( outputs are not used. vmap_method: string specifying how the FFI call transforms under :func:`~jax.vmap` as described above. + input_layouts: a sequence of layouts for each input argument. In each case, + the layout can be (a) ``None`` indicating that this input is in default + row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order, + or (c) a sequence of integers specifying the major-to-minor axis + ordering. Users who are familiar with XLA layouts should note that this + function expects layouts in major-to-minor order instead of the + minor-to-major order that XLA uses. For example, a batch of row-major + matrices could be specified using the layout ``[0, 1, 2]``, whereas a + batch of column-major matrices would have layout ``[0, 2, 1]``. In both + of these examples, the leading/batch dimension is the "slowest" axis. The + ``input_layouts`` parameter should be used to request the memory layout + expected by the FFI call target, and XLA will ensure that the buffers + have the correct layouts before the handler is executed. + output_layouts: like ``input_layouts``, but specifying the required layouts + for the output arrays. + input_output_aliases: a dictionary where the keys are input indices and the + values are output indices. This mapping indicates which output arrays + alias specific input arrays. + custom_call_api_version: the version number of the custom call API + implemented by the FFI target ``target_name``. The only formally + supported version is the typed FFI API with ``custom_call_api_version=4``, + but earlier unsupported custom calls can be executed using this argument. + legacy_backend_config: for legacy targets implemented using + ``custom_call_api_version<4``, attributes are passed using the opaque + string representation provided by this argument. This parameter cannot be + used with ``custom_call_api_version>=4``. Returns: A function that can be called with the input arrays as positional arguments @@ -263,14 +372,73 @@ def ffi_call( f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, " f"but got: {vmap_method}") + output_layouts_: Sequence[FfiLayoutOptions] | None if isinstance(result_shape_dtypes, Sequence): + output_layouts_ = output_layouts # type: ignore multiple_results = True result_avals = _result_avals(result_shape_dtypes) else: multiple_results = False result_avals = _result_avals((result_shape_dtypes,)) + output_layouts_ = (output_layouts,) # type: ignore + + if custom_call_api_version >= 4 and legacy_backend_config is not None: + raise ValueError( + "The use of the legacy_backend_config parameter requires " + f"custom_call_api_version < 4; got {custom_call_api_version}.") def wrapped(*args: ArrayLike, **kwargs: Any): + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] + + if input_layouts is None: + static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals)) + else: + if len(input_layouts) != len(in_avals): + raise ValueError( + f"The number of input arguments ({len(in_avals)}) must equal the " + f"number of input layouts ({len(input_layouts)}).") + static_input_layouts = _convert_layouts_for_ffi_call(in_avals, + input_layouts) + if output_layouts_ is None: + static_output_layouts = tuple(map(_convert_layout_for_lowering, + result_avals)) + else: + if len(output_layouts_) != len(result_avals): + raise ValueError( + f"The number of outputs ({len(result_avals)}) must equal the " + f"number of output layouts ({len(output_layouts_)}).") + static_output_layouts = _convert_layouts_for_ffi_call(result_avals, + output_layouts_) + + static_input_output_aliases: tuple[tuple[int, int], ...] = () + if input_output_aliases is not None: + for i_idx, o_idx in sorted(input_output_aliases.items()): + i_idx, o_idx = int(i_idx), int(o_idx) + if i_idx >= len(args): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with input index {i_idx} outside the range [0, " + f"{len(args)}).") + if o_idx >= len(result_avals): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with output index {o_idx} outside the range [0, " + f"{len(result_avals)}).") + in_aval = in_avals[i_idx] + out_aval = result_avals[o_idx] + if not _check_compatible_avals(in_aval, out_aval): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with abstract value {in_aval} and an " + f"output with a different abstract value {out_aval}.") + if static_input_layouts[i_idx] != static_output_layouts[o_idx]: + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with layout {static_input_layouts[i_idx]} " + "and an output with a different layout " + f"{static_output_layouts[o_idx]}.") + static_input_output_aliases += ((i_idx, o_idx),) + results = ffi_call_p.bind( *args, result_avals=result_avals, @@ -278,6 +446,11 @@ def wrapped(*args: ArrayLike, **kwargs: Any): vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, + input_layouts=static_input_layouts, + output_layouts=static_output_layouts, + input_output_aliases=static_input_output_aliases, + custom_call_api_version=custom_call_api_version, + legacy_backend_config=legacy_backend_config, attributes=_wrap_kwargs_hashable(kwargs), ) if multiple_results: @@ -383,26 +556,23 @@ def __str__(self): def ffi_call_abstract_eval( *avals_in, result_avals: tuple[core.AbstractValue, ...], - target_name: str, - vectorized: bool | DeprecatedArg, - vmap_method: str | None, has_side_effect: bool, - attributes: Sequence[tuple[str, Any]], + **_, ): - del avals_in, target_name, vectorized, vmap_method, attributes + del avals_in # unused effects = {_FfiEffect} if has_side_effect else core.no_effects return result_avals, effects -def ffi_call_jvp(*args, target_name, **kwargs): - del args, kwargs +def ffi_call_jvp(*args, target_name, **_): + del args raise ValueError( f"The FFI call to `{target_name}` cannot be differentiated. " "You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.") -def ffi_call_transpose(*args, target_name, **kwargs): - del args, kwargs +def ffi_call_transpose(*args, target_name, **_): + del args raise ValueError( f"The FFI call to `{target_name}` cannot be differentiated. " "You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.") @@ -411,15 +581,22 @@ def ffi_call_transpose(*args, target_name, **kwargs): def ffi_call_lowering( ctx: mlir.LoweringRuleContext, *operands: ir.Value, - result_avals: tuple[core.AbstractValue, ...], target_name: str, - vectorized: bool | DeprecatedArg, - vmap_method: str | None, has_side_effect: bool, + input_layouts: Sequence[Sequence[int]], + output_layouts: Sequence[Sequence[int]], + input_output_aliases: Sequence[tuple[int, int]], + custom_call_api_version: int, + legacy_backend_config: str | None, attributes: Sequence[tuple[str, Any]], + **_, ) -> Sequence[ir.Value]: - del result_avals, vectorized, vmap_method - rule = ffi_lowering(target_name, has_side_effect=has_side_effect) + rule = ffi_lowering(target_name, has_side_effect=has_side_effect, + operand_layouts=input_layouts, + result_layouts=output_layouts, + operand_output_aliases=dict(input_output_aliases), + api_version=custom_call_api_version, + backend_config=legacy_backend_config) return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py index d7e6e5a1bc48..309aa73f20ba 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py @@ -241,3 +241,218 @@ mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00", xla_call_module_version=6, ) # End paste + +data_2024_11_29 = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], + [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], + [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], + [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),), + expected_outputs=(array([[ 3.2464249196572972e+01+0.j, -1.3416407864998739e+01+0.j, + -1.2558842947806125e-14+0.j, -7.3490869705474997e-15+0.j], + [ 0.0000000000000000e+00+0.j, -2.4642491965729798e+00+0.j, + -2.5534994473279107e-15+0.j, -1.3671521621839345e-16+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + -1.8779126463272594e-15+0.j, 7.2486619604759691e-16+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + 0.0000000000000000e+00+0.j, 4.8523679991768567e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197511 +0.j, + 0.5401354211381763 +0.j, -0.09085002384085737+0.j], + [ 0.33000459866554743+0.j, -0.43714638836388686+0.j, + -0.6524649518290251 +0.j, 0.5237265380279561 +0.j], + [ 0.545832745943757 +0.j, -0.04546002040802424-0.j, + -0.31547635975648136+0.j, -0.774903004533341 +0.j], + [ 0.7616608932219662 +0.j, 0.346226347547838 +0.j, + 0.42780589044732925+0.j, 0.3420264903462419 +0.j]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("input")) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:5 = stablehlo.custom_call @lapack_zgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor<4xcomplex>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + return %6, %10 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0bO\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02>\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x0b\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_zgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], + [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], + [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], + [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),), + expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -2.1337737e-06+0.j, + 1.8261760e-06+0.j], + [ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -6.0543999e-07+0.j, + 4.8744488e-07+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -6.5878328e-07+0.j, + 3.9895070e-07+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j, + 3.0199919e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5404726 +0.j, + -0.08882082 +0.j], + [ 0.3300045 +0.j, -0.4371462 +0.j, -0.6544272 +0.j, + 0.52127254 +0.j], + [ 0.54583293 +0.j, -0.045460045-0.j, -0.312564 +0.j, + -0.77608234 +0.j], + [ 0.76166105 +0.j, 0.34622625 +0.j, 0.42651838 +0.j, + 0.34363067 +0.j]], dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("input")) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:5 = stablehlo.custom_call @lapack_cgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor<4xcomplex>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + return %6, %10 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02\x1e\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\t\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_cgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]], dtype=float32),), + expected_outputs=(array([[ 3.2464233e+01, -1.3416398e+01, -1.6680369e-05, 4.0411728e-06], + [ 0.0000000e+00, -2.4642496e+00, -1.8640144e-06, 6.7429795e-07], + [ 0.0000000e+00, 0.0000000e+00, -7.2618576e-07, 3.9895073e-07], + [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.0443638e-07]], + dtype=float32), array([[-0.11417632 , 0.8288333 , -0.5413438 , 0.08334288 ], + [-0.33000442 , 0.43714583 , 0.65967286 , -0.5146185 ], + [-0.54583275 , 0.045459934, 0.30468878 , 0.7792079 ], + [-0.7616609 , -0.34622616 , -0.4230168 , -0.34793234 ]], + dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf32> loc("input")) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:6 = stablehlo.custom_call @lapack_sgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3) + return %6, %10 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\n\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_sgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]),), + expected_outputs=(array([[ 3.2464249196572979e+01, -1.3416407864998748e+01, + 4.7128510442204522e-15, -8.6687960588453852e-15], + [ 0.0000000000000000e+00, -2.4642491965729767e+00, + 1.8990547895861982e-15, -2.4680570671743780e-16], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.8780225147134376e-15, -7.2486619604759710e-16], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + 0.0000000000000000e+00, 4.8523923435746521e-16]]), array([[-0.1141764513873386 , 0.8288327563197505 , 0.5401360966805397 , + 0.09084600741204968], + [-0.3300045986655475 , 0.43714638836388714, -0.6524688462214561 , + -0.5237216863090944 ], + [-0.5458327459437569 , 0.04546002040802441, -0.31547059759870844, + 0.774905350382041 ], + [-0.7616608932219663 , -0.34622634754783793, 0.4278033471396243 , + -0.3420296714849957 ]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf64> loc("input")) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:6 = stablehlo.custom_call @lapack_dgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3) + return %6, %10 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\x1a\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_dgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py new file mode 100644 index 000000000000..9e245052e03a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py @@ -0,0 +1,844 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, complex64 + +data_2024_09_03 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zhetrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[-1.6782909868280393 +0.j , + -0.44670237330570184+4.847000766107959j , + 2.05945450900321 -2.2848432268240106j , + -1.852046418980849 +1.672382006137275j ], + [ 8.516713699516982 +0.j , + -2.7881860505313174 +0.j , + 0.9238284715039695 -2.3790501284019947j , + 0.5005102262291599 -1.30066052934836j ], + [-0.12132810525381293-0.2963030371159077j , + -3.6374350042782893 +0.j , + 0.5605752523031344 +0.j , + -2.9865099107523174 +0.5492956557924651j ], + [-0.40379248092949666-0.7813328344426929j , + -0.07101654492399719-0.27208840961051617j, + -7.4654253782049285 +0.j , + -8.172380353916964 +0.j ]], + + [[-3.996403598623405 +0.j , + 0.59408630943699 +2.531609474375295j , + -1.789098034543644 -2.538389274566601j , + -1.291106590337488 +3.1576544511573843j ], + [10.8950662522622 +0.j , + -2.8151642043836693 +0.j , + 6.18998567202382 +1.1866537964613415j , + 3.1900218245393352 +2.7291222716752372j ], + [-0.3142889671188478 -0.37781876498252764j, + 3.049208563595754 +0.j , + -2.4383044880335487 +0.j , + 4.075435464493341 -0.6653616942280807j ], + [ 0.32757687545025194+0.565870910342534j , + 0.8177026465997795 -0.15906305615104555j, + 3.3415143060767125 +0.j , + 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, + -8.172380353916964 ], + [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, + 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], + [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, + 1.834630852474663 +0.18575551495730305j, + 1.981584368497257 +0.19102912741736966j], + [1.0365789616521406-0.40942548304121656j, + 1.0872592163018966-0.3187050677167622j , + 1.0458498304770472-0.9989483435319496j ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_zhetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo/O/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\x12\x10\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\x0b\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_zhetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_chetrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , + 7.367708 +0.88518727j , -8.659938 +1.6132793j ], + [-6.9206004 +0.j , -3.6362798 +0.j , + 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], + [ 0.64957 +0.060723424j, 6.620491 +0.j , + 0.2882607 +0.j , -1.0288142 +1.8544064j ], + [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , + -4.431866 +0.j , 2.364208 +0.j ]], + + [[-4.1803885 +0.j , 0.5670845 +0.6913016j , + 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], + [ 8.33625 +0.j , 2.6144838 +0.j , + -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], + [ 0.019031923+0.17462212j , 2.7034955 +0.j , + -0.70924187 +0.j , 2.7962255 +1.5316825j ], + [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , + 6.6364865 +0.j , -1.698973 +0.j ]]], + dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], + [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], + dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], + [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, + 1.5772758-0.8165493j ], + [1.9152443-0.1834492j , 1.1593437+0.55631363j, + 1.6889225-0.724835j ]], dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_chetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo//\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\xe2\x0b\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\xc0\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\t\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_chetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssytrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], + [-2.985257 , -5.571 , -0.22652794, -0.83806676], + [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], + [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], + + [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], + [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], + [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], + [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], + dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], + [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], + [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], + [1.6288393, 1.8669801, 0. ]], dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_ssytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) + return %2 : tensor<2x4x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02b\t\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\t)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_ssytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsytrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , + 0.8082445002373937 , -1.551980329390836 ], + [-2.629505060186711 , 4.427374205796291 , + -2.2111093161901074 , 7.552489598405787 ], + [ 0.2269453213819231 , 0.3650586474106988 , + -3.5933639667756205 , 4.828829679372501 ], + [-0.6415372293575187 , -0.2519326897319508 , + -1.7607827845801751 , -3.381311711243865 ]], + + [[-4.000421911405985 , 3.6303350337601055 , + 2.8066821235532355 , 1.099224389184342 ], + [-4.141622408467332 , -5.276404169116551 , + -0.8496056221591237 , -2.275319346221659 ], + [ 0.5828958067901202 , 0.9351254869793256 , + 2.7765603683442177 , -4.339686212557215 ], + [-0.6391146585297987 , 0.3129920702652711 , + -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, + -3.381311711243865 ], + [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, + -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], + [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], + [1.1440109149169537, 1.8215532880266878, 0. ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_dsytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<128xf64>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) + return %2 : tensor<2x4x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02r\x0b\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\x0b)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_dsytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_12_01 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zhetrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[-1.6782909868280393 +0.j , + -0.44670237330570184+4.847000766107959j , + 2.05945450900321 -2.2848432268240106j , + -1.852046418980849 +1.672382006137275j ], + [ 8.516713699516982 +0.j , + -2.7881860505313174 +0.j , + 0.9238284715039695 -2.3790501284019947j , + 0.5005102262291599 -1.30066052934836j ], + [-0.12132810525381293-0.2963030371159077j , + -3.6374350042782893 +0.j , + 0.5605752523031344 +0.j , + -2.9865099107523174 +0.5492956557924651j ], + [-0.40379248092949666-0.7813328344426929j , + -0.07101654492399719-0.27208840961051617j, + -7.4654253782049285 +0.j , + -8.172380353916964 +0.j ]], + + [[-3.996403598623405 +0.j , + 0.59408630943699 +2.531609474375295j , + -1.789098034543644 -2.538389274566601j , + -1.291106590337488 +3.1576544511573843j ], + [10.8950662522622 +0.j , + -2.8151642043836693 +0.j , + 6.18998567202382 +1.1866537964613415j , + 3.1900218245393352 +2.7291222716752372j ], + [-0.3142889671188478 -0.37781876498252764j, + 3.049208563595754 +0.j , + -2.4383044880335487 +0.j , + 4.075435464493341 -0.6653616942280807j ], + [ 0.32757687545025194+0.565870910342534j , + 0.8177026465997795 -0.15906305615104555j, + 3.3415143060767125 +0.j , + 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, + -8.172380353916964 ], + [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, + 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], + [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, + 1.834630852474663 +0.18575551495730305j, + 1.981584368497257 +0.19102912741736966j], + [1.0365789616521406-0.40942548304121656j, + 1.0872592163018966-0.3187050677167622j , + 1.0458498304770472-0.9989483435319496j ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_zhetrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf5\x99G\x011\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03i\x0f\x0b\x0b\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0boO/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03C\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0f\x07\x07\x13\x0b\x1b\x07\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02\x9a\x0f\x1d\x1d\t\x1f\x1d!\t\x1d-\t\x17\x1f\xde\n\x1b\x1d#\t\x1d/\t\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'\x81\x05\'\x1d+\t\x05)\x05+\x05-\x1f5\x01\x1d/\x1d1\x1d3\x03\x07;;;\r\x0335\x03\x03;\x1d5\x1f\x17\t\x00\x00\x00\x00\t\x07\x07\x01\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fA!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1d;\x1d=\x1f?1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\taeim\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\r\x057o35\x1dE\x1dG\x1dI#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\r\x03\x83\x85\x1dK\x13=L\x0b\x03\x1dM\x1dO\x05\x01\x03\x03W\x03\x03\x93\x15\x03\x01\x01\x01\x03\x0bWKKK\x97\x1fC\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x1f)\x05\t\x05\x13)\x05\t\x11\x19)\x05\t\r\x19)\x05\t\r\x1f)\x01\x1f)\x01\x19\x01)\x03\t\')\x01\'\x0b\x1d)\x03\t\x13\x03\x19)\x07\t\x05\x05\x13\x13)\x05\t\r\x13\x1b\x11\x01\t\x05\t\x0b\r\x11\x07!\x05\x0f\x03\x05\x11\x07\x07\t\x11\x03\t\x11\x07\x07\x0b\x11\x03\x0b\x11\x07\x07\r\x0f\x03\r)\x03\t\x1b)\x03\x01\x1b)\x05\t\x11\x13)\x07\t\x11\x11\x13)\x03\r\x1b!)\x03\r#)\x03\t#)\x03\x05#)\x03\x05\x1b\x04\xbe\x06\x05\x01Q\x03\x11\x01\x07\x04\x96\x06\x03\x01\x15\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\x05\x11G)%\x07\x0b\x05\t\x0b\r\x15\x03\x01\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\r\rF\x07\r\x03\x1d\x05\x0b\x0f\x03F\r\x0f\x03!\x03\x11\x05B\x03\x11\x03\x0f\x0fF\x01\x13\x03\x05\x07\x13\x03\x15\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\x19\rF\x07\r\x03\x1d\x05\x0b\x1b\x03F\r\x0f\x03\x07\x03\x1d\x05B\x03\x15\x03\x11\x0fF\x01\x17\x03\t\x07\x1f\x05!\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03%\rF\x07\r\x03\x1d\x05\x0b\'\x03F\r\x0f\x03\x07\x03)\x05B\x03\x15\x03\x11\x0fF\x01\x19\x03\x0b\x07+\x07-\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x031\rF\x07\r\x03\x1d\x05\x0b3\x03F\r\x0f\x03\x07\x035\x05B\x03\x11\x03\x0f\x0fF\x01\x1b\x03\r\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x1d\x07\x04S\x03\r\x13\x07C\x01\x0b\x01\x1f\x01\x00\x03F\x05\x1f\x039\x03\x01\x03F\x05\x0b\x03\x05\x03\x05\x0b\x06\x0b\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x0f\x01\x13\x01#\x01\x00\x03F\x05#\x037\x03\x01\x03F\x05\x0b\x03\t\x03\x05\x0b\x06\x0b\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x0f\x01\x17\x01#\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\x0b\x03\x05\x0b\x06\x0b\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\'\x07\x04S\x03\r\x13\x07\x0f\x01\x1b\x01\x1f\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\r\x03\x05\x0b\x06\x0b\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x12\nQ%\x03\x0b\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/A)Sci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_zhetrd_ffi\x00\x08\x8d)\x057\x01\x0bM]_qs\x03\x7f\x11\x87\x89\x8bM\x8d\x8f\x91\x95\x03A\x031\x05CE\x03G\x03Y\x03O\x03[\x03Q\x03S\x03U\x0b9u=O?\x03}\x0b9w=Q?\x03I\x0b9y=S?\x0b9{=U?', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_chetrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , + 7.367708 +0.88518727j , -8.659938 +1.6132793j ], + [-6.9206004 +0.j , -3.6362798 +0.j , + 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], + [ 0.64957 +0.060723424j, 6.620491 +0.j , + 0.2882607 +0.j , -1.0288142 +1.8544064j ], + [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , + -4.431866 +0.j , 2.364208 +0.j ]], + + [[-4.1803885 +0.j , 0.5670845 +0.6913016j , + 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], + [ 8.33625 +0.j , 2.6144838 +0.j , + -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], + [ 0.019031923+0.17462212j , 2.7034955 +0.j , + -0.70924187 +0.j , 2.7962255 +1.5316825j ], + [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , + 6.6364865 +0.j , -1.698973 +0.j ]]], + dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], + [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], + dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], + [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, + 1.5772758-0.8165493j ], + [1.9152443-0.1834492j , 1.1593437+0.55631363j, + 1.6889225-0.724835j ]], dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_chetrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf5\x99G\x011\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03i\x0f\x0b\x0b\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0bo/\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03C\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0f\x07\x07\x13\x0b\x1b\x07\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02j\x0b\x1d\x1d\t\x1f\x1d!\t\x1d-\t\x17\x1f\xde\n\x1b\x1d#\t\x1d/\t\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'\x81\x05\'\x1d+\t\x05)\x05+\x05-\x1f5\x01\x1d/\x1d1\x1d3\x03\x07;;;\r\x0335\x03\x03;\x1d5\x1f\x17\t\x00\x00\x00\x00\t\x07\x07\x01\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fA!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1d;\x1d=\x1f?1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x11\t\x00\x00\xc0\x7f#)\x03\taeim\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\r\x057o35\x1dE\x1dG\x1dI#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\r\x03\x83\x85\x1dK\x13=L\x0b\x03\x1dM\x1dO\x05\x01\x03\x03W\x03\x03\x93\x15\x03\x01\x01\x01\x03\x0bWKKK\x97\x1fC\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x1f)\x05\t\x05\x13)\x05\t\x11\x19)\x05\t\r\x19)\x05\t\r\x1f)\x01\x1f)\x01\x19\x01)\x03\t\')\x01\'\t\x1d)\x03\t\x13\x03\x19)\x07\t\x05\x05\x13\x13)\x05\t\r\x13\x1b\x11\x01\t\x05\t\x0b\r\x11\x07!\x05\x0f\x03\x05\x11\x07\x07\t\x11\x03\t\x11\x07\x07\x0b\x11\x03\x0b\x11\x07\x07\r\x0f\x03\r)\x03\t\x1b)\x03\x01\x1b)\x05\t\x11\x13)\x07\t\x11\x11\x13)\x03\r\x1b!)\x03\r#)\x03\t#)\x03\x05#)\x03\x05\x1b\x04\xbe\x06\x05\x01Q\x03\x11\x01\x07\x04\x96\x06\x03\x01\x15\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\x05\x11G)%\x07\x0b\x05\t\x0b\r\x15\x03\x01\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\r\rF\x07\r\x03\x1d\x05\x0b\x0f\x03F\r\x0f\x03!\x03\x11\x05B\x03\x11\x03\x0f\x0fF\x01\x13\x03\x05\x07\x13\x03\x15\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\x19\rF\x07\r\x03\x1d\x05\x0b\x1b\x03F\r\x0f\x03\x07\x03\x1d\x05B\x03\x15\x03\x11\x0fF\x01\x17\x03\t\x07\x1f\x05!\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03%\rF\x07\r\x03\x1d\x05\x0b\'\x03F\r\x0f\x03\x07\x03)\x05B\x03\x15\x03\x11\x0fF\x01\x19\x03\x0b\x07+\x07-\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x031\rF\x07\r\x03\x1d\x05\x0b3\x03F\r\x0f\x03\x07\x035\x05B\x03\x11\x03\x0f\x0fF\x01\x1b\x03\r\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x1d\x07\x04S\x03\r\x13\x07C\x01\x0b\x01\x1f\x01\x00\x03F\x05\x1f\x039\x03\x01\x03F\x05\x0b\x03\x05\x03\x05\x0b\x06\x0b\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x0f\x01\x13\x01#\x01\x00\x03F\x05#\x037\x03\x01\x03F\x05\x0b\x03\t\x03\x05\x0b\x06\x0b\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x0f\x01\x17\x01#\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\x0b\x03\x05\x0b\x06\x0b\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\'\x07\x04S\x03\r\x13\x07\x0f\x01\x1b\x01\x1f\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\r\x03\x05\x0b\x06\x0b\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x12\nQ%\x03\x0b\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/A)Sci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_chetrd_ffi\x00\x08\x8d)\x057\x01\x0bM]_qs\x03\x7f\x11\x87\x89\x8bM\x8d\x8f\x91\x95\x03A\x031\x05CE\x03G\x03Y\x03O\x03[\x03Q\x03S\x03U\x0b9u=O?\x03}\x0b9w=Q?\x03I\x0b9y=S?\x0b9{=U?', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssytrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], + [-2.985257 , -5.571 , -0.22652794, -0.83806676], + [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], + [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], + + [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], + [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], + [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], + [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], + dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], + [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], + [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], + [1.6288393, 1.8669801, 0. ]], dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_ssytrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) + return %2 : tensor<2x4x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe7\x93?\x011\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03c\x0f\x0b\x0b\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bOo\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03;\x17\x0f\x1b\x17\x17\x07\x13\x0f\x07\x07\x13\x1b\x07\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02\xea\x08\x1d\x1d\x07\x1f\x1d-\x07\x17\x1f\xde\n\x1b\x1d!\x07\x1d/\x07\x1d#\x07\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'{\x05\'\x1d+\x07\x05)\x05+\x05-\x1f-\x01\x1d/\x1d1\x1d3\r\x0335\x1f\x13\t\x00\x00\x00\x00\t\x07\x07\x01\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\xc0\x7f\x03\x07999\x03\x039\x1d5\x1d7\x1f9!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d9\x1d;\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f71\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t]aei\r\x057_35\x1d=\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\x1dE\x1dG###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\r\x03}\x7f\x1dI\x135L\x0b\x03\x1dK\x1dM\x05\x01\x03\x03W\x03\x03\x8d\x15\x03\x01\x01\x01\x03\x0bWMMM\x91\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\t\r\x15)\x01\x15)\x07\t\x11\x11\x15)\x05\t\x11\x15)\x05\t\x05\x0f\x01)\x03\t\x1f)\x01\x1f\t\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f\x13\x1b\x11\x01\t\t\x0b\x05\x05\x11\x07\x1b\t\x07\x03\t\x11\x07\r\x0b\x07\x03\x0b\x11\x07\r\x05\x07\x03\x05)\x05\t\r\x0f)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x0f)\x07\t\x11\x11\x0f)\x03\r\x17!)\x03\r\x1d)\x03\t\x1d)\x03\x05\x1d)\x03\x05\x17\x04\xfe\x05\x05\x01Q\x03\x11\x01\x07\x04\xd6\x05\x03\x01\x11\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\t\x11G)%\x07\x0b\t\x0b\x05\x05\x11\x03\x01\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\r\x0bF\x05\r\x03\x19\x05\x0b\x0f\x03F\x0b\x0f\x03\x1b\x03\x11\x05B\x03\x11\x03\x07\rF\x01\x13\x03\t\x07\x13\x03\x15\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\x19\x0bF\x05\r\x03\x19\x05\x0b\x1b\x03F\x0b\x0f\x03\r\x03\x1d\x05B\x03\x11\x03\x07\rF\x01\x15\x03\x0b\x07\x1f\x05!\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03%\x0bF\x05\r\x03\x19\x05\x0b\'\x03F\x0b\x0f\x03\r\x03)\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x07+\x07-\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x031\x0bF\x05\r\x03\x19\x05\x0b3\x03F\x0b\x0f\x03\r\x035\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x19\x07\x04S\x03\r\x13\x077\x01\x13\x01\x0f\x01\x00\x03F\t\x1b\x031\x03\x01\x03F\t\x0b\x03\t\x03\x05\x0f\x06\r\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\x1d\x07\x04S\x03\r\x13\x07\x1b\x01\x17\x01\x0f\x01\x00\x03F\t\x1f\x03/\x03\x01\x03F\t\x0b\x03\x0b\x03\x05\x0f\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x1b\x01\x0b\x01\x0f\x01\x00\x03F\t\x1f\x03)\x03\x01\x03F\t\x0b\x03\x05\x03\x05\x0f\x06\r\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\xea\tO%\x03\x0b\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/A)Sci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_ssytrd_ffi\x00\x08y#\x057\x01\x0bOY[mo\x03y\x11\x81\x83\x85O\x87\x89\x8b\x8f\x03;\x031\x05=?\x03A\x03C\x03Q\x03S\x03K\x0bEqGQI\x03w\x0bEsGSI\x03U\x0bEuGKI', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsytrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , + 0.8082445002373937 , -1.551980329390836 ], + [-2.629505060186711 , 4.427374205796291 , + -2.2111093161901074 , 7.552489598405787 ], + [ 0.2269453213819231 , 0.3650586474106988 , + -3.5933639667756205 , 4.828829679372501 ], + [-0.6415372293575187 , -0.2519326897319508 , + -1.7607827845801751 , -3.381311711243865 ]], + + [[-4.000421911405985 , 3.6303350337601055 , + 2.8066821235532355 , 1.099224389184342 ], + [-4.141622408467332 , -5.276404169116551 , + -0.8496056221591237 , -2.275319346221659 ], + [ 0.5828958067901202 , 0.9351254869793256 , + 2.7765603683442177 , -4.339686212557215 ], + [-0.6391146585297987 , 0.3129920702652711 , + -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, + -3.381311711243865 ], + [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, + -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], + [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], + [1.1440109149169537, 1.8215532880266878, 0. ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_dsytrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) + return %2 : tensor<2x4x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe7\x93?\x011\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03c\x0f\x0b\x0b\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bOo\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03;\x17\x0f\x1b\x17\x17\x07\x13\x0f\x07\x07\x13\x1b\x07\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02\xfa\n\x1d\x1d\x07\x1f\x1d-\x07\x17\x1f\xde\n\x1b\x1d!\x07\x1d/\x07\x1d#\x07\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'{\x05\'\x1d+\x07\x05)\x05+\x05-\x1f-\x01\x1d/\x1d1\x1d3\r\x0335\x1f\x13\t\x00\x00\x00\x00\t\x07\x07\x01\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07999\x03\x039\x1d5\x1d7\x1f9!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d9\x1d;\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f71\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t]aei\r\x057_35\x1d=\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\x1dE\x1dG###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\r\x03}\x7f\x1dI\x135L\x0b\x03\x1dK\x1dM\x05\x01\x03\x03W\x03\x03\x8d\x15\x03\x01\x01\x01\x03\x0bWMMM\x91\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\t\r\x15)\x01\x15)\x07\t\x11\x11\x15)\x05\t\x11\x15)\x05\t\x05\x0f\x01)\x03\t\x1f)\x01\x1f\x0b\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f\x13\x1b\x11\x01\t\t\x0b\x05\x05\x11\x07\x1b\t\x07\x03\t\x11\x07\r\x0b\x07\x03\x0b\x11\x07\r\x05\x07\x03\x05)\x05\t\r\x0f)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x0f)\x07\t\x11\x11\x0f)\x03\r\x17!)\x03\r\x1d)\x03\t\x1d)\x03\x05\x1d)\x03\x05\x17\x04\xfe\x05\x05\x01Q\x03\x11\x01\x07\x04\xd6\x05\x03\x01\x11\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\t\x11G)%\x07\x0b\t\x0b\x05\x05\x11\x03\x01\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\r\x0bF\x05\r\x03\x19\x05\x0b\x0f\x03F\x0b\x0f\x03\x1b\x03\x11\x05B\x03\x11\x03\x07\rF\x01\x13\x03\t\x07\x13\x03\x15\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\x19\x0bF\x05\r\x03\x19\x05\x0b\x1b\x03F\x0b\x0f\x03\r\x03\x1d\x05B\x03\x11\x03\x07\rF\x01\x15\x03\x0b\x07\x1f\x05!\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03%\x0bF\x05\r\x03\x19\x05\x0b\'\x03F\x0b\x0f\x03\r\x03)\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x07+\x07-\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x031\x0bF\x05\r\x03\x19\x05\x0b3\x03F\x0b\x0f\x03\r\x035\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x19\x07\x04S\x03\r\x13\x077\x01\x13\x01\x0f\x01\x00\x03F\t\x1b\x031\x03\x01\x03F\t\x0b\x03\t\x03\x05\x0f\x06\r\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\x1d\x07\x04S\x03\r\x13\x07\x1b\x01\x17\x01\x0f\x01\x00\x03F\t\x1f\x03/\x03\x01\x03F\t\x0b\x03\x0b\x03\x05\x0f\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x1b\x01\x0b\x01\x0f\x01\x00\x03F\t\x1f\x03)\x03\x01\x03F\t\x0b\x03\x05\x03\x05\x0f\x06\r\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\xea\tO%\x03\x0b\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/A)Sci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_dsytrd_ffi\x00\x08y#\x057\x01\x0bOY[mo\x03y\x11\x81\x83\x85O\x87\x89\x8b\x8f\x03;\x031\x05=?\x03A\x03C\x03Q\x03S\x03K\x0bEqGQI\x03w\x0bEsGSI\x03U\x0bEuGKI', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 70826eec8806..5d5e95b5cb9a 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -294,7 +294,7 @@ def serialize(self, args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) exported = export.export( jax.jit(func), - lowering_platforms=(self.default_jax_backend(),), + platforms=(self.default_jax_backend(),), disabled_checks=tuple( export.DisabledSafetyCheck.custom_call(target) for target in allow_unstable_custom_call_targets) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 1c03158953f0..48c645c4d033 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -1058,7 +1058,8 @@ def _make_broadcast_in_dim_harness(name, lax.broadcast_in_dim_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{outshape=}_broadcastdimensions={broadcast_dimensions}", lambda operand: lax.broadcast_in_dim_p.bind( - operand, shape=outshape, broadcast_dimensions=broadcast_dimensions), + operand, shape=outshape, broadcast_dimensions=broadcast_dimensions, + sharding=None), [RandArg(shape, dtype)], shape=shape, dtype=dtype, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index f1f46a5c18f7..3f6c5ee5b043 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -29,18 +29,16 @@ from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( - add_jaxvals, replace_internal_symbolic_zeros, zeros_like_jaxval, + add_jaxvals, replace_internal_symbolic_zeros, replace_rule_output_symbolic_zeros, Zero, zeros_like_aval) from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 from jax._src.api_util import flatten_fun, flatten_fun_nokwargs -from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, - raise_to_shaped) +from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, partition_list) - zip = safe_zip map = safe_map def identity(x): return x @@ -69,54 +67,101 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate, transform_stack), aux - -@lu.transformation -def jvpfun(instantiate, transform_stack, primals, tangents): +@lu.transformation2 +def jvpfun(f, instantiate, transform_stack, primals, tangents): + tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) - with core.new_main(JVPTrace) as main, ctx: - out_primals, out_tangents = yield (main, primals, tangents), {} - del main + with ctx: + out_primals, out_tangents = f(tag, primals, tangents) if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] - yield out_primals, out_tangents - -@lu.transformation -def jvp_subtrace(main, primals, tangents): - trace = JVPTrace(main, core.cur_sublevel()) - for x in list(primals) + list(tangents): - if isinstance(x, Tracer): - if x._trace.level >= trace.level: - raise core.escaped_tracer_error( - x, f"Tracer from a higher level: {x} in trace {trace}") - assert x._trace.level < trace.level - in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x - for x, t in zip(primals, tangents)] - ans = yield in_tracers, {} - out_tracers = map(trace.full_raise, ans) - yield unzip2([(out_tracer.primal, out_tracer.tangent) - for out_tracer in out_tracers]) - -@lu.transformation_with_aux -def jvp_subtrace_aux(main, primals, tangents): - trace = JVPTrace(main, core.cur_sublevel()) - for x in list(primals) + list(tangents): - if isinstance(x, Tracer): - assert x._trace.level < trace.level - ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {} - ans_tracers = map(trace.full_raise, ans) - out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers) - aux_primals = [core.full_lower(x.primal) - if isinstance(x, JVPTracer) and x._trace.level == trace.level - else x for x in aux] - yield (out_primals, out_tangents), aux_primals - + return out_primals, out_tangents + +@lu.transformation2 +def jvp_subtrace(f, tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = JVPTrace(parent_trace, tag) + in_tracers = [maybe_jvp_tracer(trace, x, t) + for x, t in zip(primals, tangents)] + with core.set_current_trace(trace): + ans = f(*in_tracers) + out = unzip2(map(trace.to_primal_tangent_pair, ans)) + return out + +@lu.transformation_with_aux2 +def jvp_subtrace_aux(f, store, tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = JVPTrace(parent_trace, tag) + with core.set_current_trace(trace): + ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents))) + out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) + aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag + else x for x in aux] + store.store(aux_primals) + return out_primals, out_tangents + +def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: + dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars)) + return core.Jaxpr(constvars=(), + invars=jaxpr.invars + jaxpr.constvars, + outvars=jaxpr.outvars, eqns=jaxpr.eqns, + effects=jaxpr.effects, debug_info=dbg) + +def linearize_jaxpr(jaxpr, nonzeros): + primal_trace = pe.DynamicJaxprTrace() + tangent_trace = pe.DynamicJaxprTrace() + lin_trace = LinearizeTrace(primal_trace, tangent_trace) + + def new_arg(primal_aval, nz): + primal = primal_trace.new_arg(primal_aval) + tangent_aval = primal_aval.to_tangent_aval() + tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) + return LinearizeTracer(lin_trace, primal, tangent) + + tracers = [new_arg(v.aval, nz) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] + with core.set_current_trace(lin_trace): + ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers) + + out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) + nzs_out = [type(t) is not Zero for t in out_tangents] + out_tangents = [tangent_trace.to_jaxpr_tracer(t) + for (nz, t) in zip(nzs_out, out_tangents) if nz] + tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + del attrs_tracked # TODO: attrs + residuals_and_primals = (*tangent_consts, *out_primals) + primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals) + num_residuals = len(tangent_consts) + tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) + del attrs_tracked # TODO: attrs + return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr + +def direct_linearize(traceable, *primals, **kwargs): + has_aux = kwargs.pop('has_aux', False) + assert not has_aux + with core.take_current_trace() as parent_trace: + tangent_trace = pe.DynamicJaxprTrace() + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] + linearize_trace = LinearizeTrace(parent_trace, tangent_trace) + tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] + with core.set_current_trace(linearize_trace): + ans = traceable.call_wrapped(*tracers) + + out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) + out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] + del attrs_tracked # TODO: attrs + return out_primals, out_tangents_pvals, jaxpr, consts def linearize(traceable, *primals, **kwargs): + if config.use_direct_linearize.value: + return direct_linearize(traceable, *primals, **kwargs) has_aux = kwargs.pop('has_aux', False) if not has_aux: jvpfun = jvp(traceable) @@ -130,7 +175,11 @@ def linearize(traceable, *primals, **kwargs): jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) - assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) + if any(not out_primal_pval.is_known() for out_primal_pval in out_primals_pvals): + raise ValueError( + "Linearization failed to produce known values for all output primals. " + "This is typically caused by attempting to differentiate a function " + "uses an operation that does not support reverse-mode autodiff.") out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts @@ -166,7 +215,6 @@ def unpair_pval(pval): aval_1, aval_2 = aval return (aval_1, const_1), (aval_2, const_2) - # NOTE: The FIXMEs below are caused by primal/tangent mixups (type # errors if you will) def backward_pass(jaxpr: core.Jaxpr, transform_stack, @@ -219,6 +267,20 @@ def write_primal(v, val): with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: + if eqn.primitive.ref_primitive: + if eqn.primitive is core.mutable_array_p: + val_var, = eqn.invars + ref_var, = eqn.outvars + ref = read_primal(ref_var) + ct_out = core.freeze(ref) + write_cotangent(eqn.primitive, val_var, ct_out) + elif eqn.primitive is core.freeze_p: + val_var, = eqn.outvars # type: ignore + ref_var, = eqn.invars # type: ignore + ct_in = instantiate_zeros(read_cotangent(val_var)) + write_primal(ref_var, core.mutable_array(ct_in)) + continue + invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) @@ -233,9 +295,6 @@ def write_primal(v, val): call_jaxpr = params.pop('call_jaxpr') cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals) - elif eqn.primitive in reducing_transposes: - cts_out = reducing_transposes[eqn.primitive]( - cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) @@ -274,44 +333,48 @@ def get_primitive_transpose(p): "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err -@lu.transformation_with_aux -def nonzero_tangent_outputs(*args, **kwargs): - results = (_, tangents_out) = yield args, kwargs - yield results, [type(r) is not Zero for r in tangents_out] +@lu.transformation_with_aux2 +def nonzero_tangent_outputs(f, store, *args, **kwargs): + results = (_, tangents_out) = f(*args, **kwargs) + store.store([type(r) is not Zero for r in tangents_out]) + return results class JVPTrace(Trace): + def __init__(self, parent_trace, tag): + self.tag = tag + self.parent_trace = parent_trace - def pure(self, val): - tangent_zero = Zero.from_primal_value(val) - return JVPTracer(self, val, tangent_zero) - - def lift(self, val): - tangent_zero = Zero.from_primal_value(val) - return JVPTracer(self, val, tangent_zero) - - def sublift(self, val): - return JVPTracer(self, val.primal, val.tangent) + def to_primal_tangent_pair(self, val): + if isinstance(val, JVPTracer) and val._trace.tag is self.tag: + return (val.primal, val.tangent) + else: + tangent_zero = Zero.from_primal_value(val) + return (val, tangent_zero) def process_primitive(self, primitive, tracers, params): - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return primitive.bind_with_trace(self.parent_trace, primals_in, params) jvp = primitive_jvps.get(primitive) if not jvp: msg = f"Differentiation rule for '{primitive}' not implemented" raise NotImplementedError(msg) - primal_out, tangent_out = jvp(primals_in, tangents_in, **params) + with core.set_current_trace(self.parent_trace): + primal_out, tangent_out = jvp(primals_in, tangents_in, **params) + if primitive.multiple_results: - return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)] + return [maybe_jvp_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] else: - return JVPTracer(self, primal_out, tangent_out) + return maybe_jvp_tracer(self, primal_out, tangent_out) def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results - primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) + primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not Zero for t in tangents] tangents = [t if type(t) is not Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) - f_jvp = jvp_subtrace(f, self.main) + f_jvp = jvp_subtrace(f, self.tag) f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] @@ -328,76 +391,59 @@ def new_out_axes_thunk(): f_jvp, out_tree = traceable(f_jvp, in_tree) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, which_nz) if update_params else params - result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), - *args, **new_params) + fun_and_args = (_update_annotation(f_jvp, f.in_type, which_nz),) + tuple(args) + result = call_primitive.bind_with_trace(self.parent_trace, fun_and_args, new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) tangent_out = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primal_out, tangent_out)] - return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] - - def post_process_call(self, call_primitive, out_tracers, params): - primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) - out, treedef = tree_flatten((primals, tangents)) - tangents_nz = [type(t) is not Zero for t in tangents] - del primals, tangents - main = self.main - def todo(x): - primals, tangents = tree_unflatten(treedef, x) - trace = JVPTrace(main, core.cur_sublevel()) - return map(partial(JVPTracer, trace), primals, tangents) - if call_primitive.map_primitive: - def out_axes_transform(out_axes): - return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz)) - todo = (todo, out_axes_transform) - return out, todo + return [maybe_jvp_tracer(self, p, t) for p, t in zip(primal_out, tangent_out)] # The only difference between process_map and process_call is that # the `in_axes` and `out_axes_thunk` params must be updated; # that's handled in process_call. process_map = process_call - post_process_map = post_process_call - def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - primals_in = map(core.full_lower, primals_in) - if not symbolic_zeros: - tangents_in = map(instantiate_zeros, tangents_in) - else: - tangents_in = map(replace_internal_symbolic_zeros, tangents_in) - outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) + def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros): + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in), + dict(symbolic_zeros=symbolic_zeros)) + with core.set_current_trace(self.parent_trace): + if not symbolic_zeros: + tangents_in = map(instantiate_zeros, tangents_in) + else: + tangents_in = map(replace_internal_symbolic_zeros, tangents_in) + outs = f_jvp.call_wrapped(*(tuple(primals_in) + tuple(tangents_in))) + primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) - return map(partial(JVPTracer, self), primals_out, tangents_out) - - def post_process_custom_jvp_call(self, out_tracers, _): - raise CustomJVPException() + return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) - def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): - # Local import to prevent an import cycle. - from jax._src.lax import lax - - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - fwd_in = [(core.full_lower(p), type(t) is not Zero) - for p, t in zip(primals_in, tangents_in)] + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return prim.bind_with_trace(self.parent_trace, + (fun, fwd, bwd, *primals_in), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) + fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] fwd_in = [x for pair in fwd_in for x in pair] # flatten - res_and_primals_out = fwd.call_wrapped(*fwd_in) + with core.set_current_trace(self.parent_trace): + res_and_primals_out = fwd.call_wrapped(*fwd_in) + _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! - tangents_in = map(instantiate_zeros, tangents_in) - tangents_out = custom_lin_p.bind( + with core.set_current_trace(self.parent_trace): + tangents_in = map(instantiate_zeros, tangents_in) + tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - return map(partial(JVPTracer, self), primals_out, tangents_out) - - def post_process_custom_vjp_call(self, out_tracers, _): - raise CustomVJPException() + return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def process_custom_transpose(self, prim, call, tracers, **params): - ps_in, ts_in = unzip2((t.primal, t.tangent) for t in tracers) + ps_in, ts_in = unzip2(map(self.to_primal_tangent_pair, tracers)) res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves]) res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves]) @@ -421,24 +467,18 @@ def process_custom_transpose(self, prim, call, tracers, **params): raise NotImplementedError( 'JVP of custom transpose with respect to non-symbolic-zero residuals') - ps_out = prim.bind(call, *ps_in, **params) + with core.set_current_trace(self.parent_trace): + ps_out = prim.bind(call, *ps_in, **params) + lin_ts_in = map(instantiate_zeros, lin_ts_in) + ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params) - lin_ts_in = map(instantiate_zeros, lin_ts_in) - ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params) - - return map(partial(JVPTracer, self), ps_out, ts_out) - - def join(self, xt, yt): - xz, yz = type(xt) is Zero, type(yt) is Zero - if xz == yz: - return xt, yt - elif yz and not xz: - return xt, zeros_like_jaxval(xt) - elif xz and not yz: - return zeros_like_jaxval(yt), yt - else: - raise TypeError((xt, yt)) + return map(partial(maybe_jvp_tracer, self), ps_out, ts_out) +def maybe_jvp_tracer(trace, primal, tangent): + if type(tangent) is Zero: + return primal + else: + return JVPTracer(trace, primal, tangent) class JVPTracer(Tracer): __slots__ = ['primal', 'tangent'] @@ -452,7 +492,6 @@ def __init__(self, trace, primal, tangent): @property def aval(self): - # TODO(dougalm): add epsilon ball return get_aval(self.primal) def full_lower(self): @@ -461,10 +500,13 @@ def full_lower(self): else: return self + def to_concrete_value(self): + return core.to_concrete_value(self.primal) + def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: - primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) - tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) + primal_aval = get_aval(primal).strip_weak_type() + tangent_aval = get_aval(tangent).strip_weak_type() assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) @@ -472,15 +514,94 @@ def _primal_tangent_shapes_match(primal, tangent): call_param_updaters: dict[core.Primitive, Callable] = {} call_transpose_param_updaters: dict[core.Primitive, Callable] = {} +# -------------------- Linearize trace -------------------- + +class LinearizeTrace(Trace): + + def __init__(self, parent_trace, tangent_trace, tag=None): + self.tag = core.TraceTag() if tag is None else tag + self.parent_trace = parent_trace + self.tangent_trace = tangent_trace + + def to_primal_tangent_pair(self, val): + if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag: + return (val.primal, val.tangent) + else: + tangent_zero = Zero.from_primal_value(val) + return (val, tangent_zero) + + def process_primitive(self, primitive, args, params): + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) + tangent_nonzeros = [type(t) is not Zero for t in tangents_in] + if all(type(t) is Zero for t in tangents_in): + return primitive.bind_with_trace(self.parent_trace, primals_in, params) + lin = primitive_linearizations.get(primitive) + if lin is None: + lin = partial(fallback_linearize_rule, primitive) + with core.set_current_trace(self.parent_trace): + primal_out, tangent_nonzeros_out, residuals, linearized = lin( + tangent_nonzeros, *primals_in, **params) + with core.set_current_trace(self.tangent_trace): + tangent_out = linearized(residuals, *tangents_in) + if primitive.multiple_results: + return [maybe_linearize_tracer(self, x, nz, t) + for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)] + else: + return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out) + +def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): + if is_nonzero: + assert not type(tangent) is Zero + return LinearizeTracer(trace, primal, tangent) + else: + assert type(tangent) is Zero + return primal + +def fallback_linearize_rule(prim, _, *args, **kwargs): + assert not prim.multiple_results + + def call_prim(*args_): + return [prim.bind(*args_, **kwargs)] + + with config.use_direct_linearize(False): + (out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize( + lu.wrap_init(call_prim), *args, **kwargs) + + def linearized(residuals, *tangents): + out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents) + return out_tangent + + return out_primal, True, consts, linearized + +class LinearizeTracer(Tracer): + __slots__ = ['primal', 'tangent'] + + def __init__(self, trace, primal, tangent): + if config.enable_checks.value: + _primal_tangent_shapes_match(primal, tangent) + self._trace = trace + self.primal = primal + self.tangent = tangent + + @property + def aval(self): + return get_aval(self.primal) + + def full_lower(self): + if type(self.tangent) is Zero: + return core.full_lower(self.primal) + else: + return self + + def to_concrete_value(self): + return core.to_concrete_value(self.primal) + # -------------------- Primitives -------------------- primitive_jvps : dict[core.Primitive, Callable] = {} - primitive_transposes: dict[core.Primitive, Callable] = {} -# transpose rules that internally perform reductions over the given named axes -reducing_transposes: dict[core.Primitive, Callable] = {} - +primitive_linearizations : dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) @@ -573,15 +694,16 @@ def zero_jvp(primitive, primals, tangents, **params): def instantiate_zeros(tangent): return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent -@lu.transformation_with_aux -def traceable(in_tree, *primals_and_tangents): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) tangents_out = [None if type(t) is Zero else t for t in tangents_out] out_flat, out_tree = tree_flatten((primals_out, tangents_out)) - yield out_flat, out_tree + store.store(out_tree) + return out_flat def call_transpose(primitive, params, call_jaxpr, args, ct, _): @@ -618,10 +740,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals): primitive_transposes[core.closed_call_p] = _closed_call_transpose -@lu.transformation_with_aux -def nonzero_outputs(*args, **kwargs): - results = yield args, kwargs - yield results, [type(r) is not Zero for r in results] +@lu.transformation_with_aux2 +def nonzero_outputs(f, store, *args, **kwargs): + results = f(*args, **kwargs) + store.store([type(r) is not Zero for r in results]) + return results def map_transpose(primitive, params, call_jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts @@ -685,17 +808,18 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() -@lu.transformation_with_aux -def f_jvp_traceable(nonzeros, *primals_and_nztangents): +@lu.transformation_with_aux2 +def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] - yield list(primals_out) + nonzero_tangents_out, out_nonzeros + store.store(out_nonzeros) + return list(primals_out) + nonzero_tangents_out def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) @@ -760,3 +884,6 @@ def __init__(self): "closed-over value into the custom_vjp function as an argument, and " "adapting the custom_vjp fwd and bwd rules.") super().__init__(msg) + +# TODO(mattjj): remove this vestigial dict +reducing_transposes: dict[core.Primitive, Callable] = {} diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index eb174cc5c052..f4658ec2be29 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -14,7 +14,7 @@ from __future__ import annotations import collections -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial from typing import Any, Union @@ -29,12 +29,12 @@ from jax._src.ad_util import (Zero, instantiate, SymbolicZero, replace_rule_output_symbolic_zeros, add_jaxvals, add_jaxvals_p) -from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName +from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) from jax._src.typing import Array -from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, +from jax._src.util import (unzip2, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) @@ -217,7 +217,7 @@ def __init__(self, a): self.a = a for d in a.shape)) if type(a) is core.DShapedArray else a for a, e in orig_type if e] - new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens] + new_avals = [core.get_aval(s) for s in segment_lens] sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size for a, d in zip(avals, explicit_in_dims): if isinstance(d, RaggedAxis): @@ -264,10 +264,17 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: + if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis): + # TODO(mvoz): A vaguely questionable assumption that it is always + # sound to have a 0 axis here. This is true for the current use cases + # and comes from how we handle intermediary products of jumbles in + # vmap. + return BatchTracer(trace, x, 0, source_info_util.current()) # TODO(mvoz): This is a terrible place to fall into if you pass # a non jumble type in, make it clearer what went wrong. assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, @@ -277,8 +284,7 @@ def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, def _cont(axis_size, elt, axis): return from_elt(trace, axis_size, i, elt, axis) return handler(_cont, axis_size, x, spec) - x_ = trace.full_raise(x) - val, bdim = x_.val, x_.batch_dim + val, bdim = trace.to_batch_info(x) if type(bdim) is RaggedAxis: if spec is not jumble_axis: # TODO(mattjj): improve this error message @@ -286,9 +292,9 @@ def _cont(axis_size, elt, axis): return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) else: try: - return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val) + return matchaxis(trace.axis_data.name, axis_size, bdim, spec, val) except SpecMatchError: - raise SpecMatchError(i, x_.batch_dim, spec) from None + raise SpecMatchError(i, x.batch_dim, spec) from None from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: @@ -321,23 +327,25 @@ def unregister_vmappable(data_type: type) -> None: def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables -@lu.transformation_with_aux -def flatten_fun_for_vmap(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_for_vmap(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans, is_leaf=is_vmappable) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) + store.store(out_tree) + return ans # Propagate ragged masking rules from invars to outvars -# rule([raggedness_per_invar], outvars) -> +# rule([params], [raggedness_per_invar], outvars) -> # [raggedness_per_invar, raggedness_per_outvar] RaggedMaskingRule = Callable[ - [list[Any], list[Any]], tuple[list[Any], list[Any]] + [list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]] ] ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {} -def ragged_mask_elementwise_rule(invar_raggedness, outvars): +def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars): # TODO(mvoz): A util for getting the ragged representations first_invar_raggedness = invar_raggedness[0] for other_invar_raggedness in invar_raggedness[1:]: @@ -348,17 +356,19 @@ def ragged_mask_elementwise_rule(invar_raggedness, outvars): return invar_raggedness, outvar_raggedness -def ragged_mask_assert_no_op_rule(invar_raggedness, outvars): +def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars): if any(invar_raggedness): raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}') return invar_raggedness, [None] * len(outvars) -def ragged_mask_no_op_rule(invar_raggedness, outvars): +def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars): return invar_raggedness, [None] * len(outvars) -def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness): +def ragged_mask_transfer_identity( + eqn_params, invar_raggedness, outvar_raggedness +): assert len(invar_raggedness) == 1, invar_raggedness outvar_raggedness = invar_raggedness return invar_raggedness, outvar_raggedness @@ -379,7 +389,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, if config.enable_checks.value: assert type(batch_dim) in (NotMapped, int, RaggedAxis) if type(batch_dim) is int: - aval = raise_to_shaped(core.get_aval(val)) + aval = core.get_aval(val) assert 0 <= batch_dim < len(aval.shape) self._trace = trace self.val = val @@ -388,7 +398,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, @property def aval(self): - aval = raise_to_shaped(core.get_aval(self.val)) + aval = core.get_aval(self.val) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: @@ -426,165 +436,118 @@ def get_referent(self): else: # TODO(mattjj): could handle the RaggedAxis case? return self +@dataclasses.dataclass(frozen=True) +class AxisData: + name : Any + size : Any + spmd_name : Any + + class BatchTrace(Trace): - def __init__(self, *args, axis_name, spmd_axis_name = None): - super().__init__(*args) - self.axis_name = axis_name - self.spmd_axis_name = spmd_axis_name - - def pure(self, val): - return BatchTracer(self, val, not_mapped, source_info_util.current()) - - def lift(self, val): - return BatchTracer(self, val, not_mapped, source_info_util.current()) - - def sublift(self, val): - return BatchTracer(self, val.val, val.batch_dim, source_info_util.current()) - - def get_primitive_batcher(self, primitive, frame): - if primitive in primitive_batchers: - return primitive_batchers[primitive] - elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers: - return partial(spmd_axis_primitive_batchers[primitive], - self.spmd_axis_name, frame.size, frame.name, - frame.main_trace.trace_type) - elif primitive in axis_primitive_batchers: - return self.get_axis_primitive_batcher(primitive, frame) - msg = "Batching rule for '{}' not implemented" - raise NotImplementedError(msg.format(primitive)) - - def get_axis_primitive_batcher(self, primitive, frame): - return partial(axis_primitive_batchers[primitive], - frame.size, frame.name, frame.main_trace.trace_type) - - def get_frame(self, vals, dims) -> core.AxisEnvFrame: - if any(d is not not_mapped for d in dims): - sizes = (x.shape[d] if type(d) is int else d.size - for x, d in zip(vals, dims) if d is not not_mapped) - axis_size, = core.dedup_referents(sizes) + def __init__(self, parent_trace, tag, axis_data): + self.parent_trace = parent_trace + assert isinstance(axis_data, AxisData) + self.axis_data = axis_data + self.tag = tag + + def to_batch_info(self, val): + if isinstance(val, BatchTracer) and val._trace.tag is self.tag: + return val.val, val.batch_dim else: - axis_size = None # can't be inferred from data - if self.axis_name is core.no_axis_name: - assert axis_size is not None # must be inferable from data - return core.AxisEnvFrame(self.axis_name, axis_size, self.main) - frame = core.axis_frame(self.axis_name, self.main) - assert axis_size is None or axis_size == frame.size, (axis_size, frame.size) - assert frame.main_trace is self.main - return frame - - def process_primitive(self, primitive, tracers, params): + return val, not_mapped + + def process_primitive(self, p, tracers, params): if config.dynamic_shapes.value: - primitive.abstract_eval(*(t.aval for t in tracers), **params) - vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) - is_axis_primitive = primitive in axis_primitive_batchers - used_names = core.used_axis_names(primitive, params) - if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names): - frame = self.get_frame(vals_in, dims_in) - batcher_primitive = self.get_axis_primitive_batcher(primitive, frame) - val_out, dim_out = batcher_primitive(vals_in, dims_in, **params) - elif all(bdim is not_mapped for bdim in dims_in): - return primitive.bind(*vals_in, **params) + p.abstract_eval(*(map(core.get_aval, tracers)), **params) + vals_in, dims_in = unzip2(map(self.to_batch_info, tracers)) + args_not_mapped = all(bdim is not_mapped for bdim in dims_in) + if p in fancy_primitive_batchers: + if (args_not_mapped + and p in skippable_batchers + and not any(self.axis_data.name == axis_name + for axis_name in skippable_batchers[p](params))): + # no-op shortcut + return p.bind_with_trace(self.parent_trace, vals_in, params) + else: + with core.set_current_trace(self.parent_trace): + val_out, dim_out = fancy_primitive_batchers[p](self.axis_data, vals_in, dims_in, **params) + elif args_not_mapped: + # no-op shortcut + return p.bind_with_trace(self.parent_trace, vals_in, params) + elif p in primitive_batchers: + with core.set_current_trace(self.parent_trace): + val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params) else: - frame = self.get_frame(vals_in, dims_in) - batched_primitive = self.get_primitive_batcher(primitive, frame) - val_out, dim_out = batched_primitive(vals_in, dims_in, **params) + raise NotImplementedError("Batching rule for '{}' not implemented".format(p)) src = source_info_util.current() - if primitive.multiple_results: - return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)] + if p.multiple_results: + with core.set_current_trace(self.parent_trace): # val_out may be lazy map + return [BatchTracer(self, x, d, src) if d is not not_mapped else x + for x, d in zip(val_out, dim_out)] else: - return BatchTracer(self, val_out, dim_out, src) + return (BatchTracer(self, val_out, dim_out, src) + if dim_out is not not_mapped else val_out) def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results params = dict(params, name=params.get('name', f.__name__)) - vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) - if all(bdim is not_mapped for bdim in dims): - return call_primitive.bind(f, *vals, **params) - sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths) - for x, d in zip(vals, dims) if d is not not_mapped) - axis_size, = core.dedup_referents(sizes) + vals, dims = unzip2(map(self.to_batch_info, tracers)) segment_lens, dims = indirectify_ragged_axes(dims) - f_, dims_out = batch_subtrace(f, self.main, tuple(dims)) + f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims)) f_ = _update_annotation( - f_, f.in_type, axis_size, self.axis_name, dims, segment_lens) - vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) + f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens) + + with core.set_current_trace(self.parent_trace): + vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out()) src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)] - def post_process_call(self, call_primitive, out_tracers, params): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main - def todo(vals): - trace = main.with_cur_sublevel() - return map(partial(BatchTracer, trace), vals, dims, srcs) - return vals, todo - def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) - if all(dim is not_mapped for dim in dims): - return map_primitive.bind(f, *vals, **params) - else: - assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1 - # The logic for the dimension math below is as follows: - # ╔═════════════╦════════════════════════════════════════╦═══════════╗ - # ║ d / in_axis ║ None ║ int ║ - # ╠═════════════╬════════════════════════════════════════╩═══════════╣ - # ║ None ║ No extra axis, so in_axis unaffected ║ - # ╠═════════════╬════════════════════════════════════════╦═══════════╣ - # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ - # ╚═════════════╩════════════════════════════════════════╩═══════════╝ - # When both d and in_axis are defined then: - # - If `d <= in_axis`, we have to move the `in_axis` one dimension further; - # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed). - def both_mapped(in_out_axis, d): - return in_out_axis is not None and d is not not_mapped - new_in_axes = tuple( - in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis - for d, in_axis in zip(dims, params['in_axes'])) - new_dims = tuple( - d - 1 if both_mapped(in_axis, d) and in_axis < d else d - for d, in_axis in zip(dims, params['in_axes'])) - f, dims_out = batch_subtrace(f, self.main, new_dims) - out_axes_thunk = params['out_axes_thunk'] - # NOTE: This assumes that the choice of the dimensions over which outputs - # are batched is entirely dependent on the function and not e.g. on the - # data or its shapes. - @as_hashable_function(closure=out_axes_thunk) - def new_out_axes_thunk(): - return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis - for out_axis, d in zip(out_axes_thunk(), dims_out())) - new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) - vals_out = map_primitive.bind(f, *vals, **new_params) - dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d - for d, out_axis in zip(dims_out(), out_axes_thunk())] - src = source_info_util.current() - return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)] - - def post_process_map(self, call_primitive, out_tracers, params): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main + vals, dims = unzip2(map(self.to_batch_info, tracers)) + # The logic for the dimension math below is as follows: + # ╔═════════════╦════════════════════════════════════════╦═══════════╗ + # ║ d / in_axis ║ None ║ int ║ + # ╠═════════════╬════════════════════════════════════════╩═══════════╣ + # ║ None ║ No extra axis, so in_axis unaffected ║ + # ╠═════════════╬════════════════════════════════════════╦═══════════╣ + # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ + # ╚═════════════╩════════════════════════════════════════╩═══════════╝ + # When both d and in_axis are defined then: + # - If `d <= in_axis`, we have to move the `in_axis` one dimension further; + # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed). def both_mapped(in_out_axis, d): return in_out_axis is not None and d is not not_mapped - def todo(vals): - trace = main.with_cur_sublevel() - return [BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s) - for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)] - if call_primitive.map_primitive: - def out_axes_transform(out_axes): - return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis - for out_axis, d in zip(out_axes, dims)) - todo = (todo, out_axes_transform) - return vals, todo + new_in_axes = tuple( + in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis + for d, in_axis in zip(dims, params['in_axes'])) + new_dims = tuple( + d - 1 if both_mapped(in_axis, d) and in_axis < d else d + for d, in_axis in zip(dims, params['in_axes'])) + f, dims_out = batch_subtrace(f, self.tag, self.axis_data, new_dims) + out_axes_thunk = params['out_axes_thunk'] + # NOTE: This assumes that the choice of the dimensions over which outputs + # are batched is entirely dependent on the function and not e.g. on the + # data or its shapes. + @as_hashable_function(closure=out_axes_thunk) + def new_out_axes_thunk(): + return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis + for out_axis, d in zip(out_axes_thunk(), dims_out())) + new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) + with core.set_current_trace(self.parent_trace): + vals_out = map_primitive.bind(f, *vals, **new_params) + dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d + for d, out_axis in zip(dims_out(), out_axes_thunk())] + src = source_info_util.current() + return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)] def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) - out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) + in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) + fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) + jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) + out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp) + tuple(in_vals), + dict(symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: assert out_dims == out_dims[:len(out_dims) // 2] * 2 @@ -592,34 +555,18 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] - def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main - def todo(vals): - trace = main.with_cur_sublevel() - if jvp_was_run: - primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):] - assert primal_dims == tangent_dims - primal_srcs = srcs[:len(vals)] - return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) - else: - return map(partial(BatchTracer, trace), vals, dims, srcs) - return vals, todo - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, symbolic_zeros): # pytype: disable=signature-mismatch - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) - if d is not not_mapped} + in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] - fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) - bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, - out_dims2, in_dims, self.main.trace_type, - self.spmd_axis_name) - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + + fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims) + + bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, out_dims2, in_dims) + out_vals = prim.bind_with_trace(self.parent_trace, + (fun, fwd, bwd) + tuple(in_vals), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: _, res_tree = out_trees() @@ -627,83 +574,47 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] - def post_process_custom_vjp_call(self, out_tracers, _): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main - def todo(vals): - trace = main.with_cur_sublevel() - return map(partial(BatchTracer, trace), vals, dims, srcs) - return vals, todo - - def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped} - main, trace_type = self.main, self.main.trace_type - axis_name = self.axis_name - _, res_tree = out_trees() - num_res = res_tree.num_leaves - res_dims, primal_dims = split_list(dims, [num_res]) - _, primal_srcs = split_list(srcs, [num_res]) - def todo(vals): - trace = main.with_cur_sublevel() - return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) - def bwd_transform(bwd): - return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,), - trace_type, self.spmd_axis_name) - return vals, todo, bwd_transform - -def _main_trace_for_axis_names(main_trace: core.MainTrace, - axis_name: Iterable[AxisName], - ) -> bool: - # This function exists to identify whether a main trace corresponds to any of - # the axis names used by a primitive. Axis names alone aren't enough because - # axis names can shadow, so we use the main trace as a tag. - return any(main_trace is core.axis_frame(n).main_trace for n in axis_name) - ### API for batching callables with vmappable inputs and outputs -def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size, - in_dims, out_dim_dests, main_type: type[BatchTrace] = BatchTrace, - spmd_axis_name: tuple[AxisName, ...] | None = None - ) -> lu.WrappedFun: +def batch(fun: lu.WrappedFun, axis_data, + in_dims, out_dim_dests) -> lu.WrappedFun: # we split up _batch_inner and _batch_outer for the leak checker - f = _batch_inner(fun, axis_size, out_dim_dests) - return _batch_outer(f, axis_name, axis_size, in_dims, main_type, - spmd_axis_name) - -@lu.transformation -def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name, - *in_vals): - with core.new_main( - main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main: - with core.extend_axis_env(axis_name, axis_size, main): - with source_info_util.transform_name_stack('vmap'): - outs = yield (main, in_dims, *in_vals), {} - del main - yield outs - -@lu.transformation -def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals): + f = _batch_inner(fun, axis_data, out_dim_dests) + return _batch_outer(f, axis_data, in_dims) + +@lu.transformation2 +def _batch_outer(f, axis_data, in_dims, *in_vals): + tag = TraceTag() + with source_info_util.transform_name_stack('vmap'): + outs, trace = f(tag, in_dims, *in_vals) + with core.ensure_no_leaks(trace): del trace + return outs + +@lu.transformation2 +def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims - trace = main.with_cur_sublevel() - idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0, - source_info_util.current())) - in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) - outs = yield in_tracers, {} + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, + source_info_util.current())) + in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + with (core.set_current_trace(trace), + core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), + core.add_spmd_axis_names(axis_data.spmd_name)): + outs = f(*in_tracers) + out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_size), range(len(outs)), + out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), outs, out_dim_dests) - yield out_vals + + return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. def vtile(f_flat: lu.WrappedFun, in_axes_flat: tuple[int | None, ...], out_axes_flat: tuple[int | None, ...], tile_size: int | None, - axis_name: AxisName, - main_type: type[BatchTrace] = BatchTrace): + axis_name: AxisName): @curry def tile_axis(arg, axis: int | None, tile_size): if axis is None: @@ -719,31 +630,33 @@ def untile_axis(out, axis: int | None): shape[axis:axis+2] = [shape[axis] * shape[axis+1]] return out.reshape(shape) - @lu.transformation - def _map_to_tile(*args_flat): + @lu.transformation2 + def _map_to_tile(f, *args_flat): sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) tile_size_ = tile_size or next(sizes, None) assert tile_size_ is not None, "No mapped arguments?" - outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} - yield map(untile_axis, outputs_flat, out_axes_flat) + outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat)) + return map(untile_axis, outputs_flat, out_axes_flat) - return _map_to_tile(batch( - f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type)) + axis_data = AxisData(axis_name, tile_size, None) + return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat)) ### API for batching functions with jaxpr type inputs and outputs -@lu.transformation_with_aux -def batch_subtrace(main, in_dims, *in_vals): - trace = main.with_cur_sublevel() - in_dims = in_dims() if callable(in_dims) else in_dims - in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) - in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) - if dim is not None else x for x, dim in zip(in_vals, in_dims)] - outs = yield in_tracers, {} - out_tracers = map(trace.full_raise, outs) - out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) - segment_lens, out_dims = indirectify_ragged_axes(out_dims) - yield (*segment_lens, *out_vals), out_dims +@lu.transformation_with_aux2 +def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals): + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + with core.set_current_trace(trace): + in_dims = in_dims() if callable(in_dims) else in_dims + in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) + in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) + if dim is not None else x for x, dim in zip(in_vals, in_dims)] + outs = f(*in_tracers) + out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) + segment_lens, out_dims = indirectify_ragged_axes(out_dims) + store.store(out_dims) + return (*segment_lens, *out_vals) def indirectify_ragged_axes(dims): if not any(type(d) is RaggedAxis for d in dims): @@ -814,38 +727,30 @@ def fetch(idx): # Can reuse same pattern for all dynamic shape stuff. def batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, - axis_size: core.AxisSize, + axis_data, in_axes: tuple[int | NotMapped | RaggedAxis, ...], - axis_name: AxisName, - spmd_axis_name: AxisName, - main_type: type[BatchTrace], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]: # This is only ever used in pjit. The difference vs batch_jaxpr is that # batch_jaxpr2 lets the callee decide which outputs are batched and what # their batch axes are; whereas batch_jaxpr has to obey caller-imposed # consistency constraints, such as type-agreement across arms of a # `lax.cond`, or input-output agreement for the body of a `lax.scan`. - return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name, - spmd_axis_name, main_type) + return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes)) @weakref_lru_cache def _batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, - axis_size: core.AxisSize, + axis_data, in_axes: tuple[int | NotMapped | RaggedAxis, ...], - axis_name: AxisName, - spmd_axis_name: AxisName, - main_type: type[BatchTrace], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) - f, out_axes = _batch_jaxpr_inner(f, axis_size) - f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, - main_type) + f, out_axes = _batch_jaxpr_inner(f, axis_data) + f = _batch_jaxpr_outer(f, axis_data, in_axes) in_axes2, avals_in = unzip2([ handle_ragged(closed_jaxpr.in_avals, dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval) + avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(avals_in, in_axes2)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) @@ -859,14 +764,11 @@ def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis, new_aval = aval.update(shape=tuple(new_shape)) return dim.stacked_axis, new_aval -def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, - spmd_axis_name, main_type): +def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate): inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst, - axis_name, spmd_axis_name, main_type) + return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst) -def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, - spmd_axis_name, main_type): +def _batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate): assert (isinstance(instantiate, bool) or isinstance(instantiate, (list, tuple)) and all(isinstance(b, bool) for b in instantiate)) @@ -874,46 +776,43 @@ def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, instantiate = [instantiate] * len(closed_jaxpr.out_avals) in_axes = [0 if b else not_mapped for b in in_batched] out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate] - return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, - axis_name, spmd_axis_name, main_type) + return batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest) -def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, - spmd_axis_name, main_type): - return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes), - tuple(out_axes_dest), axis_name, spmd_axis_name, - main_type) +def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): + return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest)) @weakref_lru_cache -def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, - axis_name, spmd_axis_name, main_type): +def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) - f, out_axes = _batch_jaxpr_inner(f, axis_size) - f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes) - f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, - main_type) - avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped + f, out_axes = _batch_jaxpr_inner(f, axis_data) + f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes) + f = _batch_jaxpr_outer(f, axis_data, in_axes) + avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() -@lu.transformation_with_aux -def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals): - trace = main.with_cur_sublevel() - _, in_axes = resolve_ragged_axes(in_vals, in_axes) - in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val - for val, dim in zip(in_vals, in_axes)] - outs = yield in_tracers, {} - out_tracers = map(trace.full_raise, outs) - out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers) - new_out_axes = indirectify_ragged_axes_against_inputs_outputs( - out_axes, in_vals, out_vals) - yield out_vals, new_out_axes - -@lu.transformation_with_aux -def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, +@lu.transformation_with_aux2 +def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + _, in_axes = resolve_ragged_axes(in_vals, in_axes) + in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val + for val, dim in zip(in_vals, in_axes)] + with (core.set_current_trace(trace), + core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), + core.add_spmd_axis_names(axis_data.spmd_name)): + outs = f(*in_tracers) + out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) + new_out_axes = indirectify_ragged_axes_against_inputs_outputs( + out_axes, in_vals, out_vals) + store.store(new_out_axes) + return out_vals + +@lu.transformation_with_aux2 +def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes, *in_vals): - trace = main.with_cur_sublevel() - out_vals = yield (main, in_axes, *in_vals), {} + out_vals = f(trace, in_axes, *in_vals) out_axes = out_axes() out_axes_dest = [(None if src is not_mapped else 0) if dst is zero_if_mapped else dst @@ -921,25 +820,19 @@ def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, if len(out_axes_dest) != len(out_axes): out_axis_dest, = out_axes_dest out_axes_dest = [out_axis_dest] * len(out_axes) - out_vals = map(partial(matchaxis, trace.axis_name, axis_size), + out_vals = map(partial(matchaxis, axis_data.name, axis_data.size), out_axes, out_axes_dest, out_vals) out_batched = [dst is not None for dst in out_axes_dest] - yield out_vals, out_batched + store.store(out_batched) + return out_vals -@lu.transformation -def _batch_jaxpr_outer(axis_name, spmd_axis_name, axis_size, in_dims, main_type, - *in_vals): - if axis_size is None: - axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} +@lu.transformation2 +def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in unsafe_zip(in_vals, in_dims)] - with core.new_main(main_type, axis_name=axis_name, - spmd_axis_name=spmd_axis_name) as main: - with core.extend_axis_env(axis_name, axis_size, main): - out_vals = yield (main, in_dims, *in_vals), {} - del main - yield out_vals + tag = TraceTag() + return f(tag, in_dims, *in_vals) def _merge_bdims(x, y): if x == y: @@ -956,32 +849,35 @@ class ZeroIfMapped: pass ### functions for handling custom_vjp -@lu.transformation_with_aux -def batch_custom_jvp_subtrace(main, in_dims, *in_vals): - size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2) - if d is not not_mapped} - trace = main.with_cur_sublevel() - in_tracers = [val if dim is None else - SymbolicZero(core.mapped_aval(size, dim, val.aval)) - if type(val) is SymbolicZero else BatchTracer(trace, val, dim) - for val, dim in zip(in_vals, in_dims * 2)] - outs = yield in_tracers, {} - # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can - # be wasteful in the rare case it actually triggers; handle symbolically! - outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] - out_tracers = map(trace.full_raise, outs) - out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) +@lu.transformation_with_aux2 +def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): + size = axis_data.size + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + in_tracers = [val if dim is None else + SymbolicZero(core.mapped_aval(size, dim, val.aval)) + if type(val) is SymbolicZero else BatchTracer(trace, val, dim) + for val, dim in zip(in_vals, in_dims * 2)] + with core.set_current_trace(trace): + outs = f(*in_tracers) + # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can + # be wasteful in the rare case it actually triggers; handle symbolically! + outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] + + out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) - out_primals = map(partial(matchaxis, trace.axis_name, size), + out_primals = map(partial(matchaxis, trace.axis_data.name, size), out_primal_bds, out_dims, out_primals) - out_tangents = map(partial(matchaxis, trace.axis_name, size), + out_tangents = map(partial(matchaxis, trace.axis_data.name, size), out_tangent_bds, out_dims, out_tangents) - yield out_primals + out_tangents, out_dims * 2 + store.store(out_dims * 2) + return out_primals + out_tangents -def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, - main_type, spmd_axis_name): +def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): + axis_size = axis_data.size + axis_name = axis_data.name def new_bwd(*args): in_dims_ = in_dims() if callable(in_dims) else in_dims args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) @@ -989,19 +885,17 @@ def new_bwd(*args): for x, dim in zip(args, in_dims_)] in_dims_ = [None if type(x) is SymbolicZero else d for x, d in zip(args, in_dims_)] - bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) - bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type, - spmd_axis_name) + bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_) bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) return bwd_.call_wrapped(*args) return new_bwd -@lu.transformation -def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): +@lu.transformation2 +def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed - out_vals = yield in_vals, {} - yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, + out_vals = f(*in_vals) + return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): @@ -1030,8 +924,23 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False) tuple[Any, Union[int, None, tuple[Union[int, None], ...]]] ] primitive_batchers : dict[core.Primitive, BatchingRule] = {} -axis_primitive_batchers: dict[core.Primitive, Callable] = {} -spmd_axis_primitive_batchers: dict[core.Primitive, Callable] = {} +# "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args +fancy_primitive_batchers: dict[core.Primitive, Callable] = {} + +# backwards compat shim. TODO: delete +class AxisPrimitiveBatchersProxy: + def __setitem__(self, prim, batcher): + def wrapped(axis_data, vals, dims, **params): + return batcher(axis_data.size, axis_data.name, None, vals, dims, **params) + fancy_primitive_batchers[prim] = wrapped + +axis_primitive_batchers = AxisPrimitiveBatchersProxy() + + +# Presence in this table allows fancy batchers to be skipped by batch traces for +# irrelevant axes. The Callable takes the params and returns a list of relevant +# axes. +skippable_batchers : dict[core.Primitive, Callable] = {} def defvectorized(prim): primitive_batchers[prim] = partial(vectorized_batcher, prim) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2adeb4b16cd9..531177b7244c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -184,13 +184,13 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.int2 is not None: assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( - ir.IntegerType.get_signless, 2 - ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( - ir.IntegerType.get_unsigned, 2 - ) + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): @@ -230,7 +230,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes: raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err ir_type_handlers[core.ShapedArray] = _array_ir_types -ir_type_handlers[core.ConcreteArray] = _array_ir_types ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get() ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types @@ -587,9 +586,18 @@ def module_to_bytecode(module: ir.Module) -> bytes: return output.getvalue() # Translation rules + +class JaxIrContext(ir.Context): + def __init__(self, *args, **kwargs): + # Note: we're very intentionally *not* calling the __init__() of our + # immediate superclass ir.Context, whose __init__() has the unfortunate side + # effect of loading all the dialects linked into the binary into the + # context. We want to ensure that only the dialects we need are loaded. + super(ir.Context, self).__init__(*args, **kwargs) + def make_ir_context() -> ir.Context: """Creates an MLIR context suitable for JAX IR.""" - context = ir.Context() + context = JaxIrContext() context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() @@ -762,7 +770,13 @@ def backend(self) -> xb.XlaBackend: return self.backend_or_name def new_channel(self) -> int: - return next(self.channel_iterator) + channel = next(self.channel_iterator) + # `xla::HostCallback` requires a 16-bit channel ID. + if channel >= (1 << 16): + raise RuntimeError( + "Host callback lowering created too many channels. PjRt does not" + " support more than 65535 channels") + return channel # Adds an IFRT host callback object to the context. A reference to these # callbacks will be provided to IFRT during compilation so it can do things @@ -1685,6 +1699,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. + assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim s = sharding_impls.SdyArraySharding( @@ -1738,6 +1753,44 @@ def _emit_lowering_rule_as_fun(lowering_rule, return func_op +class HashableLiteral: + """Hashable wrapper of core.Literal, used for deduplicating IR constants.""" + + __slots__ = ["value", "data"] + + value: core.Literal + + # Copy of the value suitable for an equality comparison. We are careful to + # avoid floating point comparisons here, because in particular we don't want + # 0.0 and -0.0 to be considered equal, but we are fine with NaNs being equal. + data: bytes | int | bool | None + + def __init__(self, value): + self.value = value + if isinstance(value.val, (np.generic, np.ndarray)): + self.data = value.val.tobytes() + elif isinstance(value.val, (bool, int)): + self.data = value.val + elif isinstance(value.val, float): + self.data = np.float64(value.val).tobytes() + elif isinstance(value.val, complex): + self.data = np.complex128(value.val).tobytes() + else: + self.data = None # Unhandled case. + + def __hash__(self): + return hash(self.data) + + def __eq__(self, other): + if type(self.value.val) != type(other.value.val): + return False + if self.value.aval != other.value.aval: + return False + if self.data is None: + return id(self) == id(other) + return self.data == other.data + + def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, name_stack: source_info_util.NameStack, tokens: TokenSet, @@ -1753,9 +1806,16 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, IR function, in the order of ctx.shape_poly_state.dim_vars. """ assert "gpu" not in ctx.platforms + cached_ir_consts: dict[HashableLiteral, IrValues] = {} + def read(v: core.Atom) -> IrValues: if type(v) is core.Literal: - return ir_constant(xla.canonicalize_dtype(v.val)) + h = HashableLiteral(v) + c = cached_ir_consts.get(h) + if c is None: + c = ir_constant(xla.canonicalize_dtype(v.val)) + cached_ir_consts[h] = c + return c else: assert isinstance(v, core.Var) return env[v] @@ -1878,6 +1938,8 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return () if eqn_ctx.compute_type == 'device_host': return ('cpu',) + if eqn_ctx.compute_type == 'tpu_sparsecore': + return ('tpu',) return () @@ -2160,8 +2222,10 @@ def map_compute_type(c_type): return 'host' elif c_type == 'device': return 'dense' + elif c_type == 'tpu_sparsecore': + return 'sparse' raise ValueError('Invalid compute type received. Current supported values ' - 'are `device_host` and `device`') + 'are `device_host`, `device` and `tpu_sparsecore') def wrap_compute_type_in_place(ctx, op): if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: @@ -2456,6 +2520,18 @@ def _wrap_with_spmd_op(name: str, wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape") +def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): + # Don't emit a wsc under full manual mode to avoid increasing HLO size. + if aval.sharding.mesh._are_all_axes_collective: + return op + proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + if sharding_proto is None else sharding_proto) + # TODO(yashkatariya): Enable this + # unspecified_dims = (set(range(aval.ndim)) + # if aval.sharding.mesh._any_axis_collective else None) + return wrap_with_sharding_op(ctx, op, aval, proto) + + def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ab00e5729cc2..49da3e2a0ec9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager, AbstractContextManager +from contextlib import contextmanager from functools import partial import inspect import itertools as it @@ -38,9 +38,9 @@ from jax._src import xla_metadata as xla_metadata_lib from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, fun_sourceinfo) -from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, +from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - ConcreteArray, Var, DropVar, raise_to_shaped, Atom, + Var, DropVar, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) @@ -143,28 +143,26 @@ def get_aval(self) -> AbstractValue: class JaxprTrace(Trace['JaxprTracer']): - def __init__(self, *args, name_stack: source_info_util.NameStack): - super().__init__(*args) + def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, tag:TraceTag): self.name_stack = name_stack + self.tag = tag + self.parent_trace = parent_trace - def pure(self, val: Any) -> JaxprTracer: - return self.new_const(val) - - def lift(self, val: Tracer) -> JaxprTracer: - return self.new_const(val) - - def sublift(self, val: JaxprTracer) -> JaxprTracer: - return JaxprTracer(self, val.pval, FreeVar(val)) + def to_jaxpr_tracer(self, x): + if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: + if x._trace is self: + return x + else: + return JaxprTracer(self, x.pval, FreeVar(x)) + else: + return self.new_const(x) def new_const(self, val) -> JaxprTracer: - if isinstance(val, Tracer) and val._trace.level == self.level: - raise Exception return JaxprTracer(self, PartialVal.known(val), None) def new_instantiated_literal(self, val) -> JaxprTracer: aval = get_aval(val) - return JaxprTracer(self, PartialVal.unknown(aval), - Literal(val, raise_to_shaped(aval))) + return JaxprTracer(self, PartialVal.unknown(aval), Literal(val, aval)) def new_instantiated_const(self, val) -> JaxprTracer: aval = get_aval(val) @@ -179,9 +177,12 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer: if const is None: aval = pval.get_aval() if type(aval) is DShapedArray: + # TODO(dougalm): Fix the type error and remove the pytype pragmas. + # pytype: disable=attribute-error shape = [self.new_instantiated_const(d) if isinstance(d, Tracer) and d._trace.level < self.level else d for d in aval.shape] + # pytype: enable=attribute-error aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) else: @@ -202,22 +203,25 @@ def instantiate_const_abstracted(self, tracer) -> JaxprTracer: if const is None: return tracer else: - aval = raise_to_shaped(get_aval(const), np.isscalar(const)) + aval = get_aval(const).update_weak_type(np.isscalar(const)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) def process_primitive(self, primitive, tracers, params): - if primitive in custom_partial_eval_rules: - return custom_partial_eval_rules[primitive](self, *tracers, **params) - else: - return self.default_process_primitive(primitive, tracers, params) + with core.set_current_trace(self.parent_trace): + if primitive in custom_partial_eval_rules: + tracers = map(self.to_jaxpr_tracer, tracers) + return custom_partial_eval_rules[primitive](self, *tracers, **params) + else: + return self.default_process_primitive(primitive, tracers, params) def default_process_primitive(self, primitive, tracers, params): # By default, if all the input tracers are known, then bind the primitive # and consider all outputs known. Otherwise, stage the application into the # jaxpr and consider all outputs unknown. + tracers = map(self.to_jaxpr_tracer, tracers) consts = [t.pval.get_known() for t in tracers] if all(c is not None for c in consts): - return primitive.bind(*consts, **params) + return primitive.bind_with_trace(self.parent_trace, consts, params) tracers = map(self.instantiate_const, tracers) avals = [t.aval for t in tracers] out_aval, effects = primitive.abstract_eval(*avals, **params) @@ -237,6 +241,7 @@ def default_process_primitive(self, primitive, tracers, params): return out_tracer def process_call(self, primitive, f, tracers, params): + tracers = map(self.to_jaxpr_tracer, tracers) rule = call_partial_eval_rules.get(primitive) if rule: return rule(self, primitive, f, tracers, params) @@ -253,15 +258,15 @@ def process_call(self, primitive, f, tracers, params): # which were unknown to the first call (corresponding to in_avals). # Wrap f to perform the partial evaluation and plumb out aux data. - f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) - f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), - tuple(in_avals)) + f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, False) + f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals)) + # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. const_params = update_params(params, in_knowns, 0) # Run the call, getting known out vals and aux data used for staged-out call - out = primitive.bind(_update_annotation_known(f_, f.in_type, in_knowns), - *in_consts, **const_params) + fun_and_args = (_update_annotation_known(f_, f.in_type, in_knowns),) + tuple(in_consts) + out = primitive.bind_with_trace(self.parent_trace, fun_and_args, const_params) fwds, out_knowns, out_type, jaxpr, env = aux() # Split apart known outputs from the original call and non-fwded residuals. out_consts, non_fwd_res = split_list(out, [sum(out_knowns)]) @@ -284,7 +289,7 @@ def process_call(self, primitive, f, tracers, params): # Create the input tracers for the staged-out (unknown-value) call. res_tracers = map(self.instantiate_const, map(self.new_const, res)) - env_tracers = map(self.full_raise, env) + env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] # Adjust parameters (e.g. donated_invars) for the staged-out call's args. num_new_args = len(res_tracers) + len(env_tracers) @@ -296,7 +301,6 @@ def process_call(self, primitive, f, tracers, params): # With dynamic shapes, we may need to substitute Tracers into avals. out_tracers = [] for aval, _ in out_type: - assert not isinstance(aval, ConcreteArray) if type(aval) is DShapedArray: shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] if type(d) is InDBIdx else d for d in aval.shape] @@ -314,6 +318,7 @@ def process_call(self, primitive, f, tracers, params): return merge_lists(out_knowns, out_tracers, out_consts) def process_map(self, primitive, f: lu.WrappedFun, tracers, params): + tracers = map(self.to_jaxpr_tracer, tracers) update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) @@ -329,7 +334,7 @@ def process_map(self, primitive, f: lu.WrappedFun, tracers, params): for ax, aval in zip(unk_in_axes, in_avals)] # Wrap f to perform partial evaluation and plumb out aux data. - f = trace_to_subjaxpr_nounits(f, self.main, False) + f = trace_to_subjaxpr_nounits2(f, self.tag, False) f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals_mapped)) # Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk) @@ -344,13 +349,13 @@ def const_out_axes_thunk(): out_axes_thunk=const_out_axes_thunk) # Run the map, getting known out vals and aux data used for staged-out map. - out = primitive.bind(f, *in_consts, **const_params) + out = primitive.bind_with_trace(self.parent_trace, (f, *in_consts), const_params) out_knowns, out_avals_mapped, jaxpr, env = aux() # Split apart known outputs from the original call and residuals. out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) # We can only check_jaxpr with the dynamic axis environment extended: - with core.extend_axis_env(params['axis_name'], params['axis_size'], None): + with core.extend_axis_env_nd([(params['axis_name'], params['axis_size'])]): call_jaxpr = convert_constvars_jaxpr(jaxpr) # Compute staged and const out_axes, taking into account residuals. @@ -358,9 +363,9 @@ def const_out_axes_thunk(): staged_out_axes, _ = partition_list(out_knowns, out_axes) staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,) - # Create the input tracers for the staged-out (unkonwn-value) call. + # Create the input tracers for the staged-out (unknown-value) call. const_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.full_raise, env) + env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] # Adjust params for staged-out call on unknown values. num_new_args = len(const_tracers) + len(env_tracers) @@ -381,95 +386,24 @@ def const_out_axes_thunk(): return merge_lists(out_knowns, out_tracers, out_consts) - def post_process_call(self, primitive, out_tracers, params): - unknown_out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) - out_pvals = [t.pval for t in out_tracers] - out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - out = [*out_consts, *res] - main = self.main - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - const_tracers = map(trace.new_instantiated_const, res) - in_tracers = (*const_tracers, *map(trace.full_raise, env)) - out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) - for a in out_avals] - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - new_params = update_params(params, [], len(in_tracers)) - new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, - jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - return out, todo - - def post_process_map(self, primitive, out_tracers, params): - unknown_out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) - out_pvals = [t.pval for t in out_tracers] - out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals) - out = [*out_consts, *res] - main = self.main - - with core.extend_axis_env(params['axis_name'], params['axis_size'], None): - call_jaxpr = convert_constvars_jaxpr(jaxpr) - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.full_raise, env) - - staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform - staged_in_axes = (0,) * len(res) + (None,) * len(env) - - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - staged_params = update_params(params, [], len(res) + len(env)) - staged_params = dict(staged_params, in_axes=staged_in_axes, - out_axes=tuple(staged_out_axes), - call_jaxpr=call_jaxpr) - - out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a) - for d, a in zip(staged_out_axes, out_avals_mapped)] - out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) - for a in out_avals] - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - primitive, staged_params, jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def out_axes_transform(out_axes): - nonlocal out_axes_unknown - out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes) - return tuple(out_axes_known) + (0,) * len(jaxpr.constvars) - out_axes_unknown: list | None = None - - return out, (todo, out_axes_transform) - def _current_truncated_name_stack(self): return source_info_util.current_name_stack()[len(self.name_stack):] - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # We assume partial evaluation is only performed to build linear functions, - # and hence we don't need to keep the custom JVP rule around anymore. + def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) + if all(t.is_known() for t in tracers): + with core.set_current_trace(self.parent_trace): + vals = [t.pval[1] for t in tracers] + return prim.bind(fun, jvp, *vals, symbolic_zeros=symbolic_zeros) + # We assume non-trivial partial evaluation is only performed to build linear + # functions, and hence we don't need to keep the custom JVP rule around + # anymore. del jvp, symbolic_zeros - assert not all(t.is_known() for t in tracers) - return fun.call_wrapped(*tracers) - - def post_process_custom_jvp_call(self, out_tracers, _): - # This path should only be reachable if we expose a partial eval API - # unrelated to autodiff, since we raise an error when differentiation with - # respect to values over which a custom_jvp function closes is detected. - raise NotImplementedError # TODO(mattjj) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) def process_custom_transpose(self, prim, call, tracers, **params): + tracers = map(self.to_jaxpr_tracer, tracers) res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves]) assert all(t.is_known() for t in res_ts) lin_all_known = all(t.is_known() for t in lin_ts) @@ -487,36 +421,41 @@ def process_custom_transpose(self, prim, call, tracers, **params): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, - symbolic_zeros): - # TODO(mattjj): after old remat is deleted, make this method trivial. - # Because we instantiate all tracers, in_knowns is all False. - tracers = map(self.instantiate_const_abstracted, tracers) - in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self.main, True) - f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - out_knowns, out_avals, jaxpr, env = aux() - out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - res_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.full_raise, env) - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) - - @_memoize - def fwd_jaxpr_thunk(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True) - fwd_, aux = partial_eval_wrapper_nounits( - fwd_, tuple(in_knowns), tuple(in_avals)) - with core.new_sublevel(): - out_flat = fwd_.call_wrapped() + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) + if all(t.is_known() for t in tracers): + vals = [t.pval[1] for t in tracers] + with core.set_current_trace(self.parent_trace): + return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) + else: + # TODO(mattjj): remove non-ad users of partial eval, then drop this case. + # We stage out the whole thing, i.e. no nontrivial partial evaluation. + tracers = map(self.instantiate_const_abstracted, tracers) + # Because we instantiate all tracers, in_knowns is all False. + in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) + f = trace_to_subjaxpr_nounits(f, self, True) + f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,)) + with core.set_current_trace(self.parent_trace): + out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) out_knowns, out_avals, jaxpr, env = aux() - _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) - return converted_jaxpr, (*res, *env) + out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + res_tracers = map(self.new_instantiated_const, res) + env_tracers = map(self.to_jaxpr_tracer, env) + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) + + @_memoize + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True) + fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,)) + out_flat = fwd_.call_wrapped() + out_knowns, out_avals, jaxpr, env = aux() + _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) + return converted_jaxpr, (*res, *env) name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) @@ -531,12 +470,6 @@ def fwd_jaxpr_thunk(*zeros): for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) - def post_process_custom_vjp_call(self, out_tracers, _): - # This path should only be reachable if we expose a partial eval API - # unrelated to autodiff, since we raise an error when differentiation with - # respect to values over which a custom_vjp function closes is detected. - raise NotImplementedError # TODO(mattjj) - def partition_pvals( pvals: list[PartialVal] ) -> tuple[list[bool], list[AbstractValue], list[Any]]: @@ -545,18 +478,19 @@ def partition_pvals( consts = [pval.get_known() for pval in pvals if pval.is_known()] return knowns, avals, consts -@lu.transformation_with_aux +@lu.transformation_with_aux2 def partial_eval_wrapper_nounits( - in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], + f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], *in_consts: Any): in_avals_, in_consts_ = iter(in_avals), iter(in_consts) in_pvals = [PartialVal.known(next(in_consts_)) if known else PartialVal.unknown(next(in_avals_)) for known in in_knowns] sentinel = object() assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel - jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {} + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) + store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) + return (*out_consts, *res) custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} @@ -587,12 +521,6 @@ def __init__(self, trace: JaxprTrace, pval: PartialVal, recipe: JaxprTracerRecipe | None): assert isinstance(pval, PartialVal) pv, const = pval - if isinstance(const, Tracer) and const._trace.level >= trace.level: - raise core.escaped_tracer_error( - const, f"Tracer from a higher level: {const} in trace {trace}") - if isinstance(pv, DShapedArray): - assert all(not isinstance(d, Tracer) or isinstance(d, JaxprTracer) and - d._trace.level == trace.level for d in pv.shape) self._trace = trace self.pval = pval self.recipe = recipe @@ -633,84 +561,68 @@ def get_referent(self): return self -@profiler.annotate_function -def trace_to_jaxpr( - fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate: bool | Sequence[bool] = False, - ) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]: - """ - Partially evaluate a function, building a jaxpr for un-evaluated computation. - - Args: - fun: lu.WrappedFun representing the function to be partially evaluated. The - function must be flattened, in the sense of accepting jaxpr type arguments - and returning a flat list of jaxpr type outputs. - pvals: sequence of PartialVals of length equal to the number of inputs to - `fun` indicating which inputs are known or unknown. - instantiate: optional bool or sequence of bools of length equal to the - number of outputs of `fun` indicating which outputs should be forced to be - treated as unknown and hence instantiated in the jaxpr. If a single bool, - the value is applied to all outputs. Default False. - - Returns: - A triple where the first element is a jaxpr representing the computation - which depends on unknown inputs; the second element is a list of PartialVals - of length equal to the length of the output of `fun` representing which - outputs are known and unknown (along with their values and abstract values, - respectively); the third element is a list of known residual values. The - returned jaxpr takes as inputs the known residual values followed by values - of the originally unknown inputs. - """ - current_name_stack = source_info_util.current_name_stack() - with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: - fun = trace_to_subjaxpr(fun, main, instantiate) - jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) - assert not env - del main, fun, env - - return jaxpr, out_pvals, consts - @profiler.annotate_function def trace_to_jaxpr_nounits( fun: lu.WrappedFun, pvals: Sequence[PartialVal], instantiate: bool | Sequence[bool] = False, ) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]: current_name_stack = source_info_util.current_name_stack() - with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: - fun = trace_to_subjaxpr_nounits(fun, main, instantiate) - jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) - assert not env - del main, fun, env - return jaxpr, out_pvals, consts - - -@lu.transformation + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, TraceTag()) + with core.ensure_no_leaks(trace): + fun = trace_to_subjaxpr_nounits(fun, trace, instantiate) + with core.set_current_trace(trace): + jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) + assert not env + del trace, fun + return jaxpr, out_pvals, consts + +# TODO(mattjj): superfluous wrapper...? +@lu.transformation2 def trace_to_subjaxpr_nounits( - main: core.MainTrace, + f, + trace: JaxprTrace, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(main, instantiate, in_pvals): - trace = main.with_cur_sublevel() +@lu.transformation2 +def trace_to_subjaxpr_nounits2( + f, + tag: TraceTag, + instantiate: bool | Sequence[bool], + in_pvals: Sequence[PartialVal]): + assert isinstance(tag, TraceTag) + assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, tag) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] + del out_tracers + return jaxpr, (out_pvals, out_consts, env) + +def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] in_args = merge_lists(in_knowns, in_tracers, in_consts) - ans = yield in_args, {} + with core.set_current_trace(trace): + ans = f(*in_args) assert isinstance(ans, (list, tuple)), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") if isinstance(instantiate, bool): instantiate = [instantiate] * len(ans) - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t + out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_) @@ -719,39 +631,48 @@ def _trace_to_subjaxpr_nounits(main, instantiate, in_pvals): # The below variant implements an optimization where residuals which are also # inputs are indicated in auxiliary data rather than passed as outputs. # TODO(mattjj): update all callers to use this version, delete other version. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd( - main: core.MainTrace, + f, + tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) - out_pvals = [t.pval for t in out_tracers] + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, tag) + with core.set_current_trace(trace): + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] - # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. - in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] - id_map = {id(c): i for i, c in enumerate(in_consts)} - fwds: list[int | None] = [id_map.get(id(c)) for c in out_consts] - pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] + # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. + in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] + id_map = {id(c): i for i, c in enumerate(in_consts)} + fwds: list[int | None] = [id_map.get(id(c)) for c in out_consts] + pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] - del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) + del out_tracers + return jaxpr, (fwds, out_pvals, pruned_consts, env) # The below variant implements two optimizations: # 1. residuals that are also primal inputs are indicated in aux data rather # than passed as outputs; # 2. residuals that are also primal outputs are indicated in aux data rather # than passed as redundant outputs. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd2( - main: core.MainTrace, + f, + tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) - out_pvals = [t.pval for t in out_tracers] + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, tag) + out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] @@ -767,7 +688,7 @@ def trace_to_subjaxpr_nounits_fwd2( if f1 is None and f2 is None] del out_tracers - yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) + return jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) FreeVar = namedtuple('FreeVar', ['val']) @@ -801,7 +722,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], len(params["in_axes"]) == len(params["call_jaxpr"].invars)) assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) - out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] + out_avals = [t.aval for t in out_tracers] ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, @@ -1022,7 +943,7 @@ def fun(*known_vals_in): f, in_pvals, instantiate=instantiate) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] - res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals] + res_avals = [core.get_aval(r) for r in residuals] cell.append((out_unknowns, jaxpr_unknown, res_avals)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] @@ -1088,7 +1009,7 @@ def partial_eval_jaxpr_stateful( in_inst: bool | Sequence[bool], ensure_out_unknowns: bool | Sequence[bool], ensure_out_inst: bool | Sequence[bool], - saveable: Callable[..., RematCases_], + saveable: Callable[..., RematCases_] | None, ) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]: if type(in_inst) is bool: in_inst = (in_inst,) * len(jaxpr.invars) @@ -1096,6 +1017,8 @@ def partial_eval_jaxpr_stateful( ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) if type(ensure_out_inst) is bool: ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) + if saveable is None: + saveable = everything_saveable jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ _partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns), tuple(in_inst), @@ -1103,6 +1026,8 @@ def partial_eval_jaxpr_stateful( tuple(ensure_out_inst), saveable) return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref +everything_saveable = lambda *_, **__: True + @weakref_lru_cache def _partial_eval_jaxpr_custom_cached( jaxpr: Jaxpr, @@ -1283,7 +1208,7 @@ def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx, + ctx = trivial_ctx, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): @@ -1469,6 +1394,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], return new_jaxpr, used_consts, used_inputs +def has_effects(eqn: JaxprEqn) -> bool: + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + return bool(effs) + + @weakref_lru_cache def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], instantiate: tuple[bool, ...] @@ -1482,21 +1412,14 @@ def write(x: Atom, b: bool) -> None: if type(x) is Var: env[x] = read(x) or b - def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} - return bool(effs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params) - new_eqns = [] map(write, jaxpr.outvars, used_outputs) for eqn in jaxpr.eqns[::-1]: used_outs = map(read, eqn.outvars) - if not any(used_outs) and not has_effects(eqn): - used_ins = [False] * len(eqn.invars) - else: - rule = dce_rules.get(eqn.primitive, _default_dce_rule) - used_ins, new_eqn = rule(used_outs, eqn) - if new_eqn is not None: - new_eqns.append(new_eqn) + rule = dce_rules.get(eqn.primitive, _default_dce_rule) + used_ins, new_eqn = rule(used_outs, eqn) + if new_eqn is not None: + new_eqns.append(new_eqn) map(write, eqn.invars, used_ins) used_inputs = map(read, jaxpr.invars) used_inputs = map(op.or_, instantiate, used_inputs) @@ -1520,7 +1443,9 @@ def has_effects(eqn: JaxprEqn) -> bool: def _default_dce_rule( used_outs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outs) and not has_effects(eqn): + return [False] * len(eqn.invars), None return [True] * len(eqn.invars), eqn dce_rules: dict[Primitive, DCERule] = {} @@ -1528,6 +1453,8 @@ def _default_dce_rule( def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) new_params = dict(eqn.params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(eqn.primitive) @@ -1541,6 +1468,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn + dce_rules[core.call_p] = dce_jaxpr_call_rule @@ -1552,8 +1480,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...] return core.ClosedJaxpr(new_jaxpr, consts), used_inputs def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: # TODO(mattjj): de-duplicate with above rule? + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr_ = eqn.params['call_jaxpr'] closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs)) new_params = dict(eqn.params, call_jaxpr=closed_jaxpr) @@ -1614,13 +1544,7 @@ def _contents(self): return () def _origin_msg(self): - if not self._trace.main.jaxpr_stack: - # If this Tracer has been leaked the jaxpr stack may no longer be - # available. So we can't print as much origin information. - return ("\nThis DynamicJaxprTracer was created on line " - f"{source_info_util.summarize(self._line_info)}") - else: - invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) + invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) dbg = self._debug_info if dbg is None: return "" @@ -1653,17 +1577,14 @@ def _origin_msg(self): origin += "\n\n(Additional originating lines are not shown.)" return "\n" + origin - def _assert_live(self) -> None: - if not self._trace.main.jaxpr_stack: # type: ignore - raise core.escaped_tracer_error(self, None) - def get_referent(self): frame = self._trace.frame val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) + def _dynamic_jaxpr_tracer_shaped_abstractify(x): - return core.raise_to_shaped(x.aval) + return x.aval api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: @@ -1737,7 +1658,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer] invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - state_outvars = [self.tracer_to_var[id(trace.full_raise(x))] + state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars @@ -1862,6 +1783,9 @@ def lit(a: Atom) -> Literal | None: newvars: dict[Var, Var] = {} newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) + lit_or_var = ( + lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) + ) dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: @@ -1880,10 +1804,10 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: new_invars = [var(v) for v in jaxpr.invars] new_eqns = [] for eqn in jaxpr.eqns: - invars = [lit(x) or var(x) for x in eqn.invars] + invars = [lit_or_var(x) for x in eqn.invars] outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] + new_outvars = [lit_or_var(v) for v in jaxpr.outvars] jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, @@ -1892,11 +1816,25 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: class DynamicJaxprTrace(core.Trace): - __slots__ = [] - - @property - def frame(self): - return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error + def __init__(self): + self.frame = JaxprStackFrame() + + def invalidate(self): + # avoid cyclic refs + self.frame.tracers = [] + self.frame.constid_to_tracer = {} + + def to_jaxpr_tracer(self, x): + as_local_var = self.frame.tracer_to_var.get(id(x)) + if as_local_var is None: + if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr + with core.set_current_trace(self): + x = x.dimension_as_value() + return self.to_jaxpr_tracer(x) + else: + return self.new_const(x) + else: + return x def new_arg(self, aval): tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) @@ -1909,7 +1847,9 @@ def new_const(self, c): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: - aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) + aval = get_aval(c) + if hasattr(aval, "weak_type"): + aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) aval = self._lift_tracers_in_aval(aval) tracer = self._new_const(aval, c) return tracer @@ -1924,22 +1864,11 @@ def _new_const(self, aval, c) -> DynamicJaxprTracer: self.frame.constvar_to_val[var] = c return tracer - def sublift(self, t): - # When lifting closed-over tracers corresponding to this same trace, the - # variable to lift could have tracers (representing axis size variables) in - # its shape. We must lift those too! - tracer = self.frame.constid_to_tracer.get(id(t)) - if tracer is None: - aval = raise_to_shaped(get_aval(t), weak_type=dtypes.is_weakly_typed(t)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, t) - return tracer - def _lift_tracers_in_aval(self, aval): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): return aval - shape = [self.full_raise(d) if isinstance(d, Tracer) else d + shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d for d in aval.shape] return aval.update(shape=tuple(shape)) @@ -1956,17 +1885,16 @@ def makevar(self, tracer): var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) return var - def instantiate_const(self, val): - if (isinstance(val, Tracer) and val._trace.main is self.main - and val._trace.sublevel == self.sublevel): - return val - else: - return self.new_const(val) + def is_const(self, tracer): + return self.frame.tracer_to_var.get(id(tracer)) is None def process_primitive(self, primitive, tracers, params): + if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): + return primitive.bind_with_trace(core.eval_trace, tracers, params) + jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: - return custom_staging_rules[primitive](self, *tracers, **params) - return self.default_process_primitive(primitive, tracers, params) + return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) + return self.default_process_primitive(primitive, jaxpr_tracers, params) def default_process_primitive(self, primitive, tracers, params): avals = [t.aval for t in tracers] @@ -1986,16 +1914,12 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f, explicit_tracers, params): if f.in_type is None: - f = lu.annotate(f, tuple((raise_to_shaped(t.aval), True) - for t in explicit_tracers)) + f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = [*implicit_tracers, *explicit_tracers] + in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation - with core.new_sublevel(): - # TODO(lenamartens): Make call_primitive name -> API function name mapping. - # (currently this will display eg. 'xla_call' instead of `jit`) - dbg = debug_info_final(f, call_primitive.name) - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg) + dbg = debug_info_final(f, call_primitive.name) + jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) @@ -2009,7 +1933,7 @@ def process_call(self, call_primitive, f, explicit_tracers, params): aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) update_params = call_param_updaters.get(call_primitive) @@ -2017,25 +1941,21 @@ def process_call(self, call_primitive, f, explicit_tracers, params): new_params = update_params(new_params, [True] * len(explicit_tracers), len(consts) + len(implicit_tracers)) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, - new_params, new_params['call_jaxpr'].effects, - source_info) + new_params, new_params['call_jaxpr'].effects, source_info) self.frame.add_eqn(eqn) return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - def post_process_call(self, call_primitive, out_tracers, params): - assert False # unreachable - def process_map(self, map_primitive, f, tracers, params): + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) if in_axis is not None else a for a, in_axis in zip(in_avals, params['in_axes'])] - with core.extend_axis_env(axis_name, params["global_axis_size"], None): - with core.new_sublevel(): - jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic( - f, self.main, reduced_in_avals, - debug_info=debug_info_final(f, map_primitive.name)) + with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): + jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( + f, reduced_in_avals, + debug_info=debug_info_final(f, map_primitive.name)) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2047,7 +1967,7 @@ def process_map(self, map_primitive, f, tracers, params): source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, @@ -2062,16 +1982,12 @@ def process_map(self, map_primitive, f, tracers, params): self.frame.add_eqn(eqn) return out_tracers - def post_process_map(self, map_primitive, out_tracers, params): - assert False # unreachable - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] - with core.new_sublevel(): - fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) @_memoize def jvp_jaxpr_thunk(*in_zeros): @@ -2079,12 +1995,12 @@ def jvp_jaxpr_thunk(*in_zeros): nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) - jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) + jaxpr, _, out_consts, () = trace_to_jaxpr_dynamic(jvp_, in_avals_) return jaxpr, out_consts, out_zeros() out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, @@ -2096,29 +2012,24 @@ def jvp_jaxpr_thunk(*in_zeros): self.frame.add_eqn(eqn) return out_tracers - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] - with core.new_sublevel(): - fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) - @_memoize def fwd_jaxpr_from_zeros(*zeros): for store in fwd.stores: store and store.reset() fwd_ = _interleave_fun(fwd, zeros) - jaxpr, _, consts, atr = trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals) + jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals) if atr: raise NotImplementedError return jaxpr, consts out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, dict(fun_jaxpr=closed_fun_jaxpr, @@ -2131,38 +2042,32 @@ def fwd_jaxpr_from_zeros(*zeros): self.frame.add_eqn(eqn) return out_tracers - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - def process_custom_transpose(self, prim, call, tracers, *, transpose, out_types, lin_tree, res_tree, out_tree): + tracers = map(self.to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] in_avals_t = [*[t.aval for t in tracers_res], *out_types] - with core.new_sublevel(): - call_jaxpr, out_avals, call_consts, () = trace_to_subjaxpr_dynamic( - call, self.main, in_avals_p) + call_jaxpr, out_avals, call_consts, _ = trace_to_jaxpr_dynamic(call, in_avals_p) closed_call_jaxpr = core.ClosedJaxpr( convert_constvars_jaxpr(call_jaxpr), ()) transpose_flat, in_tree2 = flatten_fun_nokwargs( lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) - main_ = ref(self.main) # the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts @_memoize def transpose_jaxpr_thunk(): for store in transpose_flat.stores: store.reset() - jaxpr, _, consts, () = trace_to_subjaxpr_dynamic( - transpose_flat, main_(), in_avals_t) + jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, call_consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, @@ -2174,42 +2079,42 @@ def transpose_jaxpr_thunk(): self.frame.add_eqn(eqn) return out_tracers + def to_jaxpr(self, out_tracers: Sequence[Tracer]): + return self.frame.to_jaxpr(self, out_tracers) + custom_staging_rules: dict[Primitive, Callable] = {} -@lu.transformation -def _interleave_fun(every_others, *args, **kwargs): +@lu.transformation2 +def _interleave_fun(f, every_others, *args, **kwargs): args_ = [x for pair in zip(args, every_others) for x in pair] - yield (yield (args_, kwargs)) + return f(*args_, **kwargs) +# TODO: consider renaming to "lazy_thunk" def _memoize(fn): cells = {} - saved_state = core.thread_local_state.trace_state.copy() sentinel = object() def memoized(*args): out = cells.get(args, sentinel) if out is sentinel: - prev_state = core.thread_local_state.trace_state - core.thread_local_state.trace_state = saved_state - try: + with core.set_current_trace(None): out = cells[args] = fn(*args) - finally: - core.thread_local_state.trace_state = prev_state return out return memoized -@lu.transformation_with_aux -def _jvp_jaxpr_zeros(in_zeros, zero_avals, *primal_tangent_avals): +@lu.transformation_with_aux2 +def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): in_primals, nz_in_tangents = split_list(primal_tangent_avals, [len(in_zeros)]) symbolic_zeros = map(ad_util.SymbolicZero, zero_avals) tangents = merge_lists(in_zeros, nz_in_tangents, symbolic_zeros) - out = yield (*in_primals, *tangents), {} + out = f(*in_primals, *tangents) n, ragged = divmod(len(out), 2) assert not ragged out_primals, out_tangents = out[:n], out[n:] out_zeros = [type(t) is ad_util.SymbolicZero for t in out_tangents] out_nz_tangents, _ = partition_list(out_zeros, out_tangents) - yield [*out_primals, *out_nz_tangents], out_zeros + store.store(out_zeros) + return [*out_primals, *out_nz_tangents] # TODO(mattjj): remove this DebugInfo and helper functions, replace with # api_util.py versions @@ -2271,106 +2176,43 @@ def trace_to_jaxpr_dynamic( debug_info: DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: - main.jaxpr_stack = () # type: ignore - jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) - del main, fun - return jaxpr, out_avals, consts, attrs_tracked - - -def trace_to_subjaxpr_dynamic( - fun: lu.WrappedFun, - main: core.MainTrace, - in_avals: Sequence[AbstractValue], - *, - keep_inputs: Sequence[bool] | None = None, - debug_info: DebugInfo | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - frame = JaxprStackFrame() - frame.debug_info = debug_info - with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): - trace = DynamicJaxprTrace(main, core.cur_sublevel()) + trace = DynamicJaxprTrace() + trace.frame.debug_info = debug_info + with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - ans = fun.call_wrapped(*in_tracers_) - out_tracers = map(trace.full_raise, ans) - jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers) - del fun, main, trace, frame, in_tracers, out_tracers, ans + in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) + + out_tracers = map(trace.to_jaxpr_tracer, ans) + jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers) + del trace, fun, in_tracers, out_tracers, ans + config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked - @profiler.annotate_function def trace_to_jaxpr_dynamic2( fun: lu.WrappedFun, debug_info: DebugInfo | None = None ) -> tuple[Jaxpr, OutputType, list[Any]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: - main.jaxpr_stack = () # type: ignore - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) - del main, fun - return jaxpr, out_type, consts - -def trace_to_subjaxpr_dynamic2( - fun: lu.WrappedFun, main: core.MainTrace, - debug_info: DebugInfo | None = None -) -> tuple[Jaxpr, OutputType, list[Any]]: - in_avals, keep_inputs = unzip2(fun.in_type) - frame = JaxprStackFrame() - frame.debug_info = debug_info - with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): - trace = DynamicJaxprTrace(main, core.cur_sublevel()) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - ans = fun.call_wrapped(*in_tracers_) - out_tracers = map(trace.full_raise, ans) - jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers) - del fun, main, trace, frame, in_tracers, out_tracers, ans - return jaxpr, out_type, consts - - -@contextmanager -def extend_jaxpr_stack(main, frame): - main.jaxpr_stack = main.jaxpr_stack + (frame,) - try: - yield - finally: - assert frame is main.jaxpr_stack[-1] - main.jaxpr_stack = main.jaxpr_stack[:-1] - -@profiler.annotate_function -def trace_to_jaxpr_final( - fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], - debug_info: DebugInfo | None = None, - keep_inputs: Sequence[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: - main.jaxpr_stack = () # type: ignore - with core.new_sublevel(): - jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) - del fun, main - return jaxpr, out_avals, consts - - -@profiler.annotate_function -def trace_to_jaxpr_final2( - fun: lu.WrappedFun, debug_info: DebugInfo | None = None - ) -> tuple[Jaxpr, OutputType, list[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: - main.jaxpr_stack = () # type: ignore - with core.new_sublevel(): - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) - del fun, main - return jaxpr, out_type, consts + trace = DynamicJaxprTrace() + with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): + trace.frame.debug_info = debug_info + in_avals, keep_inputs = unzip2(fun.in_type) + in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) + out_tracers = map(trace.to_jaxpr_tracer, ans) + jaxpr = trace.frame.to_jaxpr2(out_tracers) + del trace, in_tracers, out_tracers, ans + return jaxpr AbstractedAxisName = Hashable AbstractedAxesSpec = Union[ @@ -2472,7 +2314,7 @@ def _collect_implicit( for i, name in spec.items(): if name not in idxs and id(x.shape[i]) not in explicit_tracers: idxs[name] = DBIdx(next(counter)) - implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) + implicit_types.append(get_aval(x.shape[i])) if isinstance(x, Tracer): explicit_tracers.setdefault(id(x), explicit_idx) # use the first @@ -2491,7 +2333,7 @@ def _arg_type( ) -> AbstractValue: # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return core.raise_to_shaped(aval) + if not spec: return aval shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d for i, d in enumerate(aval.shape)] assert not any(isinstance(d, Tracer) for d in shape) @@ -2555,8 +2397,8 @@ def _extract_implicit_args( for d1, d2 in zip(aval.shape, tracer.aval.shape): if isinstance(d1, DBIdx): if tracers[d1.val] is None: - tracers[d1.val] = trace.instantiate_const(d2) - assert tracers[d1.val] is trace.instantiate_const(d2) + tracers[d1.val] = trace.to_jaxpr_tracer(d2) + assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) assert all(t is not None for t in tracers) return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore @@ -2693,32 +2535,9 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): return prim.bind(*subfuns, *args, **bind_params) -# TODO(mattjj): the following are deprecated; update callers to _nounits version -# See https://github.com/jax-ml/jax/pull/9498 -@lu.transformation -def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool], - pvals: Sequence[PartialVal]): - assert all(isinstance(pv, PartialVal) for pv in pvals), pvals - trace = main.with_cur_sublevel() - in_tracers = map(trace.new_arg, pvals) - ans = yield in_tracers, {} - assert isinstance(ans, (list, tuple)), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers) - jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers) - out_pvals = [t.pval for t in out_tracers] - del trace, in_tracers, out_tracers - yield jaxpr, (out_pvals, consts, env) - -partial_eval_jaxpr: Callable - def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): if instantiate: - return trace.instantiate_const(trace.full_raise(tracer)) + return trace.instantiate_const(tracer) else: return tracer diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a14ca3dcabd8..d48e81b9092c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -16,7 +16,6 @@ from __future__ import annotations import enum -from contextlib import contextmanager import collections from collections import namedtuple from collections.abc import Callable, Sequence, Iterable @@ -68,8 +67,8 @@ from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, - UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, - is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, + UnspecifiedValue, get_array_mapping as _get_array_mapping, + array_mapping_to_axis_resources, SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, @@ -106,21 +105,42 @@ class WeakRefList(list): ### util + +def to_xc_copy_semantics(copy_semantics): + out = [] + for cs in copy_semantics: + if cs is None or cs == dispatch.CopySemantics.ALIAS: + out.append(xc.ArrayCopySemantics.REUSE_INPUT) + elif cs == dispatch.CopySemantics.COPY: + out.append(xc.ArrayCopySemantics.ALWAYS_COPY) + elif cs == dispatch.CopySemantics.DONATE: + out.append(xc.ArrayCopySemantics.DONATE_INPUT) + else: + assert isinstance(cs, xc.ArrayCopySemantics) + out.append(cs) + return out + + def identity(x): return x @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], layouts, args, - canonicalize=True) -> Sequence[xc.ArrayImpl]: +def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, + args, canonicalize=True) -> Sequence[xc.ArrayImpl]: + xc_copy_semantics = to_xc_copy_semantics(copy_semantics) + del copy_semantics # Fast path for one argument. if len(args) == 1: arg = args[0] if canonicalize: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)]([arg], shardings, layouts) - - # type(arg) -> (list[indices], list[args], list[shardings]) - batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore - for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)): + return shard_arg_handlers[type(arg)]([arg], shardings, layouts, + xc_copy_semantics) + + # type(arg) -> (list[indices], list[args], list[shardings], list[layouts], + # list[copy_semantics]) + batches = collections.defaultdict(lambda: ([], [], [], [], [])) # type: ignore + for i, (arg, sharding, layout, cs) in enumerate( + safe_zip(args, shardings, layouts, xc_copy_semantics)): if canonicalize: arg = xla.canonicalize_dtype(arg) batch = batches[type(arg)] @@ -128,14 +148,15 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, batch[1].append(arg) batch[2].append(sharding) batch[3].append(layout) + batch[4].append(cs) # Call `shard_arg_handlers` per batch and build a flat list of arrays returned # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. results: list[jax.Array | None] = [None] * len(args) - for t, (indices, a, s, l) in batches.items(): - outs = shard_arg_handlers[t](a, s, l) + for t, (indices, a, s, l, cs) in batches.items(): + outs = shard_arg_handlers[t](a, s, l, cs) for i, out in safe_zip(indices, outs): results[i] = out assert all(result is not None for result in results) @@ -143,13 +164,14 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, shard_arg_handlers: dict[ - Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]] + Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any], Sequence[Any]], + Sequence[Any]] ] = {} @lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None or is_unspecified(sharding): + if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue): return True if (aval is core.abstract_token or aval.dtype == dtypes.float0 or dtypes.issubdtype(aval.dtype, dtypes.extended)): @@ -173,17 +195,12 @@ def is_default_layout(curr_layout, sharding, aval): raise -@lru_cache(maxsize=1024) -def _get_replicated_slices(num_addressable_devices: int): - return ((slice(None),),) * num_addressable_devices - - -def _masked_array_error(xs, shardings, layouts): +def _masked_array_error(xs, shardings, layouts, copy_semantics): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_np_array(xs, shardings, layouts): +def _shard_np_array(xs, shardings, layouts, copy_semantics): results = [] for x, sharding, layout in safe_zip(xs, shardings, layouts): devices = sharding._addressable_device_assignment @@ -203,12 +220,12 @@ def _shard_np_array(xs, shardings, layouts): for _t in array_types: shard_arg_handlers[_t] = _shard_np_array -def _shard_darray(xs, shardings, layouts): - return shard_args(shardings, layouts, [x._data for x in xs]) +def _shard_darray(xs, shardings, layouts, copy_semantics): + return shard_args(shardings, layouts, copy_semantics, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(xs, shardings, layouts): - return shard_args(shardings, layouts, [x._buf for x in xs]) +def _shard_mutable_array(xs, shardings, layouts, copy_semantics): + return shard_args(shardings, layouts, copy_semantics, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -374,14 +391,15 @@ def _emap_impl(fun: lu.WrappedFun, *args, emap_info = EmapInfo(backend, devices) shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes] - with core.new_base_main(MapTrace, emap_info=emap_info) as main: - with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main): - t = main.with_cur_sublevel() - tracers = [MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)] + trace = MapTrace(axis_name, emap_info) + with core.extend_axis_env_nd([(axis_name, axis_size)]): + tracers = [MapTracer(trace, arg, s) for arg, s in zip(args, shard_axes)] + with core.set_current_trace(trace): ans = fun.call_wrapped(*tracers) - out_tracers = map(t.full_raise, ans) - outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers) - del main + + out_tracers = map(trace.to_map_tracer, ans) + outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers) + out_axes = out_axes_thunk() platform = xb.get_backend(backend).platform @@ -441,25 +459,33 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], class MapTrace(core.Trace): - def __init__(self, *args, emap_info): - super().__init__(*args) + def __init__(self, axis_name, emap_info): self.emap_info = emap_info + self.axis_name = axis_name - def pure(self, val): - return MapTracer(self, val, {}) - - def sublift(self, tracer): - return MapTracer(self, tracer.val, tracer.shard_axes) + def to_map_tracer(self, val): + if isinstance(val, MapTracer): + return val + else: + return MapTracer(self, val, {}) def process_primitive(self, primitive, tracers, params): - info = self.main.payload["emap_info"] + if primitive is jax._src.lax.parallel.axis_index_p: + return self.process_axis_index(**params) + if primitive is jax._src.lax.parallel.psum_p: + f = HashableFunction( + lambda *xs: jax._src.lax.parallel.psum( + xs, axis_name=params['axes'], axis_index_groups=params['axis_index_groups']), + (primitive, tuple(params.items()))) + else: + f = HashableFunction(lambda *args: primitive.bind(*args, **params), + (primitive, tuple(params.items()))) + tracers = map(self.to_map_tracer, tracers) vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers]) - names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env - if f.main_trace is self.main) + info = self.emap_info + names = core.get_axis_env().axis_names() all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations - f = HashableFunction(lambda *args: primitive.bind(*args, **params), - (primitive, tuple(params.items()))) - f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes) + f_mapped, out_shard_axes = _multi_pmap(f, self.emap_info, names, all_axes) with core.eval_context(), jax.disable_jit(False): outvals = f_mapped(*vals) if primitive.multiple_results: @@ -484,14 +510,12 @@ def process_map(self, map_primitive, fun, tracers, params): shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s} if ax is not None else s for v, ax, s in zip(vals, in_axes, shard_axes)] - # TODO(mattjj): use _emap_subtrace here? - with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main): - t = self.main.with_cur_sublevel() - in_tracers = map(partial(MapTracer, t), vals, shard_axes) - ans = fun.call_wrapped(*in_tracers) - out_tracers = map(t.full_raise, ans) + in_tracers = map(partial(MapTracer, self), vals, shard_axes) + with core.extend_axis_env_nd([(axis_name, axis_size)]): + with core.set_current_trace(self): + ans = fun.call_wrapped(*in_tracers) + out_tracers = map(self.to_map_tracer, ans) out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers) - del t, in_tracers, ans, out_tracers out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst) for v, s, dst in zip(out, outaxes, out_axes_thunk())) return map(partial(MapTracer, self), out, outaxes) @@ -502,11 +526,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros # always base main, can drop jvp - in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) - fun, out_axes = _emap_subtrace(fun, self.main, in_axes) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(MapTracer, self), out_vals, out_axes()) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): @@ -515,32 +536,18 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp - in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) - fun, out_axes = _emap_subtrace(fun, self.main, in_axes) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(MapTracer, self), out_vals, out_axes()) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) - def process_axis_index(self, frame): + def process_axis_index(self, axis_name): bind = HashableFunction( - lambda _: jax.lax.axis_index(frame.name), - (jax.lax.axis_index, frame.name)) + lambda _: jax.lax.axis_index(axis_name), + (jax.lax.axis_index, axis_name)) fake_primitive = FakePrimitive(multiple_results=False, bind=bind) - with core.eval_context(): - range = jax.lax.iota(np.int32, frame.size) - dummy_tracer = MapTracer(self, range, {frame.name: 0}) + range = jax.lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) + dummy_tracer = MapTracer(self, range, {axis_name: 0}) return self.process_primitive(fake_primitive, (dummy_tracer,), {}) -@lu.transformation_with_aux -def _emap_subtrace(main, in_axes, *in_vals): - t = main.with_cur_sublevel() - in_tracers = map(partial(MapTracer, t), in_vals, in_axes) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers) - del t, in_tracers, ans, out_tracers - yield out_vals, out_axes - def _annot_to_flat(ndim: int, mapped_axes: Iterable[int], annotation: int | None) -> int | None: if annotation is None: return None @@ -680,15 +687,15 @@ def find_replicas( num_global_replicas = global_axis_size * jaxpr_replicas return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas) -@lu.transformation -def _change_argument_ranks(in_axes, out_axes_thunk, *args): +@lu.transformation2 +def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): args = tuple( arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) - results = yield (args, {}) + results = f(*args) out_axes = out_axes_thunk() - yield tuple( + return tuple( x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) @@ -706,14 +713,13 @@ def stage_parallel_callable( fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) else: fun = orig_fun - with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): + with core.extend_axis_env_nd([(pci.axis_name, pci.global_axis_size)]): with dispatch.log_elapsed_time( - "Finished tracing + transforming {fun_name} for pmap in {elapsed_time:.9f} sec", + "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( + jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( fun, sharded_avals, pe.debug_info_final(fun, "pmap")) jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info) - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) assert len(out_sharded_avals) == len(pci.out_axes), ( len(out_sharded_avals), len(pci.out_axes)) @@ -748,7 +754,8 @@ def get_pmap_jaxpr( pci = ParallelCallableInfo( name, backend, axis_name, axis_size, global_axis_size, devices, in_axes, out_axes_thunk, avals) - jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) + with core.extend_axis_env_nd([(axis_name, axis_size)]): + jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) return closed_jaxpr, backend, replicas, shards, pci @@ -847,7 +854,7 @@ def lower_parallel_callable( backend.platform) module_name = f"pmap_{fun.__name__}" platforms = lowering_platforms or (backend.platform,) - with maybe_extend_axis_env(axis_name, global_axis_size, None): + with core.extend_axis_env_nd([(axis_name, global_axis_size)]): ordered_effects = list( effects.ordered_effects.filter_in(closed_jaxpr.effects)) if ordered_effects: @@ -1150,7 +1157,8 @@ class InputsHandler: def __init__(self, in_shardings, in_layouts, local_devices=None, input_indices=None): - self.handler = partial(shard_args, in_shardings, in_layouts) + self.handler = partial(shard_args, in_shardings, in_layouts, + [None] * len(in_shardings)) self.in_shardings = in_shardings self.in_layouts = in_layouts self.local_devices = local_devices @@ -1342,8 +1350,10 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None axis_name = eqn.params["axis_name"] - with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None): + with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) _, in_axes = partition_list(used_inputs, eqn.params['in_axes']) @@ -1402,21 +1412,6 @@ def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) -def _pmap_axis_subst(params, subst, traverse): - if 'call_jaxpr' not in params: - return params - if not traverse: - return params - def shadowed_subst(name): - return (name,) if name in params['axis_name'] else subst(name) - with maybe_extend_axis_env(params['axis_name'], - params['global_axis_size'], None): - new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'], - shadowed_subst) - return dict(params, call_jaxpr=new_jaxpr) -core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst - - def _unravel_index_hlo(axis_env): div = mlir.ir_constant( np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32)) @@ -1525,7 +1520,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, if in_axis is not None else in_node for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes)) - with maybe_extend_axis_env(axis_name, global_axis_size, None): + with core.extend_axis_env_nd([(axis_name, global_axis_size)]): sub_ctx = ctx.module_context.replace( axis_context=sharding_impls.ReplicaAxisContext(new_env)) sharded_outs, _ = mlir.jaxpr_subcomp( @@ -1643,7 +1638,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp def check_if_any_auto( shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool: for s in shardings: - if is_auto(s): + if isinstance(s, AUTO): return True return False @@ -1712,8 +1707,11 @@ class DeviceAssignmentMismatchError(Exception): ] -def _get_default_device() -> xc.Device: - return config.default_device.value or xb.local_devices()[0] +def get_default_device() -> xc.Device: + if isinstance(config.default_device.value, str): + return xb.get_backend(config.default_device.value).local_devices()[0] + else: + return config.default_device.value or xb.local_devices()[0] def _get_and_check_device_assignment( @@ -1727,14 +1725,14 @@ def _get_and_check_device_assignment( devices = tuple(devices) for i, s_type, source_info in shardings: - if is_unspecified(i): + if isinstance(i, UnspecifiedValue): continue if first_sharding_info is None: first_sharding_info = ( - (i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore - else (i._device_assignment, s_type, source_info)) # type: ignore - arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore + (i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO) + else (i._device_assignment, s_type, source_info)) + arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment if not devices: if first_sharding_info[0] != arr_device_assignment: raise DeviceAssignmentMismatchError([ @@ -1748,7 +1746,7 @@ def _get_and_check_device_assignment( if first_sharding_info is None and devices: final_device_assignment = devices elif first_sharding_info is None: - final_device_assignment = (_get_default_device(),) + final_device_assignment = (get_default_device(),) else: final_device_assignment = first_sharding_info[0] # type: ignore return xb.get_device_backend(final_device_assignment[0]), final_device_assignment @@ -1786,7 +1784,6 @@ def _dce_jaxpr(closed_jaxpr, api_name, fun_name, donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) del kept_const_idx - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) return closed_jaxpr, donated_invars, kept_var_idx, name_stack @@ -1836,7 +1833,8 @@ class SemanticallyEqualShardings: def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], avals: tuple[core.AbstractValue]): gspmd_shardings = [ - s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore + s if isinstance(s, (UnspecifiedValue, AUTO)) + else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings @@ -2004,7 +2002,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i): + if isinstance(i, (UnspecifiedValue, AUTO)): continue if i.memory_kind is None: # pytype: disable=attribute-error continue @@ -2034,7 +2032,7 @@ def _default_rule(prim, num_outvars, *_, **__): if in_shardings is None: invar_mem_kind = [None] * len(jaxpr.invars) else: - invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind + invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind for s in in_shardings] safe_map(write, jaxpr.invars, invar_mem_kind) safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) @@ -2129,7 +2127,7 @@ def _abstract_to_concrete_mesh(abstract_mesh): out = [] for s, a in zip(shardings, avals): - if is_unspecified(s) and a.sharding is not None: + if isinstance(s, UnspecifiedValue) and a.sharding is not None: out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), a.sharding.spec)) else: @@ -2150,6 +2148,7 @@ def lower_sharding_computation( *, keep_unused: bool, context_mesh: mesh_lib.Mesh | None, + compiler_options_kvs: tuple[tuple[str, Any], ...], lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None, @@ -2211,14 +2210,17 @@ def lower_sharding_computation( out_shardings = _concretize_abstract_shardings( out_shardings, global_out_avals, device_assignment) - platforms = lowering_platforms or (backend.platform,) + # TODO(parkers): One _raw_platform has been unified with platform, + # change this back to just read platform. + platforms = lowering_platforms or ( + getattr(backend, "_raw_platform", backend.platform),) committed = bool( devices_from_context or len(device_assignment) > 1 or - any(not is_unspecified(i) for i in unique_in_shardings) or - any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or - any(not is_unspecified(o) for o in unique_out_shardings)) + any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or + any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or + any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings)) da_object = _create_da_object(tuple(device_assignment)) @@ -2276,6 +2278,7 @@ def lower_sharding_computation( module, donated_invars, platforms, + compiler_options_kvs, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2327,11 +2330,13 @@ class MeshComputation(stages.XlaLowering): def __init__(self, name: str, hlo: ir.Module, donated_invars: Sequence[bool], platforms: Sequence[str], + compiler_options_kvs: tuple[tuple[str, Any], ...], **compile_args): self._name = name self._hlo = hlo self._donated_invars = donated_invars self._platforms = platforms + self._compiler_options_kvs = compiler_options_kvs self.compile_args = compile_args self._executable = None @@ -2341,11 +2346,14 @@ def stablehlo(self) -> ir.Module: return self._hlo def compile(self, compiler_options=None) -> MeshExecutable: - if self._executable is None or compiler_options is not None: + t_compiler_options = (() if compiler_options is None else + tuple(compiler_options.items())) + compiler_options_kvs = self._compiler_options_kvs + t_compiler_options + if self._executable is None or compiler_options_kvs: executable = UnloadedMeshExecutable.from_hlo( self._name, self._hlo, **self.compile_args, - compiler_options=compiler_options) - if compiler_options is None: + compiler_options_kvs=compiler_options_kvs) + if not compiler_options_kvs: self._executable = executable return executable return self._executable @@ -2610,8 +2618,7 @@ def create_compile_options( else: xla_device_assignment = np_dev.reshape((num_replicas, num_partitions)) - fdo_profile = (None if compiler_options is None else - compiler_options.pop("fdo_profile", None)) + fdo_profile = compiler_options.pop("fdo_profile", None) compile_options = compiler.get_compile_options( num_replicas=num_replicas, @@ -2643,17 +2650,11 @@ def create_compile_options( def _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, - da, pmap_nreps, compiler_options_keys, - compiler_options_values, - pgle_profiler): + da, pmap_nreps, compiler_options_kvs, pgle_profiler): # One would normally just write: dev = np.array(device_assignment) # The formulation below is substantially faster if there are many devices. dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da))) - - if compiler_options_keys is None: - compiler_options = None - else: - compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + compiler_options = dict(compiler_options_kvs) compile_options = create_compile_options( computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, @@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings( new_in_shardings = [] for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings, global_in_avals): - if is_unspecified(orig): + if isinstance(orig, UnspecifiedValue): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) @@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings( new_out_shardings = [] for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings, global_out_avals): - if is_unspecified(orig): + if isinstance(orig, UnspecifiedValue): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) @@ -2746,11 +2747,11 @@ def _maybe_get_and_check_out_shardings( return new_out_shardings -def finalize_out_shardings(out_shardings, device_assignment): +def finalize_shardings(shardings, device_assignment): if len(device_assignment) == 1: return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) - if isinstance(o, GSPMDSharding) else o for o in out_shardings] - return out_shardings + if isinstance(o, GSPMDSharding) else o for o in shardings] + return shardings @dataclasses.dataclass @@ -2817,53 +2818,49 @@ def from_hlo(name: str, committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, + compiler_options_kvs: tuple[tuple[str, Any], ...], pmap_nreps: int = 1, mut: MutationData | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, - compiler_options=None, pgle_profiler: profiler.PGLEProfiler | None = None, intermediate_shardings: Sequence[JSharding] | None = None, context_mesh: mesh_lib.Mesh | None = None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) - compiler_options_keys = tuple( - compiler_options.keys()) if compiler_options is not None else None - compiler_options_values = tuple( - compiler_options.values()) if compiler_options is not None else None if isinstance(device_assignment, xc.DeviceList): da = device_assignment else: da = _create_da_object(tuple(device_assignment)) del device_assignment - allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i) + allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings) - allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o) + allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO)) for o in out_shardings) mesh = None if auto_spmd_lowering: for i in it.chain.from_iterable([in_shardings, out_shardings]): - if is_auto(i): - mesh = i.mesh # type: ignore + if isinstance(i, AUTO): + mesh = i.mesh break xla_executable = _cached_compilation( hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_keys, compiler_options_values, pgle_profiler) + compiler_options_kvs, pgle_profiler) if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( xla_executable, mesh) - in_shardings = [x if is_auto(i) else i + in_shardings = [x if isinstance(i, AUTO) else i for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings = [x if is_auto(o) else o + out_shardings = [x if isinstance(o, AUTO) else o for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: @@ -2895,7 +2892,8 @@ def from_hlo(name: str, in_shardings, out_shardings, global_in_avals, global_out_avals, intermediate_shardings, context_mesh) - out_shardings = finalize_out_shardings(out_shardings, da) + in_shardings = finalize_shardings(in_shardings, da) + out_shardings = finalize_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, @@ -2947,6 +2945,7 @@ class JitGlobalCppCacheKeys: out_layouts_treedef: PyTreeDef | None = None out_layouts_leaves: tuple[Any, ...] | None = None use_resource_env: bool = False + compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None @functools.cached_property def contains_explicit_attributes(self): @@ -2954,10 +2953,11 @@ def contains_explicit_attributes(self): self.donate_argnames is not None or self.device is not None or self.backend is not None or - any(not is_unspecified(i) for i in self.in_shardings_leaves) or - any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or + any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or any(i is not None for i in self.in_layouts_leaves) or - any(o is not None for o in self.out_layouts_leaves)) + any(o is not None for o in self.out_layouts_leaves) or + self.compiler_options_kvs) def reflatten_outputs_for_dispatch(out_tree, out_flat): @@ -3078,7 +3078,7 @@ def aot_cache_miss(*args, **kwargs): JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): - return shard_args([sharding], [layout], [x])[0] + return shard_args([sharding], [layout], [None], [x])[0] def check_arg_avals_for_call(ref_avals, arg_avals, @@ -3130,7 +3130,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): def check_device_backend_on_shardings(shardings) -> bool: for i in shardings: - if is_unspecified(i) or is_auto(i): + if isinstance(i, (UnspecifiedValue, AUTO)): continue if getattr(i, '_device_backend', False): return True @@ -3156,7 +3156,7 @@ def check_array_xla_sharding_layout_match( args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): if not isinstance(arg, ArrayImpl): continue - if is_unspecified_or_auto(xs): + if isinstance(xs, (UnspecifiedValue, AUTO)): continue db_xs = check_device_backend_on_shardings([xs]) @@ -3202,9 +3202,3 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: parsed_pspec = sharding_impls.prepare_axis_resources( pspec, "pspec to array_mapping") return _get_array_mapping(parsed_pspec) - - -@contextmanager -def maybe_extend_axis_env(*args, **kwargs): - with core.extend_axis_env(*args, **kwargs): - yield diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 14635a46ea33..46bc7bef7ca7 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -25,7 +25,7 @@ from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -101,7 +101,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: _xla_shape_handlers: dict[type[core.AbstractValue], Callable[[Any], Sequence[xc.Shape]]] = { ShapedArray: _make_array_shape, - ConcreteArray: _make_array_shape, } _xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index db03143f1083..34395756f25a 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -28,7 +28,6 @@ fori_loop as fori_loop, map as map, scan as scan, - scan_bind as scan_bind, scan_p as scan_p, _scan_impl as _scan_impl, while_loop as while_loop, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index c634148768fc..547415c098b4 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -35,7 +35,7 @@ from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects +from jax._src.core import replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -87,6 +87,7 @@ def switch(index, branches, *operands): Args: index: Integer scalar type, indicating which branch function to apply. branches: Sequence of functions (A -> B) to be applied based on ``index``. + All branches must return the same output structure. operands: Operands (A) input to whichever branch is applied. Returns: @@ -130,8 +131,7 @@ def switch(index, branches, *operands): hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) - if (config.disable_jit.value and - isinstance(core.get_aval(index), ConcreteArray)): + if (config.disable_jit.value and core.is_concrete(index)): return branches[int(index)](*operands) ops, ops_tree = tree_flatten(operands) @@ -148,11 +148,6 @@ def switch(index, branches, *operands): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') - if joined_effects: - # Raise index in case of effects to allow data-dependence-based discharging - # of those effects (even if they don't have an explicit data dependence). - index = core.raise_as_much_as_possible(index) - out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) return tree_unflatten(out_trees[0], out) @@ -225,7 +220,7 @@ def cond(pred, true_fun, false_fun, *operands): msg = ("Pred type must be either boolean or number, got {}.") raise TypeError(msg.format(pred_dtype)) - if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray): + if config.disable_jit.value and core.is_concrete(pred): if pred: return true_fun(*operands) else: @@ -263,10 +258,6 @@ def cond(pred, true_fun, false_fun, *operands): f'Effects not supported in `cond`: {disallowed_effects}') index = lax.convert_element_type(pred, np.int32) - if joined_effects: - # Raise index in case of effects to allow data-dependence-based discharging - # of those effects (even if they don't have an explicit data dependence). - index = core.raise_as_much_as_possible(index) false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects) true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) @@ -338,7 +329,7 @@ def _cond_abstract_eval(*avals, branches, **_): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') - return map(raise_to_shaped, branches[0].out_avals), joined_effects + return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): if np.ndim(pred) != np.ndim(on_true): @@ -352,8 +343,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, - dims, branches): +def _cond_batching_rule(axis_data, args, dims, branches): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -375,15 +365,13 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = ( - batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) + batching.bdim_at_front(x, d, axis_data.size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ - batching.batch_jaxpr( - jaxpr, axis_size, in_batched, out_batched, axis_name, spmd_axis_name, - main_type)[0] + batching.batch_jaxpr(jaxpr, axis_data, in_batched, out_batched)[0] for jaxpr in branches] branch_outs = [] @@ -401,13 +389,11 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for b, x, d in zip(ops_bat, ops, op_dims)] branches_out_bat = [ - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, - spmd_axis_name, main_type)[1] + batching.batch_jaxpr(jaxpr, axis_data, ops_bat, False)[1] for jaxpr in branches] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, - spmd_axis_name, main_type)[0] + batching.batch_jaxpr(jaxpr, axis_data, ops_bat, out_bat)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] @@ -527,7 +513,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # jaxpr for each branch. branches_known_ : list[core.ClosedJaxpr] = [] branches_staged_: list[core.ClosedJaxpr] = [] - branch_res_avals: list[core.AbstractValue] = [] + branch_res_avals: list[list[core.AbstractValue]] = [] for jaxpr in branches: jaxpr_known, jaxpr_staged, _, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( @@ -657,7 +643,11 @@ def _ordered_unique(xs): return list(d.keys()) def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + closed_branches = eqn.params['branches'] branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches] @@ -691,7 +681,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, def _transpose_cond_jaxpr(jaxpr, num_res): res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - primal_avals = map(raise_to_shaped, primal_avals) @lu.wrap_init def transposed(*args): @@ -708,7 +697,7 @@ def _cond_transpose(cts, *args, branches): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] - in_avals = map(raise_to_shaped, branches[0].in_avals) + in_avals = branches[0].in_avals num_res = len(ops) - sum(linear) if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): @@ -716,8 +705,7 @@ def _cond_transpose(cts, *args, branches): branches_trans = tuple( _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) - lin_in_avals = [raise_to_shaped(a, weak_type=False) - for a, l in zip(in_avals, linear) if l] + lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l] assert all(core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) @@ -733,12 +721,6 @@ def _cond_transpose(cts, *args, branches): assert next(out_iter, None) is None return [None] + out -def _cond_axis_substitution(params, subst, traverse): - if not traverse: - return params - branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches']) - return dict(params, branches=branches) - def _cond_typecheck(bind_time, *in_atoms, branches): if not bind_time: _, *in_atoms = in_atoms @@ -793,28 +775,16 @@ def _cond_typecheck(bind_time, *in_atoms, branches): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects -def cond_bind(*args, branches): - if config.enable_checks.value: - avals = map(core.get_aval, args) - in_atoms = [core.Var('', a) for a in avals] # dummies - _cond_typecheck(True, *in_atoms, branches=branches) - for jaxpr in branches: - core.check_jaxpr(jaxpr.jaxpr) - return core.AxisPrimitive.bind(cond_p, *args, branches=branches) - -cond_p = core.AxisPrimitive('cond') +cond_p = core.Primitive('cond') cond_p.multiple_results = True cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) -cond_p.def_custom_bind(cond_bind) ad.primitive_jvps[cond_p] = _cond_jvp -ad.reducing_transposes[cond_p] = _cond_transpose +ad.primitive_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval -batching.spmd_axis_primitive_batchers[cond_p] = _cond_batching_rule -batching.axis_primitive_batchers[cond_p] = partial(_cond_batching_rule, None) +batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule xla.register_initial_style_primitive(cond_p) core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) -core.axis_substitution_rules[cond_p] = _cond_axis_substitution pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 21b522b3d8bb..b6ae09d364a3 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -278,34 +278,30 @@ def _cached_for_jaxpr(jaxpr): discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) return core.ClosedJaxpr(discharged_jaxpr, body_consts) -def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, +def _for_vmap(axis_data, args, dims, *, jaxpr, nsteps, reverse, which_linear, unroll): init_batched = [d is not batching.not_mapped for d in dims] closed_jaxpr = _cached_for_jaxpr(jaxpr) batched = init_batched for _ in range(len(batched)): _, out_batched = batching.batch_jaxpr( - closed_jaxpr, - axis_size, [False] + batched, instantiate=batched, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + closed_jaxpr, axis_data, [False] + batched, instantiate=batched) if out_batched == batched: break batched = map(operator.or_, batched, out_batched) else: raise Exception("Invalid fixpoint") - args = [batching.broadcast(x, axis_size, 0) if now_bat and not was_bat + args = [batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)] batched_jaxpr_, _ = batching.batch_jaxpr( - pe.close_jaxpr(jaxpr), axis_size, [False] + batched, [], - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + pe.close_jaxpr(jaxpr), axis_data, [False] + batched, []) batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=which_linear, unroll=unroll) return out_flat, [0 if b else batching.not_mapped for b in batched] -batching.axis_primitive_batchers[for_p] = functools.partial(_for_vmap, None) -batching.spmd_axis_primitive_batchers[for_p] = _for_vmap +batching.fancy_primitive_batchers[for_p] = _for_vmap def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, unroll): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7a9596bf2c0d..9b2d688c322b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,7 +35,7 @@ from jax._src import state from jax._src import util from jax._src.api_util import shaped_abstractify -from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -227,6 +227,11 @@ def scan(f, init, xs, length=None): msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err + if (config.sharding_in_types.value and + not all(x.sharding.spec[0] is None for x in xs_flat)): + raise ValueError('0th dimension of all xs should be replicated. Got ' + f'{", ".join(str(x.sharding.spec) for x in xs_flat)}') + if length is not None: try: length = int(length) @@ -250,7 +255,8 @@ def scan(f, init, xs, length=None): if config.disable_jit.value: if length == 0: - raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") + raise ValueError("zero-length scan is not supported in disable_jit() " + "mode because the output type is unknown.") carry = init ys = [] maybe_reversed = reversed if reverse else lambda x: x @@ -262,7 +268,7 @@ def scan(f, init, xs, length=None): stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat] + xs_avals = [core.get_aval(x) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] def _create_jaxpr(init): @@ -370,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'of the carry output is a {thing2}, so {explanation}' for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: @@ -387,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: @@ -418,9 +430,13 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, consts, carry, xs_ = split_list(args, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) num_trips, remainder = divmod(length, unroll) + if unroll != 1 and num_trips == 1 and remainder == 0: + # In that case, we explicitly want to fully unroll the loop. Put everything + # into the remainder block and avoid lowering to a while loop. + num_trips, remainder = 0, length if unroll == 1: xss = xs_ - yss = _map(partial(_empty_array, (length,)), y_avals) + yss = _map(partial(_empty_array, (length,), None), y_avals) else: if remainder: if not reverse: @@ -428,7 +444,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, else: xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals) + yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals) def cond_fun(while_carry): i, _, _ = while_carry @@ -473,8 +489,11 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) -def _empty_array(prefix, aval): - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape)) +def _empty_array(prefix, length_spec, aval): + sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec)) + if config.sharding_in_types.value else None) + return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), + sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True @@ -482,11 +501,13 @@ def _stage_jaxpr(trace, *tracers, jaxpr): params = dict(call_jaxpr=jaxpr) return trace.default_process_primitive(core.closed_call_p, tracers, params) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr + @eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf -def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects +def _stage_jaxpr_abstract_eval(*_, jaxpr): + return jaxpr.out_avals, jaxpr.effects def _prepend_dim_to_aval(sz, aval): - return core.unmapped_aval(sz, core.no_axis_name, 0, aval) + return core.unmapped_aval(sz, None, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): @@ -670,7 +691,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) - ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) + ys_avals = [core.unmapped_aval(length, None, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in itertools.chain(carry_avals, ys_avals)] @@ -700,7 +721,7 @@ def _maybe_put(x): aval = shaped_abstractify(x) s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) - return result_handler(pxla.shard_args([s], [None], [x])) + return result_handler(pxla.shard_args([s], [None], [None], [x])) else: return x @@ -885,7 +906,7 @@ def transposed(*res1_cbar_bbar_res2): b_ys_avals_stripped + res2_avals)) -def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, +def _scan_batching_rule(axis_data, args, dims, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): @@ -902,11 +923,8 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for _ in range(1 + len(carry_batched)): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( - jaxpr, axis_size, batched, - instantiate=carry_batched + [False] * num_ys, - axis_name=axis_name, - spmd_axis_name=spmd_axis_name, - main_type=main_type) + jaxpr, axis_data, batched, + instantiate=carry_batched + [False] * num_ys) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] if carry_batched_out == carry_batched: break @@ -919,7 +937,7 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims)] - new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched + new_init = [batching.broadcast(x, axis_data.size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched)] @@ -943,7 +961,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params) def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry @@ -1038,7 +1058,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Create residual variables. intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) - ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a) + ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a) for a in ext_avals_mapped] newvar = core.gensym() intensive_res = _map(newvar, intensive_avals) @@ -1116,7 +1136,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, jaxpr.in_avals, [num_consts, num_carry]) carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) - y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a) + y_avals = [core.unmapped_aval(length, None, 0, a) for a in y_avals_mapped] if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): @@ -1209,27 +1229,17 @@ def arrange_jaxpr_args_for_wrapped(args): assert len(refs_out_matching_in_avals) == len(in_avals) return refs_out_matching_in_avals, [*carry_out, *ys] -def scan_bind(*args, **params): - if config.enable_checks.value: - avals = _map(core.get_aval, args) - in_atoms = [core.Var('', a) for a in avals] # dummies - _scan_typecheck(True, *in_atoms, **params) - core.check_jaxpr(params['jaxpr'].jaxpr) - return core.AxisPrimitive.bind(scan_p, *args, **params) - -scan_p = core.AxisPrimitive("scan") +scan_p = core.Primitive("scan") scan_p.multiple_results = True -scan_p.def_custom_bind(scan_bind) scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp -ad.reducing_transposes[scan_p] = _scan_transpose +ad.primitive_transposes[scan_p] = _scan_transpose pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) -batching.axis_primitive_batchers[scan_p] = partial(_scan_batching_rule, None) -batching.spmd_axis_primitive_batchers[scan_p] = _scan_batching_rule +batching.fancy_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom pe.padding_rules[scan_p] = _scan_padding_rule @@ -1379,11 +1389,10 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') - return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects + return body_jaxpr.out_avals, joined_effects -def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, - args, dims, cond_nconsts, cond_jaxpr, +def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): from jax._src.callback import _IOEffect, _OrderedIOEffect if any(_OrderedIOEffect in fn.effects for fn in [body_jaxpr, cond_jaxpr]): @@ -1401,8 +1410,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # reach a fixpoint. for _ in range(1 + len(carry_bat)): _, carry_bat_out = batching.batch_jaxpr( - body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + body_jaxpr, axis_data, bconst_bat + carry_bat, instantiate=carry_bat) if carry_bat == carry_bat_out: break carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out) @@ -1412,8 +1420,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # Knowing how the carry is batched now, we can determine if the predicate is # batched. _, (pred_bat,) = batching.batch_jaxpr( - cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + cond_jaxpr, axis_data, cconst_bat + carry_bat, instantiate=False) if pred_bat: # If the predicate is batched, we have to batch *all* of the carry @@ -1424,13 +1431,9 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, carry_bat = [True] * len(carry_bat) carry_dims = [0] * len(carry_bat) body_jaxpr_batched, _ = batching.batch_jaxpr_axes( - body_jaxpr, axis_size, bconst_dims + carry_dims, - carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name, - main_type=main_type) + body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims) cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( - cond_jaxpr, axis_size, cconst_dims + carry_dims, [0], - axis_name=axis_name, spmd_axis_name=spmd_axis_name, - main_type=main_type) + cond_jaxpr, axis_data, cconst_dims + carry_dims, [0]) else: # If the predicate is not batched, we can look at the `cond_jaxpr`'s out # shape to determine the rank of the predicate. From this rank we pick the @@ -1440,13 +1443,11 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, cond_rank = len(cond_jaxpr.out_avals[0].shape) carry_dims = [cond_rank if b else None for b in carry_bat] body_jaxpr_batched, _ = batching.batch_jaxpr_axes( - body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims) # Now we need to rebatch the `cond_jaxpr` according to the new dims of the # carry. cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( - cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,), - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + cond_jaxpr, axis_data, cconst_dims + carry_dims, (None,)) # To prepare the `init` to the `while_p`, we broadcast values if they are # unbatched and need to have an out axis. If their current batch axis does not @@ -1455,7 +1456,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, new_init = [] for x, old_axis, new_axis in zip(init, init_dims, carry_dims): if old_axis is batching.not_mapped and new_axis is not batching.not_mapped: - new_init.append(batching.broadcast(x, axis_size, new_axis)) + new_init.append(batching.broadcast(x, axis_data.size, new_axis)) elif old_axis is batching.not_mapped and new_axis is batching.not_mapped: new_init.append(x) else: @@ -1891,7 +1892,7 @@ def new_cond(*consts_refs_carry): *[None] * num_carry] return invals_out, carry_out -while_p = core.AxisPrimitive('while') +while_p = core.Primitive('while') while_p.multiple_results = True while_p.def_impl(partial(dispatch.apply_primitive, while_p)) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) @@ -1899,8 +1900,7 @@ def new_cond(*consts_refs_carry): pe.custom_partial_eval_rules[while_p] = _while_partial_eval xla.register_initial_style_primitive(while_p) ad.primitive_transposes[while_p] = _while_transpose_error -batching.axis_primitive_batchers[while_p] = partial(_while_loop_batching_rule, None) -batching.spmd_axis_primitive_batchers[while_p] = _while_loop_batching_rule +batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck @@ -2034,12 +2034,11 @@ def fori_loop(lower, upper, body_fun, init_val): # If we can specialize on the trip count, call scan instead of a while_loop # to enable efficient reverse-mode differentiation. - if (isinstance(core.get_aval(lower), ConcreteArray) and - isinstance(core.get_aval(upper), ConcreteArray)): + if core.is_concrete(lower) and core.is_concrete(upper): try: lower_ = int(lower) upper_ = int(upper) - except TypeError: + except (TypeError, core.InconclusiveDimensionOperation): use_scan = False else: use_scan = True diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4e0f5086b121..f97377b2df6c 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,7 +23,6 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu -from jax._src.core import raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -300,7 +299,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return _map(raise_to_shaped, args_to_raise) + return args_to_raise def _custom_linear_solve_impl(*args, const_lengths, jaxprs): @@ -376,8 +375,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): return [None] * sum(const_lengths) + cotangent_b -def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, - args, dims, const_lengths, jaxprs): +def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] params, b = _split_linear_solve_args(args, const_lengths) @@ -397,15 +395,13 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( - solve, axis_size, solve_bat + b_bat, instantiate=x_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + solve, axis_data, solve_bat + b_bat, instantiate=x_bat) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( - vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + vecmat, axis_data, vecmat_bat + b_bat, instantiate=b_bat) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # keep a slice of only the linear operator part of solve's avals @@ -413,15 +409,13 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( - matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + matvec, axis_data, matvec_bat + x_bat_noaux, instantiate=b_bat) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( - solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + solve_t, axis_data, solve_t_bat + x_bat_noaux, instantiate=x_bat_out) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, @@ -445,7 +439,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, ] # Broadcast out b if necessary new_b = [ - batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else + batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] @@ -458,7 +452,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, return outs, out_dims -linear_solve_p = core.AxisPrimitive('custom_linear_solve') +linear_solve_p = core.Primitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) @@ -468,5 +462,4 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule -batching.axis_primitive_batchers[linear_solve_p] = partial(_linear_solve_batching_rule, None) -batching.spmd_axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule +batching.fancy_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 113c87b60ee0..84b936e97a4c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -47,8 +47,8 @@ from jax._src import state from jax._src import util from jax._src.abstract_arrays import array_types -from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, - raise_to_shaped, abstract_token, canonicalize_shape) +from jax._src.core import (Primitive, UnshapedArray, ShapedArray, + abstract_token, canonicalize_shape) from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -60,7 +60,6 @@ from jax._src.lax.utils import ( _input_dtype, dtype_to_string, standard_abstract_eval, standard_multi_result_abstract_eval, standard_primitive) -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -563,6 +562,10 @@ def _convert_element_type( new_dtype = np.dtype(new_dtype) new_dtype = dtypes.dtype(new_dtype, canonicalize=True) + if (config.sharding_in_types.value and sharding is None and + isinstance(operand, Array)): + sharding = operand.sharding + if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" @@ -579,8 +582,7 @@ def _convert_element_type( if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array) and - not (isinstance(operand, core.Tracer) and - isinstance(core.get_aval(operand), core.ConcreteArray)) and + not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and (sharding is None or getattr(operand, 'sharding', None) == sharding)): return operand else: @@ -874,18 +876,18 @@ def __str__(self) -> str: return self.name @property - def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: + def supported_lhs_types(self) -> tuple[DTypeLike, ...] | None: match self: case ( - DotAlgorithmPreset.DEFAULT | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.DEFAULT + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32: - return np.float16 + return (np.float16,) case ( DotAlgorithmPreset.BF16_BF16_BF16 | DotAlgorithmPreset.BF16_BF16_F32 @@ -895,21 +897,21 @@ def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: # type. If not, we explicitly cast to bfloat16. return (dtypes.bfloat16, np.float32) case DotAlgorithmPreset.F64_F64_F64: - return np.float64 + return (np.float64,) case _: - return np.float32 + return (np.float32,) @property - def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: - return self.lhs_precision_type + def supported_rhs_types(self) -> tuple[DTypeLike, ...] | None: + return self.supported_lhs_types @property def accumulation_type(self) -> DTypeLike | None: match self: case ( - DotAlgorithmPreset.DEFAULT | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.DEFAULT + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None case DotAlgorithmPreset.F16_F16_F16: @@ -921,6 +923,40 @@ def accumulation_type(self) -> DTypeLike | None: case _: return np.float32 + def supported_output_types( + self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike + ) -> tuple[DTypeLike, ...] | None: + match self: + case ( + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + ): + return ( + np.float32, + np.float16, + dtypes.bfloat16, + dtypes.float8_e4m3fn, + dtypes.float8_e5m2, + dtypes.float8_e5m2fnuz, + dtypes.float8_e4m3fnuz, + dtypes.float8_e4m3b11fnuz, + ) + case DotAlgorithmPreset.F16_F16_F32: + # F16 output is only supported with F16 inputs. + if dtypes.promote_types(lhs_dtype, rhs_dtype) == np.float16: + return (np.float32, np.float16) + else: + return (np.float32,) + case DotAlgorithmPreset.BF16_BF16_F32: + # BF16 output is only supported with BF16 inputs. + if dtypes.promote_types(lhs_dtype, rhs_dtype) == dtypes.bfloat16: + return (np.float32, dtypes.bfloat16) + else: + return (np.float32,) + case _: + accumulation_type = self.accumulation_type + return None if accumulation_type is None else (accumulation_type,) + def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: f16 = ir.F16Type.get() @@ -930,26 +966,39 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, tf32 = ir.FloatTF32Type.get() match self: case ( - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), - np.dtype(dtypes.float8_e4m3fn), - np.dtype(dtypes.float8_e4m3fnuz), - np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + fp8_dtypes = [ + np.dtype(dtypes.float8_e4m3b11fnuz), + np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e4m3fnuz), + np.dtype(dtypes.float8_e5m2), + np.dtype(dtypes.float8_e5m2fnuz), + ] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " - f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.") + f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.' + ) lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype)) rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype)) acc = ir.F32Type.get() return hlo.DotAlgorithm.get( - lhs, rhs, acc, 1, 1, 1, - self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + lhs, + rhs, + acc, + 1, + 1, + 1, + self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, + ) case DotAlgorithmPreset.F16_F16_F16: return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) case DotAlgorithmPreset.F16_F16_F32: @@ -1040,7 +1089,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + out_type=None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -1086,6 +1136,13 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ + if out_type is not None and not config.sharding_in_types.value: + raise NotImplementedError("out_type only works when sharding_in_types " + "config is True.") + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError( + '`out_type` argument of `dot_general` only supports NamedSharding ' + 'instances. Please file a bug if this is not enough for your use case.') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), api_util._ensure_index_tuple(rhs_contract)) @@ -1097,7 +1154,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_type=out_type) def ragged_dot( @@ -1123,10 +1181,11 @@ def ragged_dot( """ return ragged_dot_p.bind(lhs, rhs, group_sizes, precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type, group_offset=group_offset) + preferred_element_type=preferred_element_type, + group_offset=group_offset) -def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: +def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array: """Broadcasts an array, adding new leading dimensions Args: @@ -1140,13 +1199,14 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: See Also: jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape. """ - if len(sizes) == 0: + if len(sizes) == 0 and sharding is None: return asarray(operand) dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand))) - return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) + return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims, + sharding=sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int]) -> Array: + broadcast_dimensions: Sequence[int], sharding=None) -> Array: """Wraps XLA's `BroadcastInDim `_ operator. @@ -1164,7 +1224,11 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ - if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array): + if not config.sharding_in_types.value and sharding is not None: + raise NotImplementedError("sharding argument to broadcast_in_dim is only " + "allowed when sharding_in_types config is on.") + if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and + isinstance(operand, Array) and sharding is None): return operand if config.dynamic_shapes.value: # We must gate this behavior under a flag because otherwise the errors @@ -1174,7 +1238,8 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, dyn_shape, static_shape = [], shape # type: ignore return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), - broadcast_dimensions=tuple(broadcast_dimensions)) + broadcast_dimensions=tuple(broadcast_dimensions), + sharding=sharding) def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" @@ -1184,7 +1249,8 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: return broadcast(x, (1,) * (rank - ndim)) def reshape(operand: ArrayLike, new_sizes: Shape, - dimensions: Sequence[int] | None = None) -> Array: + dimensions: Sequence[int] | None = None, + sharding: NamedSharding | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -1238,7 +1304,8 @@ def reshape(operand: ArrayLike, new_sizes: Shape, return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), - dimensions=None if dims is None or same_dims else dims) + dimensions=None if dims is None or same_dims else dims, + sharding=sharding) def pad(operand: ArrayLike, padding_value: ArrayLike, padding_config: Sequence[tuple[int, int, int]]) -> Array: @@ -1259,6 +1326,36 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, Returns: The ``operand`` array with padding value ``padding_value`` inserted in each dimension according to the ``padding_config``. + + Examples: + >>> from jax import lax + >>> import jax.numpy as jnp + + Pad a 1-dimensional array with zeros, We'll specify two zeros in front and + three at the end: + + >>> x = jnp.array([1, 2, 3, 4]) + >>> lax.pad(x, 0, [(2, 3, 0)]) + Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32) + + Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero + between each value: + + >>> lax.pad(x, 0, [(0, 0, 1)]) + Array([1, 0, 2, 0, 3, 0, 4], dtype=int32) + + Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad + size of 2 in each dimension: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) + Array([[-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 1, 2, 3, -1, -1], + [-1, -1, 4, 5, 6, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) @@ -1419,23 +1516,24 @@ def _get_monoid_reducer(monoid_op: Callable, x, = xs aval = core.get_aval(x) dtype = _dtype(x) - if (type(aval) is ConcreteArray) and aval.shape == (): + if core.is_concrete(x) and aval.shape == (): + val = core.to_concrete_value(x) # allow bitwise reductions for boolean and integer types _is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer) if monoid_op is add: - return _reduce_sum if np.equal(aval.val, 0) else None + return _reduce_sum if np.equal(val, 0) else None elif monoid_op is mul: - return _reduce_prod if np.equal(aval.val, 1) else None + return _reduce_prod if np.equal(val, 1) else None elif monoid_op is bitwise_or and _is_intlike: - return _reduce_or if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None + return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is bitwise_and and _is_intlike: - return _reduce_and if np.equal(aval.val, _get_bitwise_and_identity(dtype)) else None + return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None elif monoid_op is bitwise_xor and _is_intlike: - return _reduce_xor if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None + return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is max: - return _reduce_max if np.equal(aval.val, _get_max_identity(dtype)) else None + return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None elif monoid_op is min: - return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None + return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None return None def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray: @@ -1603,17 +1701,16 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, weak_type = dtype is None and dtypes.is_weakly_typed(fill_value) dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) fill_value = _convert_element_type(fill_value, dtype, weak_type) - # In tracing mode we can't set sharing explictly and PmapShardng is not - # supported. - # NB: Consider using with_sharding_constraint in jitted computation - # if needed? if (sharding is not None and not isinstance(sharding, PmapSharding) and isinstance(fill_value, array.ArrayImpl)): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) return array.make_array_from_callback(shape, sharding, lambda _: shard) - return broadcast(fill_value, shape) + if config.sharding_in_types.value and sharding is not None: + return broadcast(fill_value, shape, sharding=sharding) + else: + return broadcast(fill_value, shape) def zeros_like_shaped_array(aval: ShapedArray) -> Array: assert isinstance(aval, ShapedArray) @@ -1623,6 +1720,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) + if config.sharding_in_types.value: + return broadcast(scalar_zero, aval.shape, sharding=aval.sharding) return broadcast(scalar_zero, aval.shape) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array @@ -1631,6 +1730,9 @@ def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: val = ad_util.zeros_like_aval(aval.inner_aval) return core.mutable_array(val) +# TODO(dougalm): this is nonsense but it's here because in places like +# custom_vjp we assume that all arguments have tangent spaces. We could have +# a distinct NotATangentType value instead. ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore def iota(dtype: DTypeLike, size: int) -> Array: @@ -1740,6 +1842,9 @@ def stop(x): return x elif (dtypes.issubdtype(_dtype(x), np.floating) or dtypes.issubdtype(_dtype(x), np.complexfloating)): + # break abstractions to support legacy leaked tracer use cases + if isinstance(x, ad.JVPTracer): + return stop(x.primal) return ad_util.stop_gradient_p.bind(x) else: return x @@ -1811,22 +1916,26 @@ def full_like(x: ArrayLike | DuckTypedArray, if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] - # If `x` has a sharding but no `_committed` attribute - # (in case of ShapeDtypeStruct), default it to True. - use_x_sharding = ( - sharding is None - # Tracer have special logic in handling sharding and even - # though hasattr(x, 'sharding') returns False, it is very slow. - # This bypasses the check. - and not isinstance(x, core.Tracer) - and hasattr(x, 'sharding') - and getattr(x, '_committed', True) - and not weak_type - and fill_shape == np.shape(x) # type: ignore[arg-type] - ) - if use_x_sharding: - # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. - sharding = x.sharding # type: ignore + if (config.sharding_in_types.value and sharding is None and + isinstance(x, Array)): + sharding = x.sharding + else: + # If `x` has a sharding but no `_committed` attribute + # (in case of ShapeDtypeStruct), default it to True. + use_x_sharding = ( + sharding is None + # Tracer have special logic in handling sharding and even + # though hasattr(x, 'sharding') returns False, it is very slow. + # This bypasses the check. + and not isinstance(x, core.Tracer) + and hasattr(x, 'sharding') + and getattr(x, '_committed', True) + and not weak_type + and fill_shape == np.shape(x) # type: ignore[arg-type] + ) + if use_x_sharding: + # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. + sharding = x.sharding # type: ignore val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), sharding=sharding) return val @@ -1881,23 +1990,12 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" - return integer_pow(x, 2) + return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: r"""Elementwise reciprocal: :math:`1 \over x`.""" return integer_pow(x, -1) -def _upcast_fp16_for_computation(f): - @functools.wraps(f) - def f_wrapped(x): - dtype = _dtype(x) - if dtype == np.float16 or dtype == dtypes.bfloat16: - return convert_element_type( - f(convert_element_type(x, np.float32)), dtype) - return f(x) - - return f_wrapped - def tan(x: ArrayLike) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" return tan_p.bind(x) @@ -2066,25 +2164,25 @@ def broadcasting_shape_rule(name, *avals): def broadcasting_sharding_rule(name, *avals): - shapes = [aval.shape for aval in avals if aval.shape] - if not shapes: - return () - if len({len(shape) for shape in shapes}) != 1: - msg = '{}: arrays must have same number of dimensions, got {}.' - raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) - - specs = [a.sharding.spec for a in avals if a.shape] - mesh = None for a in avals: - if a.shape: - mesh = a.sharding.mesh + if a.sharding is not None: if mesh is not None and mesh != a.sharding.mesh: raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') + mesh = a.sharding.mesh assert mesh is not None + shapes = [aval.shape for aval in avals if aval.shape] + if not shapes: + return NamedSharding(mesh, P()) + if len({len(shape) for shape in shapes}) != 1: + msg = '{}: arrays must have same number of dimensions, got {}.' + raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) + + specs = [a.sharding.spec for a in avals if a.shape] + result_specs = [None] * len(shapes[0]) for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): if all(s == ss[0] for s in ss[1:]): @@ -2181,14 +2279,8 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) else: - # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains - # CompilerShardingAxis, then specify `unspecified_dims` via - # `wrap_with_sharding_op`. - if config.use_shardy_partitioner.value: - sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim) - else: - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() - out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) + proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto)) return out @@ -2204,11 +2296,7 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if config.use_shardy_partitioner.value: - out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) - else: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] else: return [out] @@ -2273,6 +2361,7 @@ def _round_lower(ctx, x, *, rounding_method): exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) +batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule exp2_p = standard_unop(_float | _complex, 'exp2') ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) @@ -2334,12 +2423,17 @@ def _sin_lowering(ctx, x): return sine(ctx, x) return _nary_lower_hlo(hlo.sine, ctx, x) +def _sin_p_lin(nzs, x): + nz, = nzs + cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) + return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule - def _cos_complex(x): # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) # see also _sin_complex @@ -2361,21 +2455,9 @@ def _cos_lowering(ctx, x): ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) mlir.register_lowering(cos_p, _cos_lowering) -@_upcast_fp16_for_computation -def _tan_impl(x): - return div(sin(x), cos(x)) - tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this -# lowering is mostly supported, but it fails on export or with the PJRT plugin -# because those modes target an older StableHLO version, and the -# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't -# included in the 0.4.33 release. -if jaxlib_version <= (0, 4, 33): - mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) -else: - mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): @@ -2388,27 +2470,9 @@ def asin_impl(x): ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin)) -def acos_impl(x): - if dtypes.issubdtype(_dtype(x), np.complexfloating): - result = mul(_const(x, 1j), acosh(x)) - # By convention, numpy chooses the branch with positive real part. - rpart = real(result) - return select( - gt(rpart, _const(rpart, 0)), - result, - neg(result) - ) - else: - return select( - ne(x, _const(x, -1.0)), - mul(_const(x, 2), - atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))), - full_like(x, np.pi)) - acos_p = standard_unop(_float | _complex, 'acos') ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) -mlir.register_lowering(acos_p, - mlir.lower_fun(acos_impl, multiple_results=False)) +mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos)) def atan_impl(x): return atan2(x, _const(x, 1)) @@ -2530,6 +2594,27 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +square_p = standard_unop(_int | _float | _complex, 'square') + +def _square_complex(x): + a, b = real(x), imag(x) + # zero square(x).real is handled explicitly for abs(a)==abs(b) cases + # where for finite a, 2 * a is non-finite: + zero_re = is_finite(a) & (eq(a, b) | eq(a, -b)) + # equivalent to a**2 - b**2 but avoids overflow errors for large a + # and large b cases: + re = (a - b) * (a + b) + im = a * b * 2 + return select(zero_re, complex(_const(a, 0), im), complex(re, im)) + +def _square_lower_hlo(ctx, x): + if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): + return mlir.lower_fun(_square_complex, multiple_results=False)(ctx, x) + return [hlo.multiply(x, x)] + +ad.defjvp2(square_p, lambda g, ans, x: mul(g, mul(_const(x, 2), x))) +mlir.register_lowering(square_p, _square_lower_hlo) # TODO(pearu): use chlo.square + def _pow_dtype_rule(x, y): if (dtypes.issubdtype(x.dtype, np.inexact) and dtypes.issubdtype(y.dtype, np.integer)): @@ -2569,15 +2654,12 @@ def _pow_jvp_rhs(g, ans, x, y): def _pow_lower(ctx, x, y): x_aval, y_aval = ctx.avals_in - out_aval, = ctx.avals_out - convert = mlir.lower_fun( - partial(convert_element_type, new_dtype=out_aval.dtype), False) - x_aval_ = x_aval.update(dtype=out_aval.dtype) - y_aval_ = y_aval.update(dtype=out_aval.dtype) - [x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x) - [y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y) - ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_]) - return _nary_lower_hlo(hlo.power, ctx_, x_, y_) + if x_aval.dtype != y_aval.dtype: + out_aval, = ctx.avals_out + y_aval = y_aval.update(dtype=out_aval.dtype) + y = hlo.convert(mlir.aval_to_ir_type(y_aval), y) + ctx = ctx.replace(avals_in=[x_aval, y_aval]) + return _nary_lower_hlo(hlo.power, ctx, x, y) mlir.register_lowering(pow_p, _pow_lower) def _integer_pow_dtype_rule(x, *, y): @@ -2618,24 +2700,23 @@ def _integer_pow(x, *, y): def _integer_pow_lowering(ctx, x, *, y): # These cases are subsumed by the general case, but it's faster to emit these # common cases directly. - if y == 2: + if y == 1: + out = x + elif y == 2: out = hlo.multiply(x, x) elif y == 3: out = hlo.multiply(hlo.multiply(x, x), x) + elif y == -1: + out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x) else: lowering = mlir.lower_fun(_integer_pow, multiple_results=False) - # TODO(b/217551391): emitting an out-of-line call leads to a large - # expansion when the MLIR is lowered to HLO, because the HLO lowering - # clones the callee. Consider unconditionally caching when the MLIR->HLO - # lowering doesn't expand the program. - lowering = mlir.cache_lowering(lowering) - out = lowering(ctx, x, y=y) + if builtins.abs(y) >= 3: + lowering = mlir.cache_lowering(lowering) + out, = lowering(ctx, x, y=y) if config.sharding_in_types.value: aval_out, = ctx.avals_out - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - out = out[0] if isinstance(out, list) else out - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] - return out if isinstance(out, list) else [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) @@ -2723,6 +2804,7 @@ def _sub_transpose(t, x, y): ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract)) +batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule def _mul_transpose(ct, x, y): @@ -2744,6 +2826,7 @@ def _mul_transpose(ct, x, y): lambda ydot, x, y: mul(x, ydot)) ad.primitive_transposes[mul_p] = _mul_transpose mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) +batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -2757,6 +2840,7 @@ def _div_transpose_rule(cotangent, x, y): lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide)) +batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( @@ -2780,12 +2864,14 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo)) +batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) +batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) @@ -2872,6 +2958,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt') ad.defjvp_zero(lt_p) mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False)) +batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to') ad.defjvp_zero(eq_to_p) @@ -2967,14 +3054,18 @@ def _convert_elt_type_pp_rule(eqn, context, settings): return core._pp_eqn(eqn.replace(params=params), context, settings) convert_element_type_p = Primitive('convert_element_type') -def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): - operand = core.Primitive.bind(convert_element_type_p, operand, - new_dtype=new_dtype, weak_type=weak_type, - sharding=sharding) - if sharding is not None: - operand = pjit.with_sharding_constraint(operand, sharding) + +# TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to +# the old "custom bind" but it might not be the best way to do this. +def _convert_element_type_bind_with_trace(trace, args, params): + sharding = params['sharding'] + operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params) + if sharding is not None and not config.sharding_in_types.value: + with core.set_current_trace(trace): + operand = pjit.with_sharding_constraint(operand, sharding) return operand -convert_element_type_p.def_custom_bind(_convert_element_type_bind) +convert_element_type_p.def_bind_with_trace(_convert_element_type_bind_with_trace) + convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, @@ -3002,7 +3093,12 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = hlo.real(operand) aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) - return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)] + out = mlir.convert_hlo(ctx, operand, aval_in, aval_out) + if config.sharding_in_types.value: + if sharding is not None: + assert aval_out.sharding == sharding + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -3043,7 +3139,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype) + return x.update(shape=shape_prefix, dtype=edtype) to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -3164,7 +3260,10 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -3241,24 +3340,29 @@ def _check_specs_match(lhs_spec, rhs_spec, msg): raise TypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): if lhs.sharding.mesh != rhs.sharding.mesh: raise ValueError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + if out_type is not None: + assert isinstance(out_type, NamedSharding) + return out_type + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " - f"to have the consistent sharding, got {lhs_batch_spec} and " - f"{rhs_batch_spec}.") + f"to have the consistent sharding, got {lhs_batch_spec} and " + f"{rhs_batch_spec}.") _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) msg = ("dot_general requires contracting dimensions to have consistent " - f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") + f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) return _dot_general_sharding_computation( @@ -3280,7 +3384,10 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError del dimension_numbers # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. @@ -3327,7 +3434,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - swap_ans=False): + out_type, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim x_kept = remaining(range(x_ndim), x_contract, x_batch) @@ -3338,21 +3445,31 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) - out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) - x_bar = transpose(dot_general(g, y, dims, precision=precision, - preferred_element_type=preferred_element_type), - tuple(out_axes)) + unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y + out_axes = np.argsort(unsorted_axes) + if config.sharding_in_types.value: + xs = x.aval.sharding + inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) + ds = xs.with_spec(inverse_spec) + else: + ds = None + dot_general_out = dot_general(g, y, dims, precision=precision, + preferred_element_type=preferred_element_type, + out_type=ds) + x_bar = transpose(dot_general_out, tuple(out_axes)) if x_bar.dtype != x.aval.dtype: x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, swap_ans=True) + preferred_element_type=preferred_element_type, out_type=out_type, + swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) return y_bar @@ -3366,6 +3483,7 @@ def _dot_batch_rule( batch_dims, *, dimension_numbers, + out_type, precision, preferred_element_type: DTypeLike | None, **_, @@ -3395,12 +3513,16 @@ def _dot_batch_rule( rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: rhs_shape = np.shape(rhs) + if out_type is not None: + raise NotImplementedError("vmap with out_type is not supported. " + "Please open an issue.") batched_out = invoke_prim( lhs, rhs, new_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, + out_type=out_type, ) result_batch_dim = batching.shape_as_bdim( result_stack_dim, @@ -3481,12 +3603,37 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -def _dot_general_ragged_prop_rule(invar_raggedness, outvars): +def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 assert len(outvars) == 1 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1] + dimension_numbers = eqn_params['dimension_numbers'] + (lhs_contracting, rhs_contracting), (_, _) = dimension_numbers + + if not invar_raggedness_lhs and not invar_raggedness_rhs: + # Both are dense - it is valid to reach here, because dense operations + # are legal in code running under ragged prop. + return invar_raggedness, [None] + + if not invar_raggedness_lhs or not invar_raggedness_rhs: + # One ragged, one dense + if not invar_raggedness_lhs: + # left is dense, right is ragged + _, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs + if rhs_contracting != ragged_axis_dim_rhs: + # Contraction is on a dense dimension, this is valid! + return invar_raggedness, [None] + if not invar_raggedness_rhs: + # left is ragged, right is dense + _, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs + if lhs_contracting != ragged_axis_dim_lhs: + # Contraction is on a dense dimension, this is valid! + return invar_raggedness, [None] + + raise NotImplementedError('NYI - dense and ragged dim contraction') + stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs @@ -3505,9 +3652,8 @@ def _dot_general_ragged_prop_rule(invar_raggedness, outvars): assert len(outvars) == 1 # TODO(mvoz): A constant on batching.* ? - dense_jumble_raggedness = None # Dense (m, n) - no jumble only atm - return invar_raggedness, [dense_jumble_raggedness] + return invar_raggedness, [None] dot_general_p = standard_primitive( @@ -3568,12 +3714,44 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype) +def get_algorithm_compute_types( + algorithm: DotAlgorithm | DotAlgorithmPreset, + lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike, + out_dtype: DTypeLike | None = None, +) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]: + if isinstance(algorithm, DotAlgorithm): + return ( + algorithm.lhs_precision_type, + algorithm.rhs_precision_type, + algorithm.accumulation_type, + ) + + def maybe_convert_dtype(input_dtype, target_dtypes): + if target_dtypes is None: + return input_dtype + if np.dtype(input_dtype) in map(np.dtype, target_dtypes): + return input_dtype + return target_dtypes[0] + + lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types) + rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types) + out_type = maybe_convert_dtype( + out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype) + ) + return lhs_dtype, rhs_dtype, out_type + + def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, - platform: str = "default"): + out_type, platform: str = "default"): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + if dtypes.float8_e3m4 is not None: + fp8_dtypes += (dtypes.float8_e3m4,) + if dtypes.float8_e4m3 is not None: + fp8_dtypes += (dtypes.float8_e4m3,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in @@ -3598,6 +3776,8 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): if platform == "cpu" and precision not in { DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16, DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64, + DotAlgorithmPreset.BF16_BF16_F32, DotAlgorithmPreset.BF16_BF16_F32_X3, + DotAlgorithmPreset.BF16_BF16_F32_X6, }: raise ValueError( f"The precision '{precision}' is not supported by dot_general on CPU") @@ -3605,20 +3785,17 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): # If an explicit algorithm was specified, we always cast the input types to # the correct types. def maybe_convert_dtype(operand, operand_aval, target_dtype): - if target_dtype is None: - return operand, operand_aval.dtype - if not isinstance(target_dtype, tuple): - target_dtype = (target_dtype,) - if any(operand_aval.dtype == d for d in target_dtype): - return operand, operand_aval.dtype - aval = core.ShapedArray(operand_aval.shape, target_dtype[0]) - return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0] - - lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type) - rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type) - accumulation_type = precision.accumulation_type - if accumulation_type is not None: - accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type) + if target_dtype is None or operand_aval.dtype == target_dtype: + return operand + aval = core.ShapedArray(operand_aval.shape, target_dtype) + return mlir.convert_hlo(ctx, operand, operand_aval, aval) + + lhs_dtype, rhs_dtype, accumulation_dtype = get_algorithm_compute_types( + precision, lhs_dtype, rhs_dtype, aval_out.dtype) + lhs = maybe_convert_dtype(lhs, lhs_aval, lhs_dtype) + rhs = maybe_convert_dtype(rhs, rhs_aval, rhs_dtype) + if accumulation_dtype is not None: + accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_dtype) if precision != DotAlgorithmPreset.DEFAULT: algorithm_kwarg = { @@ -3639,7 +3816,6 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype else: # cpu and gpu # Do not convert mixed fp8 types to output type. if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype): @@ -3647,7 +3823,6 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), @@ -3658,8 +3833,9 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): **algorithm_kwarg, ) if config.sharding_in_types.value: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp) + if out_type is not None: + assert aval_out.sharding == out_type + result = mlir.lower_sharding_under_shit(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) return [result] @@ -3711,12 +3887,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S return (m, n) def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: + precision, preferred_element_type: DTypeLike | None, + **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. - return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type) + return _dot_general_dtype_rule( + lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, preferred_element_type=preferred_element_type, + out_type=None) def _ragged_dot_jvp_rule( @@ -3839,7 +4018,9 @@ def _ragged_dot_invoke_prim( new_dimension_numbers, precision, preferred_element_type, + out_type, ): + del out_type return ragged_dot( lhs, rhs, @@ -3868,6 +4049,7 @@ def _ragged_dot_batch_rule( dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type, + out_type=None, ) @@ -3909,7 +4091,8 @@ def _ragged_dot_impl( mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False)) -def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): +def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, + sharding): _check_shapelike('broadcast_in_dim', 'shape', shape) _check_shapelike('broadcast_in_dim', 'broadcast_dimensions', broadcast_dimensions) @@ -3944,18 +4127,22 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): raise TypeError(msg.format(broadcast_dimensions)) return shape -def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions): +def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, + sharding): + if sharding is not None: + return sharding bds = set(broadcast_dimensions) orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] assert next(orig_spec, None) is None - return NamedSharding(operand.sharding.mesh, P(*new_spec)) + return operand.sharding.with_spec(new_spec) def _broadcast_in_dim_typecheck_rule( - _, operand, *dyn_shape, shape, broadcast_dimensions): + _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): if not dyn_shape: out_aval, effects = broadcast_in_dim_p.abstract_eval( - operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions) + operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) return [out_aval], effects else: # TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule @@ -3966,7 +4153,7 @@ def _broadcast_in_dim_typecheck_rule( return [out_aval], core.no_effects def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, - shape, broadcast_dimensions): + shape, broadcast_dimensions, sharding): if type(ct) is ad_util.Zero: return [ad_util.Zero(operand.aval)] unit_dims = [i for i, s in enumerate(operand.aval.shape) @@ -3977,7 +4164,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, [None] * len(dyn_shape)) def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape, - broadcast_dimensions): + broadcast_dimensions, sharding): # `dyn_shape` is the dynamic portion of the target shape. `shape` # is the target shape, with `None` for dynamic sections. # broadcast_dimensions gives indices where dimensions of the input @@ -4023,6 +4210,8 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape, assert len(sizes) == stacked_size, msg dyn_limits.append(bound) new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits) + if sharding is not None: + raise NotImplementedError('Implement broadcast_in_dim_batch_rule') result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions) out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None] out_bdim = batching.make_batch_axis( @@ -4037,8 +4226,9 @@ def _broadcast_in_dim_fwd_rule(eqn): return [None], eqn def _broadcast_in_dim_staging_rule( - trace, x, *dyn, shape, broadcast_dimensions): - params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions) + trace, x, *dyn, shape, broadcast_dimensions, sharding): + params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) if not dyn: return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) @@ -4063,24 +4253,28 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape), broadcast_dimensions=broadcast_dimensions)] -def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions): +def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions, + sharding): operand, *dyn_shape = primals operand_dot, *_ = tangents y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, - broadcast_dimensions=broadcast_dimensions) + broadcast_dimensions=broadcast_dimensions, + sharding=sharding) if type(operand_dot) is ad_util.Zero: y_dot = ad_util.Zero.from_primal_value(y) else: y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, - broadcast_dimensions=broadcast_dimensions) + broadcast_dimensions=broadcast_dimensions, + sharding=sharding) return y, y_dot def _broadcast_in_dim_partial_eval( - trace, operand, *dyn_shape, shape, broadcast_dimensions): + trace, operand, *dyn_shape, shape, broadcast_dimensions, sharding): if not dyn_shape: return trace.default_process_primitive( broadcast_in_dim_p, (operand, *dyn_shape), - dict(shape=shape, broadcast_dimensions=broadcast_dimensions)) + dict(shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding)) assert all(t.pval.is_known() for t in dyn_shape) operand_tracer = trace.instantiate_const(operand) dyn_shape_tracers = map(trace.instantiate_const, dyn_shape) @@ -4090,41 +4284,46 @@ def _broadcast_in_dim_partial_eval( out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe( [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, - dict(shape=shape, broadcast_dimensions=broadcast_dimensions), + dict(shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=None), core.no_effects, source_info_util.current()) out_tracer.recipe = eqn return out_tracer -def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) -> Sequence[ir.Value]: +def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, + sharding) -> Sequence[ir.Value]: aval_out, = ctx.avals_out if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=broadcast_dimensions) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + if sharding is not None: + assert sharding == aval_out.sharding + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] -def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): +def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, + sharding): if (not dyn_shape and not any(isinstance(d, core.DArray) and type(core.get_aval(d).dtype) is core.bint for d in shape)): shape = _broadcast_in_dim_shape_rule( # error checking - x, shape=shape, broadcast_dimensions=broadcast_dimensions) + x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) if config.sharding_in_types.value: - sharding = _broadcast_in_dim_sharding_rule( - x, shape=shape, broadcast_dimensions=broadcast_dimensions) + new_sharding = _broadcast_in_dim_sharding_rule( + x, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) else: - sharding = None - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=sharding) + new_sharding = None + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type) -def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars): +def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 1 assert not isinstance(invar_raggedness[0], core.Var) return invar_raggedness, [None] * len(outvars) @@ -4227,7 +4426,7 @@ def _concatenate_shape_rule(*operands, **kwargs): raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands]))) shapes = [operand.shape[:dimension] + operand.shape[dimension+1:] for operand in operands] - if not shapes[:-1] == shapes[1:]: + if shapes[:-1] != shapes[1:]: msg = ("Cannot concatenate arrays with shapes that differ in dimensions " "other than the one being concatenated: concatenating along " "dimension {} for shapes {}.") @@ -4238,6 +4437,13 @@ def _concatenate_shape_rule(*operands, **kwargs): ex_shape = operands[0].shape return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:] +def _concatenate_sharding_rule(*operands, **kwargs): + if not all(o.sharding == operands[0].sharding for o in operands): + ss = ", ".join(str(o.sharding) for o in operands) + raise TypeError( + f"All operands should have the same sharding. Got shardings {ss}") + return operands[0].sharding + def _concatenate_dtype_rule(*operands, **kwargs): check_same_dtypes('concatenate', *operands) return operands[0].dtype @@ -4278,14 +4484,19 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): raise NotImplementedError # TODO(mattjj) concatenate_p = standard_primitive( - _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate') + _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', + sharding_rule=_concatenate_sharding_rule) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): - return [hlo.concatenate(xs, mlir.i64_attr(dimension))] + aval_out, = ctx.avals_out + out = hlo.concatenate(xs, mlir.i64_attr(dimension)) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(concatenate_p, _concatenate_lower) @@ -4297,7 +4508,8 @@ def _pad_dtype_rule(operand, padding_value, *, padding_config): return _input_dtype(operand, padding_value) def _pad_shape_rule(operand, padding_value, *, padding_config): - del padding_value + if np.ndim(padding_value) != 0: + raise ValueError(f"padding_value must be a scalar; got {np.shape(padding_value)=}") op_shape = np.shape(operand) if not len(padding_config) == np.ndim(operand): raise ValueError("length of padding_config must equal the number of axes " @@ -4315,6 +4527,15 @@ def _pad_shape_rule(operand, padding_value, *, padding_config): raise ValueError(msg) return result +def _pad_sharding_rule(operand, padding_value, *, padding_config): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _pad_shape_rule(operand, padding_value, + padding_config=padding_config) + return slicing._get_sharding_for_varying_out_shape( + out_shape, operand, 'padding') + + def _pad_transpose(t, operand, padding_value, *, padding_config): if type(t) is ad_util.Zero: t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None @@ -4354,14 +4575,18 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): (operand_bdim,)) return select(mask, x, broadcasted_padding), operand_bdim -pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad') +pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', + sharding_rule=_pad_sharding_rule) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule def _pad_lower(ctx, x, padding_value, *, padding_config): aval_out, = ctx.avals_out low, high, interior = util.unzip3(padding_config) - return [mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)] + out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(pad_p, _pad_lower) @@ -4383,6 +4608,12 @@ def _squeeze_dtype_rule(operand, *, dimensions): def _squeeze_shape_rule(operand, *, dimensions): return _compute_squeeze_shape(np.shape(operand), dimensions) +def _squeeze_sharding_rule(operand, *, dimensions): + dims_set = set(dimensions) + new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) + if i not in dims_set) + return operand.sharding.with_spec(new_spec) + def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) if len(dims_set) != len(dimensions): @@ -4411,7 +4642,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze') + 'squeeze', sharding_rule=_squeeze_sharding_rule) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -4419,7 +4650,11 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. - return [mlir.reshape(ctx, operand, ctx.avals_out[0])] + aval_out, = ctx.avals_out + out = mlir.reshape(ctx, operand, aval_out) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(squeeze_p, _squeeze_lower) @@ -4428,6 +4663,8 @@ def shape_as_value(shape: core.Shape): """Converts a shape that may contain Poly values into a JAX value.""" if len(shape) == 0: return full((0,), np.array(0, np.int64)) + if core.is_constant_shape(shape): + return np.asarray(shape, dtype=np.int64) dims = [ expand_dims(convert_element_type(core.dimension_as_value(d), np.int64), (0,)) @@ -4435,7 +4672,7 @@ def shape_as_value(shape: core.Shape): ] return concatenate(dims, dimension=0) -def _reshape_shape_rule(operand, *, new_sizes, dimensions): +def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding): if not all(d >= 0 for d in new_sizes): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) @@ -4455,10 +4692,33 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions): raise TypeError(msg.format(dimensions, np.shape(operand))) return tuple(new_sizes) -def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): +def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): + if sharding is not None: + return sharding + filtered_spec = [ + (sh, sp) for sh, sp in zip(operand.shape, operand.sharding.spec) + if sh != 1 + ] + fs = iter(filtered_spec) + new_spec = [] + for n in new_sizes: + if n == 1: + new_spec.append(None) + else: + sh, sp = next(fs) + if n != sh: + raise ValueError( + 'This reshape is not supported. Please specify the sharding of the' + ' output via the `sharding` argument of reshape.') + new_spec.append(sp) + return operand.sharding.with_spec(new_spec) + +def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions, + sharding): if not dyn_shape: out_aval, effects = reshape_p.abstract_eval( - operand.aval, new_sizes=new_sizes, dimensions=dimensions) + operand.aval, new_sizes=new_sizes, dimensions=dimensions, + sharding=sharding) return [out_aval], effects else: # TODO(mattjj, necula): perform more checks like _reshape_shape_rule @@ -4469,18 +4729,29 @@ def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): return [out_aval], core.no_effects -def _reshape_dtype_rule(operand, *, new_sizes, dimensions): +def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding): return operand.dtype -def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions): +def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): assert ad.is_undefined_primal(operand) if dimensions is None: + if config.sharding_in_types.value: + return [reshape(t, operand.aval.shape, sharding=operand.aval.sharding)] return [reshape(t, operand.aval.shape)] else: - return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)), + if config.sharding_in_types.value: + t_s = operand.sharding.with_spec( + tuple(map(str, np.take(operand.aval.sharding.spec, dimensions)))) + else: + t_s = None + return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), + sharding=t_s), np.argsort(dimensions))] -def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): +def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions, + sharding): + if sharding is not None: + raise NotImplementedError operand, = batched_args bdim, = batch_dims operand = batching.moveaxis(operand, bdim, 0) @@ -4489,24 +4760,29 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0 -def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): +def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): aval_out, = ctx.avals_out if dimensions is not None: x = hlo.transpose(x, mlir.dense_int_array(dimensions)) if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) - return [mlir.reshape(ctx, x, aval_out)] + out = mlir.reshape(ctx, x, aval_out) + if config.sharding_in_types.value: + if sharding is not None: + assert sharding == aval_out.sharding + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] def _reshape_staging_rule( - trace, x, *dyn, new_sizes, dimensions): - params = dict(new_sizes=new_sizes, dimensions=dimensions) + trace, x, *dyn, new_sizes, dimensions, sharding): + params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) if not dyn: return trace.default_process_primitive(reshape_p, (x,), params) av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape') + 'reshape', sharding_rule=_reshape_sharding_rule) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.primitive_batchers[reshape_p] = _reshape_batch_rule mlir.register_lowering(reshape_p, _reshape_lower) @@ -4553,7 +4829,7 @@ def _transpose_shape_rule(operand, *, permutation): def _transpose_sharding_rule(operand, *, permutation): o_spec = operand.sharding.spec new_spec = [o_spec[old_idx] for old_idx in permutation] - return NamedSharding(operand.sharding.mesh, P(*new_spec)) + return operand.sharding.with_spec(new_spec) def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args @@ -4574,8 +4850,7 @@ def _transpose_lower(ctx, x, *, permutation): permutation = [*permutation, *trailing_dims] out = hlo.transpose(x, mlir.dense_int_array(permutation)) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] transpose_p = standard_primitive( @@ -4600,6 +4875,18 @@ def _select_shape_rule(which, *cases): raise TypeError(msg.format(which.shape, cases[0].shape)) return cases[0].shape +def _select_sharding_rule(which, *cases): + if any(case.sharding != cases[0].sharding for case in cases[1:]): + msg = "select cases must have the same shardings, got [{}]." + raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + if which.shape and which.sharding != cases[0].sharding: + raise TypeError( + 'select `which` must be scalar or have the same sharding as cases, got' + f' `which` sharding {which.sharding} but case sharding' + f' {cases[0].sharding}.') + return cases[0].sharding + + def _select_dtype_rule(which, *cases): check_same_dtypes("select", *cases) if (not dtypes.issubdtype(which.dtype, np.bool_) and @@ -4702,18 +4989,24 @@ def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in=[aval_which_bcast, *physical_avals_cases], avals_out=[physical_aval_out])[0] +def _add_shit_to_select(ctx, op, aval_out): + if config.sharding_in_types.value: + return mlir.lower_sharding_under_shit(ctx, op, aval_out) + return op def _select_hlo_lowering(ctx, which, *cases): which_aval = ctx.avals_in[0] aval_out, = ctx.avals_out if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - return [_select_hlo_lowering_opaque(ctx, which, *cases)] + op = _select_hlo_lowering_opaque(ctx, which, *cases) + return [_add_shit_to_select(ctx, op, aval_out)] if which_aval.dtype == np.dtype(np.bool_): assert len(cases) <= 2 if len(cases) == 1: return cases - return [hlo.select(which, cases[1], cases[0])] + op = hlo.select(which, cases[1], cases[0]) + return [_add_shit_to_select(ctx, op, aval_out)] if dtypes.issubdtype(which_aval.dtype, np.signedinteger): compare_type = 'SIGNED' @@ -4732,11 +5025,12 @@ def _select(offset, cases): return hlo.select(pred, _select(offset, cases[:mid]), _select(offset + mid, cases[mid:])) - return [_select(0, cases)] + op = _select(0, cases) + return [_add_shit_to_select(ctx, op, aval_out)] select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -4751,6 +5045,11 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): raise ValueError(f'reduce found non-scalar initial value: {init_val_shapes}') return [tuple(np.delete(op.shape, dimensions)) for op in operand_avals] +def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) + for op in operand_avals] + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] @@ -4837,7 +5136,7 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -4859,6 +5158,9 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): *reducer.arguments, dim_var_values=ctx.dim_var_values) hlo.return_(mlir.flatten_ir_values(out_nodes)) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, r, aval) + for r, aval in safe_zip(op.results, ctx.avals_out)] return op.results mlir.register_lowering(reduce_p, _reduce_lower) @@ -4874,7 +5176,11 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) - result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) + if config.sharding_in_types.value: + result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions, + sharding=operand.aval.sharding) + else: + result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) assert result.shape == input_shape return [result] @@ -4905,7 +5211,7 @@ def _reduce_op_sharding_rule(operand, *, axes): axes = frozenset(axes) new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) if i not in axes)) - return NamedSharding(operand.sharding.mesh, new_spec) + return operand.sharding.with_spec(new_spec) reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), @@ -4914,6 +5220,7 @@ def _reduce_op_sharding_rule(operand, *, axes): batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, _get_sum_identity) +batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule def _reduce_prod_jvp_rule(primals, tangents, *, axes): reducer = lambda x, y: [mul(x, y)] @@ -4948,6 +5255,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max, _get_max_identity) +batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, @@ -4965,7 +5273,12 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): if operand.shape[axis] < 1: raise ValueError("argmin and argmax require non-empty reduced dimension. " f"operand.shape={operand.shape} {axis=}") - return tuple(np.delete(operand.shape, axis)) + return util.tuple_delete(operand.shape, axis) + +def _argminmax_sharding_rule(operand, *, axes, index_dtype): + axis, = axes + return operand.sharding.with_spec( + util.tuple_delete(operand.sharding.spec, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): if not dtypes.issubdtype(index_dtype, np.integer): @@ -5002,7 +5315,9 @@ def _compute_argminmax(value_comparator, get_identity, # value_comparator is either lax.lt (for argmin) or lax.gt # get_identity(operand.dtype) is inf for argmin or -inf for argmax axis, = axes - indices = broadcasted_iota(index_dtype, np.shape(operand), axis) + indices = broadcasted_iota( + index_dtype, np.shape(operand), axis, + _sharding=operand.sharding if config.sharding_in_types.value else None) res = reduce([operand, indices], [get_identity(operand.dtype), np.array(0, index_dtype)], _ArgMinMaxReducer(value_comparator), @@ -5010,22 +5325,24 @@ def _compute_argminmax(value_comparator, get_identity, return res[1] argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, - 'argmin', weak_type_rule=_strip_weak_type) + 'argmin', weak_type_rule=_strip_weak_type, + sharding_rule=_argminmax_sharding_rule) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, - 'argmax', weak_type_rule=_strip_weak_type) + 'argmax', weak_type_rule=_strip_weak_type, + sharding_rule=_argminmax_sharding_rule) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) -mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun( - partial(_compute_argminmax, lt, _get_min_identity), - multiple_results=False))) +mlir.register_lowering(argmin_p, mlir.cache_lowering( + mlir.lower_fun(partial(_compute_argminmax, lt, _get_min_identity), + multiple_results=False))) -mlir.register_lowering(argmax_p, mlir.cache_lowering(mlir.lower_fun( - partial(_compute_argminmax, gt, _get_max_identity), - multiple_results=False))) +mlir.register_lowering(argmax_p, mlir.cache_lowering( + mlir.lower_fun(partial(_compute_argminmax, gt, _get_max_identity), + multiple_results=False))) def _reduce_logical_shape_rule(operand, *, axes): @@ -5063,8 +5380,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): with ir.InsertionPoint(reducer_region): hlo.return_([reducer(*reducer_region.arguments)]) if config.sharding_in_types.value: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, op.result, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)] return op.results mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, @@ -5122,7 +5438,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): def _sort_abstract_eval(*args, **kwargs): - args = tuple(raise_to_shaped(arg) for arg in args) + args = tuple(args) if any(arg.shape != args[0].shape for arg in args[1:]): shapes = " ".join(str(a.shape) for a in args) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") @@ -5193,14 +5509,14 @@ def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys): shape = primals[0].shape iotas = [] for dim, size in enumerate(shape): - dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64 - iotas.append(broadcasted_iota(dtype, shape, dim)) - primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension, - is_stable=is_stable, num_keys=num_keys) - idx = tuple(primals[-1] if i == dimension else iotas[i] + iotas.append(broadcasted_iota(np.int64, shape, dim)) + sorted_primals_and_idx = sort_p.bind( + *primals, iotas[dimension], dimension=dimension, + is_stable=is_stable, num_keys=num_keys) + idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i] for i in range(len(shape))) tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents) - return tuple(primals[:-1]), tangents_out + return tuple(sorted_primals_and_idx[:-1]), tangents_out def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys): prototype_arg, new_bdim = next( @@ -5621,7 +5937,7 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule)) + _rng_bit_generator_weak_type_rule, None)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -5728,9 +6044,11 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) + iota_p = Primitive('iota') iota_p.def_impl(partial(dispatch.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) +batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, @@ -5761,8 +6079,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): out = mlir.iota(ctx, aval_out, dimension=dimension) if config.sharding_in_types.value: assert aval_out.sharding == sharding - proto = sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(iota_p, _iota_lower) @@ -5971,7 +6288,13 @@ def _const(example, val): _zeros: Callable = partial(full_like, fill_value=0) _zero: Callable = partial(full_like, shape=(), fill_value=0) _ones: Callable = partial(full_like, fill_value=1) -_one: Callable = partial(full_like, shape=(), fill_value=1) + +def _one(x): + if config.sharding_in_types.value: + return full_like(x, shape=(), fill_value=1, + sharding=x.sharding.with_spec(P())) + return full_like(x, shape=(), fill_value=1) + _twos: Callable = partial(full_like, fill_value=2) _two: Callable = partial(full_like, shape=(), fill_value=2) @@ -6064,7 +6387,7 @@ def _eq_meet(a, b): def _abstractify(x): - return raise_to_shaped(core.get_aval(x)) + return core.get_aval(x) def empty(dtype): @@ -6176,3 +6499,7 @@ def _optimization_barrier_lowering_rule(ctx, *args): optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) mlir.register_lowering(optimization_barrier_p, _optimization_barrier_lowering_rule) + +def _optimization_barrier_batcher(batched_args, batch_dims, **params): + return optimization_barrier_p.bind(*batched_args, **params), batch_dims +batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index dc1d1d472ae2..a352cee757ca 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -33,7 +33,7 @@ from jax._src import dtypes from jax._src import util from jax._src.core import ( - Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) + Primitive, ShapedArray, is_constant_dim, is_constant_shape) from jax._src.extend import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -48,6 +48,7 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -121,16 +122,46 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, - compute_right_eigenvectors: bool = True) -> list[Array]: + compute_right_eigenvectors: bool = True, + use_magma: bool | None = None) -> list[Array]: """Eigendecomposition of a general matrix. - Nonsymmetric eigendecomposition is at present only implemented on CPU. + Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU, + the default implementation calls LAPACK directly on the host CPU, but an + experimental GPU implementation using `MAGMA `_ + is also available. The MAGMA implementation is typically slower than the + equivalent LAPACK implementation for small matrices (less than about 2048), + but it may perform better for larger matrices. + + To enable the MAGMA implementation, you must install MAGMA yourself (there + are Debian and conda-forge packages, or you can build from source). Then set + the ``use_magma`` argument to ``True``, or set the ``jax_use_magma`` + configuration variable to ``"on"`` or ``"auto"``: + + .. code-block:: python + + jax.config.update('jax_use_magma', 'on') + + JAX will try to ``dlopen`` the installed MAGMA shared library, raising an + error if it is not found. To explicitly specify the path to the MAGMA + library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full + installation path. + + If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will + be used if the library can be found, and the input matrix is sufficiently + large (>= 2048x2048). Args: x: A batch of square matrices with shape ``[..., n, n]``. compute_left_eigenvectors: If true, the left eigenvectors will be computed. compute_right_eigenvectors: If true, the right eigenvectors will be computed. + use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the + eigendecomposition is computed using MAGMA. If ``False``, the computation + is done using LAPACK on to the host CPU. If ``None`` (default), the + behavior is controlled by the ``jax_use_magma`` flag. This argument + is only used on GPU. + Returns: The eigendecomposition of ``x``, which is a tuple of the form ``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left @@ -142,7 +173,8 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, for that batch element. """ return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma) def eigh( @@ -678,12 +710,14 @@ def _symmetric_product_jax_fn(a, c, *, alpha, beta): # Asymmetric eigendecomposition -def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): +def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): return dispatch.apply_primitive( eig_p, operand, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma, ) def eig_lower(*args, **kw): @@ -692,7 +726,8 @@ def eig_lower(*args, **kw): "If your matrix is symmetric or Hermitian, you should use eigh instead.") def eig_abstract_eval(operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " @@ -716,7 +751,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors, return tuple(output) def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused operand_aval, = ctx.avals_in out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] @@ -763,18 +799,94 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, return output +def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors, + compute_right_eigenvectors, use_magma): + gpu_solver.initialize_hybrid_kernels() + dtype = x.dtype + is_real = dtype == np.float32 or dtype == np.float64 + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + batch_dims = x.shape[:-2] + n, m = x.shape[-2:] + assert n == m + num_batch_dims = len(batch_dims) + + layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims) + out_types = [ + api.ShapeDtypeStruct(batch_dims + (n,), dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims, np.int32), + ] + out_layouts = [None, layout, layout, None] + if is_real: + out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types + out_layouts = [None] + out_layouts + + magma = config.gpu_use_magma.value + if use_magma is not None: + magma = "on" if use_magma else "off" + fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout], + output_layouts=out_layouts) + *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = lax.complex(*w) + else: + assert len(w) == 1 + w = w[0] + ok = lax.eq(info, lax.zeros_like_array(info)) + ok = _broadcast_to(ok[..., None], w.shape) + w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j)) + ok = _broadcast_to(ok[..., None], x.shape) + output = [w] + if compute_left_eigenvectors: + vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j)) + output.append(vl) + if compute_right_eigenvectors: + vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j)) + output.append(vr) + return output + + +def _eig_gpu_lowering(target_name_prefix, ctx, operand, *, + compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): + if ctx.is_forward_compat(): + raise NotImplementedError( + "Export of nonsymmetric eigendecomposition on GPU is not supported " + "because of forward compatibility. The " + "'jax_export_ignore_forward_compatibility' configuration option can be " + "used to disable this check.") + rule = mlir.lower_fun(partial( + _eig_gpu_impl, target_name_prefix, + compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), multiple_results=True) + return rule(ctx, operand) + + def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors), + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), (0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors)) def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' @@ -793,6 +905,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, eig_p.def_abstract_eval(eig_abstract_eval) mlir.register_lowering(eig_p, eig_lower) mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'), + platform='cuda') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'), + platform='rocm') batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule @@ -1289,7 +1405,6 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): - pivots = raise_to_shaped(pivots) if isinstance(pivots, ShapedArray): if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32): raise ValueError( @@ -1421,7 +1536,6 @@ def _lu_impl(operand): return lu, pivot, perm def _lu_abstract_eval(operand): - operand = raise_to_shaped(operand) if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") @@ -2503,37 +2617,54 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, batch_dims = operand_aval.shape[:-2] a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - gees_result = lapack.gees_hlo(operand_aval.dtype, operand, + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else () + gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand, jobvs=compute_schur_vectors, sort=sort_eig_vals, select=select_callable, a_shape_vals=a_shape_vals) - - # Number of return values depends on value of sort_eig_vals. - T, vs, *_, info = gees_result + if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat(): + schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result + else: + # Number of return values depends on value of sort_eig_vals. + schur_form, schur_vectors, *_, info = gees_result ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") - select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - T = _broadcasting_select_hlo( + select_schur_form_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) + schur_form = _broadcasting_select_hlo( ctx, - mlir.broadcast_in_dim(ctx, ok, select_T_aval, - broadcast_dimensions=range(len(batch_dims))), - select_T_aval, - T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]) - output = [T] + mlir.broadcast_in_dim( + ctx, + ok, + select_schur_form_aval, + broadcast_dimensions=range(len(batch_dims)), + ), + select_schur_form_aval, + schur_form, + ctx.avals_out[0], + _nan_like_hlo(ctx, ctx.avals_out[0]), + ctx.avals_out[0], + ) + output = [schur_form] if compute_schur_vectors: select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - vs = _broadcasting_select_hlo( + schur_vectors = _broadcasting_select_hlo( ctx, - mlir.broadcast_in_dim(ctx, ok, select_vs_aval, - broadcast_dimensions=range(len(batch_dims))), + mlir.broadcast_in_dim( + ctx, ok, select_vs_aval, broadcast_dimensions=range(len(batch_dims)) + ), select_vs_aval, - vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]) + schur_vectors, + ctx.avals_out[1], + _nan_like_hlo(ctx, ctx.avals_out[1]), + ctx.avals_out[1], + ) - output.append(vs) + output.append(schur_vectors) return output @@ -2706,24 +2837,35 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) - return tridiagonal(x), 0 + return tridiagonal(x, lower=lower), 0 batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule -def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower): +def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower, platform): a_aval, = ctx.avals_in - a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower) + cpu_args = [] + if platform == "cpu": + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else () + cpu_args.extend(ctx_args) + a, d, e, taus, info = sytrd_impl(*cpu_args, a_aval.dtype, a, lower=lower) return a, d, e, taus, info mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo), - platform='cpu') + tridiagonal_p, + partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo, platform="cpu"), + platform="cpu", +) mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd), - platform='cuda') + tridiagonal_p, + partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd, platform="cuda"), + platform="cuda", +) mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd), - platform='rocm') + tridiagonal_p, + partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd, platform="rocm"), + platform="rocm", +) # Utilities diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 9d4614f344fb..6ae2d02f82b7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,9 +24,12 @@ from jax import tree_util from jax._src import core +from jax._src import config +from jax._src import dispatch from jax._src import dtypes -from jax._src import sharding_impls -from jax._src.core import AxisName, ShapedArray, raise_to_shaped +from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, + NamedSharding, PartitionSpec as P) +from jax._src.core import AxisName, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -119,8 +122,25 @@ def psum(x, axis_name, *, axis_index_groups=None): leaves = [lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves] axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + # handle the constant case specially + if all(not isinstance(leaf, core.Tracer) for leaf in leaves): + named_axes, pos_axes = axes_partition = [], [] + for axis in axis_name: + axes_partition[isinstance(axis, int)].append(axis) + def pos_reduce(x): + if not pos_axes: + return x + return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0)) + for axis in pos_axes]) + if axis_index_groups is not None: + assert not pos_axes + size = len(axis_index_groups[0]) + else: + size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) + out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) + else: + out_flat = psum_p.bind( + *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) def pmean(x, axis_name, *, axis_index_groups=None): @@ -233,7 +253,7 @@ def _axis_index_of_val(x, val, axis_name): mask = (val == x) validx = lax.select(mask, lax.full(mask.shape, idx), - lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtype=idx.dtype)) + lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtypes.dtype(idx))) return pmin(validx, axis_name) def _validate_reduce_axis_index_groups(axis_index_groups): @@ -303,6 +323,8 @@ def ppermute(x, axis_name, perm): Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ + if not isinstance(axis_name, (list, tuple)): + axis_name = (axis_name,) return tree_util.tree_map( partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(map(tuple, perm))), x) @@ -435,6 +457,55 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): return tree_util.tree_map(bind, x) +def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes): + """Ragged version of :func:`all_to_all`. + + For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent + and the outermost (ragged) dimension. ``axis_index_groups`` is default to all + replicas (e.g. there is only one group and covers all axis indices). + + Ragged arrays are defined by a set of three arrays: + * ``data``: the ``data`` array is "ragged" along its outermost dimension, + along which each indexed element has variable size. + * ``offsets``: the ``offsets`` array indexes the outermost dimension of the + ``data`` array, and represents the starting offset of each ragged element of + the ``data`` array. + * ``sizes``: the ``sizes`` array represents the size of each ragged element of + the ``data`` array, where the size is specified in units of sub-elements. A + sub-element is defined as the suffix of the ``data`` array shape obtained by + removing the outermost "ragged" dimension. + The ``offsets`` and ``sizes`` arrays must have the same size. + + # Example ragged tensor + data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} + offsets: [3] = {0, 1, 4} + sizes: [3] = {1, 3, 4} + + # Index 'data' at 'offsets'[0], 'sizes'[0]' + {a,b,c} + + # Index 'data' at 'offsets'[1], 'sizes'[1]' + {d,e,f},{g,h,i},{j,k,l} + + # Index 'data' at 'offsets'[2], 'sizes'[2]' + {m,n,o},{p,q,r},{s,t,u},{v,w,x} + + Args: + operand: array with ragged dimension along its outermost dimension. + output: array of ragged input offsets. + input_offsets: array of ragged input send sizes. + send_sizes: array of ragged output data. + output_offsets: array of ragged output offsets. + recv_sizes: array of ragged output receive sizes. + Returns: + array with shape equal to ``output``. + """ + return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes, + output_offsets, recv_sizes) + +ragged_all_to_all_p = core.Primitive('ragged_all_to_all') + + def axis_index(axis_name): """Return the index along the mapped axis ``axis_name``. @@ -472,8 +543,15 @@ def axis_index(axis_name): [0 1] [0 1]] """ - return axis_index_p.bind(axis_name=axis_name) - + if not isinstance(axis_name, (tuple, list)): + return axis_index_p.bind(axis_name=axis_name) + else: + inner_size = 1 + index = 0 + for name in reversed(axis_name): + index += axis_index(name) * inner_size + inner_size *= psum(1, name) + return index def pgather(src, idx, axes: int | AxisName): """Uses the last positional axis of idx to index into src's axes.""" @@ -485,18 +563,30 @@ def pgather(src, idx, axes: int | AxisName): ### parallel primitives -def _subst_all_names_in_param( - pname: str, params: core.ParamDict, subst: core.AxisSubst, traverse: bool) -> core.ParamDict: - axis_name = params[pname] - if not isinstance(axis_name, (tuple, list)): - axis_name = (axis_name,) - result = dict(params) - result[pname] = sum(((name,) if isinstance(name, int) else subst(name) - for name in axis_name), - ()) - return result +def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]: + axis_names = params[pname] + if isinstance(axis_names, (tuple, list)): + return tuple(axis_names) + else: + return (axis_names,) + +def _constant_reduction(prim, axis_data, args, axes, axis_index_groups): + assert axis_data.name in axes + if axis_index_groups: raise NotImplementedError + new_axes = tuple(n for n in axes if n != axis_data.name) + if new_axes: + args = prim.bind(*args, axes=new_axes, axis_index_groups=axis_index_groups) + if prim is psum_p: + outs = [lax._const(x, axis_data.size) * x for x in args] + elif prim in (pmin_p, pmax_p): + outs = args + else: + raise Exception(f"Unrecognized reducer: {prim}") -def _reduction_with_positional_batcher(prim, vals_in, dims_in, axis_index_groups, + return outs, [None] * len(outs) + +def _reduction_with_positional_batcher( + prim, vals_in, dims_in, axis_index_groups, transform_unmapped, transform_mapped): if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap collectives. " @@ -536,10 +626,19 @@ def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups): return vals_out, [d if d is batching.not_mapped else 0 for d in dims_in] def _batched_reduction_collective( - prim, if_unmapped, axis_size, frame_name, _, vals_in, dims_in, axes, + prim, if_unmapped, axis_data, vals_in, dims_in, axes, axis_index_groups): assert prim.multiple_results - assert frame_name in axes + if all(d is None for d in dims_in): + if axis_data.name in axes: + return _constant_reduction(prim, axis_data, vals_in, axes, axis_index_groups) + else: + return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in + + if axis_data.name not in axes: + return _reduction_batcher(prim, vals_in, dims_in, axes=axes, + axis_index_groups=axis_index_groups) + # Note that we have a choice here. We can either unfuse the reduction into one # that handles the batched dims and then another one that handles the rest. # Alternatively, we can keep the dimension reduction fused with the rest, but @@ -548,12 +647,11 @@ def _batched_reduction_collective( # We choose the second strategy here. vals_out = _reduction_with_positional_batcher( prim, vals_in, dims_in, axis_index_groups, - lambda d, d_vals_in: (tuple(axis for axis in axes if axis != frame_name), - [if_unmapped(v, axis_size) for v in d_vals_in]), + lambda d, d_vals_in: (tuple(axis for axis in axes if axis != axis_data.name), + [if_unmapped(v, axis_data.size) for v in d_vals_in]), lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else - axis if axis != frame_name else - d - for axis in axes), + axis if axis != axis_data.name else + d for axis in axes), d_vals_in)) return vals_out, [batching.not_mapped] * len(vals_out) @@ -572,23 +670,40 @@ def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]] dtype=np.int64).T return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups)) -def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups): +def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups): assert axis_index_groups is None + if not all(isinstance(axis, int) for axis in axes): + return dispatch.apply_primitive(prim, *args, axes=axes, + axis_index_groups=axis_index_groups) assert all(isinstance(axis, int) for axis in axes) return [pos_reducer(arg, axes) for arg in args] def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): + _check_axis_names(axes) named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) if axis_index_groups is not None: if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), - arg.dtype) for arg in args] + if config.sharding_in_types.value: + out_avals = [ + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) + for arg in args + ] + else: + out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) + for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +def _check_axis_names(axes): + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + axis_env = core.get_axis_env() + for name in named_axes: + if not axis_env.axis_exists(name): + raise NameError(f"unbound axis name: {name}") + def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): len_0 = len(axis_index_groups[0]) @@ -615,10 +730,7 @@ def _positional_reduce(aval, arg): _replica_groups(ctx.module_context.axis_env, named_axes, axis_index_groups)) axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) def all_reduce(aval, x): if is_spmd: @@ -636,7 +748,11 @@ def all_reduce(aval, x): else: op = hlo.AllReduceOp( [x.type], [x], replica_groups=replica_groups, **other_args) - scalar_aval = core.ShapedArray((), aval.dtype) + if config.sharding_in_types.value: + scalar_aval = core.ShapedArray( + (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) + else: + scalar_aval = core.ShapedArray((), aval.dtype) scalar_type = mlir.aval_to_ir_type(scalar_aval) reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_block): @@ -669,64 +785,37 @@ def broadcast_positional(ct, arg): axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) -psum_p = core.AxisPrimitive('psum') +psum_p = core.Primitive('psum') psum_p.multiple_results = True -psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum)) +psum_p.def_impl(partial(_allreduce_impl, psum_p, lax._reduce_sum)) psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) ad.deflinear2(psum_p, _psum_transpose_rule) -batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p) -batching.axis_primitive_batchers[psum_p] = \ +batching.fancy_primitive_batchers[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v) -core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes') - +batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes') -# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at -# tracing time. -@psum_p.def_custom_bind -def psum_bind(*args, axes, axis_index_groups): - if all(not isinstance(x, core.Tracer) for x in args): - named_axes, pos_axes = axes_partition = [], [] - for axis in axes: - axes_partition[isinstance(axis, int)].append(axis) - def pos_reduce(x): - if not pos_axes: - return x - return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0)) - for axis in pos_axes]) - if axis_index_groups is not None: - assert not pos_axes - size = len(axis_index_groups[0]) - else: - size = math.prod([core.axis_frame(name).size for name in named_axes]) - return tuple(lax._const(x, size) * pos_reduce(x) for x in args) - return core.AxisPrimitive.bind( - psum_p, *args, axes=axes, axis_index_groups=axis_index_groups) - - -pmax_p = core.AxisPrimitive('pmax') +pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True -pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max)) +pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax._reduce_max)) pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) -batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p) -batching.axis_primitive_batchers[pmax_p] = \ +batching.fancy_primitive_batchers[pmax_p] = \ partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v) -core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes') +batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes') -pmin_p = core.AxisPrimitive('pmin') +pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True -pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min)) +pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax._reduce_min)) pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) -batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p) -batching.axis_primitive_batchers[pmin_p] = \ +batching.fancy_primitive_batchers[pmin_p] = \ partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v) -core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes') +batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes') def _ppermute_lowering(ctx, x, *, axis_name, perm): @@ -747,7 +836,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): axis_context = ctx.module_context.axis_context is_manual = ( - isinstance(axis_context, sharding_impls.SPMDAxisContext) + isinstance(axis_context, SPMDAxisContext) and axis_context.manual_axes ) if is_manual: @@ -765,15 +854,16 @@ def _ppermute_transpose_rule(t, x, perm, axis_name): inverse_perm = list(zip(dsts, srcs)) return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] -def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, perm): +def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): + axis_size, frame_name = axis_data.size, axis_data.name (v,), (d,) = vals_in, dims_in if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) + if axis_data.name not in axis_name: + return ppermute_p.bind(v, perm=perm, axis_name=axis_name), d remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) - if axis_size == 1 and remaining_axes: - return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d if remaining_axes: - raise NotImplementedError("ppermute batcher only supports a single axis") + return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d assert axis_name[0] == frame_name, "ppermute batcher called with a wrong axis!" assert len(perm) == axis_size, "Permutation doesn't match the axis size!" if d is batching.not_mapped: @@ -783,30 +873,33 @@ def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, per perm_indices[dst] = src return v.take(perm_indices, d), d -def _collective_batcher(prim, args, dims, **params): - return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] +def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): + _check_axis_names(axis_name) + return x -ppermute_p = core.AxisPrimitive('ppermute') -ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) +ppermute_p = core.Primitive('ppermute') +ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(ppermute_p, _ppermute_transpose_rule) mlir.register_lowering(ppermute_p, _ppermute_lowering) -batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) -batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher -core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher +batching.skippable_batchers[ppermute_p] = partial(_names_in_param, 'axis_name') def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))] -def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): +def _pbroadcast_batcher(axis_data, vals_in, dims_in, axis_name, source): + axis_size = axis_data.size (v,), (d,) = vals_in, dims_in if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) - remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) + if axis_data.name not in axis_name: + return pbroadcast_p.bind(v, axis_name=axis_name, source=source), d + remaining_axes = tuple(axis for axis in axis_name if axis != axis_data.name) if remaining_axes: raise NotImplementedError("pbroadcast batcher only supports a single axis") - assert axis_name[0] == frame_name, "pbroadcast batcher called with a wrong axis!" + assert axis_name[0] == axis_data.name, "pbroadcast batcher called with a wrong axis!" assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!" if axis_size == 1 and remaining_axes: return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d @@ -823,13 +916,12 @@ def source_to_front(group): return hlo.CollectiveBroadcastOp( x, replica_groups=_replica_groups_hlo(replica_groups)).results -pbroadcast_p = core.AxisPrimitive('pbroadcast') -pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) +pbroadcast_p = core.Primitive('pbroadcast') +pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) -batching.primitive_batchers[pbroadcast_p] = partial(_collective_batcher, pbroadcast_p) -batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -core.axis_substitution_rules[pbroadcast_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher +batching.skippable_batchers[pbroadcast_p] = partial(_names_in_param, 'axis_name') def _moveaxis(src, dst, x): @@ -862,7 +954,7 @@ def _all_to_all_lowering( raise ValueError('Replica groups must be equally sized') is_spmd = isinstance( ctx.module_context.axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -914,11 +1006,22 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, ) return result, d -def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, +def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): + axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + + if isinstance(axis_name, (list, tuple)): + axes_names = axis_name + else: + axes_names = [axis_name] + if axis_data.name not in axes_names: + return _all_to_all_batcher( + vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, + concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) + x, = vals_in d, = dims_in if d is batching.not_mapped: @@ -974,12 +1077,12 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, def _all_to_all_effectful_abstract_eval( - x, axis_name, split_axis, concat_axis, axis_index_groups, tiled + input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled ): del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - input_aval = raise_to_shaped(x) + _check_axis_names(axis_name) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) @@ -990,13 +1093,70 @@ def _all_to_all_effectful_abstract_eval( return out_aval, effects -all_to_all_p = core.AxisPrimitive('all_to_all') +all_to_all_p = core.Primitive('all_to_all') all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) -batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher -batching.axis_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective -core.axis_substitution_rules[all_to_all_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective +batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name') + + +def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes): + N = input_offsets.type.shape[0] + backend_config = ir.DictAttr.get({ + 'replica_groups': ir.DenseIntElementsAttr.get( + np.arange(0, N, 1, dtype=np.int64), shape=[1, N] + ) + }) + return hlo.CustomCallOp( + result=[output.type], + inputs=[operand, output, input_offsets, send_sizes, output_offsets, + recv_sizes], + call_target_name=ir.StringAttr.get('ragged_all_to_all'), + backend_config=backend_config, + api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4), + ).results + +@ragged_all_to_all_p.def_abstract_eval +def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes): + if operand.shape != output.shape: + raise ValueError('ragged_all_to_all input and output shapes must be equal.') + if not dtypes.issubdtype(input_offsets.dtype, np.integer): + raise ValueError("ragged_all_to_all input_offsets must be integer type.") + if not dtypes.issubdtype(send_sizes.dtype, np.integer): + raise ValueError("ragged_all_to_all send_sizes must be integer type.") + if not dtypes.issubdtype(output_offsets.dtype, np.integer): + raise ValueError("ragged_all_to_all output_offsets must be integer type.") + if not dtypes.issubdtype(recv_sizes.dtype, np.integer): + raise ValueError("ragged_all_to_all recv_sizes must be integer type.") + if len(input_offsets.shape) != 1 or input_offsets.shape[0] < 1: + raise ValueError( + "ragged_all_to_all input_offsets must be rank 1 with positive dimension" + " size, but got shape {}".format(input_offsets.shape) + ) + if len(send_sizes.shape) != 1 or send_sizes.shape[0] < 1: + raise ValueError( + "ragged_all_to_all send_sizes must be rank 1 with positive dimension" + " size, but got shape {}".format(send_sizes.shape) + ) + if len(output_offsets.shape) != 1 or output_offsets.shape[0] < 1: + raise ValueError( + "ragged_all_to_all output_offsets must be rank 1 with positive" + " dimension size, but got shape {}".format(output_offsets.shape) + ) + if len(recv_sizes.shape) != 1 or recv_sizes.shape[0] < 1: + raise ValueError( + "ragged_all_to_all recv_sizes must be rank 1 with positive dimension" + " size, but got shape {}".format(recv_sizes.shape) + ) + return output.update( + shape=list(output.shape), + dtype=output.dtype, + weak_type=output.weak_type, + ) + +ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p)) +mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): @@ -1063,6 +1223,8 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): [[12 13 14 15] [ 4 5 6 7]]] """ + if not isinstance(axis_name, tuple): + axis_name = axis_name, axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) def bind(leaf): @@ -1071,7 +1233,7 @@ def bind(leaf): all_gather_dimension=canonicalize_axis( axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), axis_name=axis_name, axis_index_groups=axis_index_groups, - axis_size=axis_size, tiled=tiled) + axis_size=int(axis_size), tiled=tiled) return tree_util.tree_map(bind, x) def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): @@ -1083,10 +1245,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, x_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) if not tiled: new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, 1) @@ -1122,11 +1281,11 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, def _all_gather_effectful_abstract_eval( - x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled + x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - x_aval = raise_to_shaped(x) + _check_axis_names(axis_name) new_shape = list(x_aval.shape) if tiled: new_shape[all_gather_dimension] *= axis_size @@ -1144,10 +1303,11 @@ def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): (x,), (d,) = vals_in, dims_in - if d <= all_gather_dimension: - all_gather_dimension += 1 - elif not tiled: # Tiled all-gather doesn't modify the set of dimensions - d += 1 + if d is not batching.not_mapped: + if d <= all_gather_dimension: + all_gather_dimension += 1 + elif not tiled: # Tiled all-gather doesn't modify the set of dimensions + d += 1 result = all_gather_p.bind( x, all_gather_dimension=all_gather_dimension, @@ -1157,9 +1317,15 @@ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, ax tiled=tiled) return result, d -def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, +def _all_gather_batched_collective(axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): + frame_size, frame_name = axis_data.size, axis_data.name + if frame_name not in axis_name: + return _all_gather_batcher( + vals_in, dims_in, all_gather_dimension=all_gather_dimension, + axis_name=axis_name, axis_index_groups=axis_index_groups, + axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap") assert axis_size == frame_size, "axis size doesn't match" @@ -1180,7 +1346,7 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, y = _foldaxis(all_gather_dimension, y) return y, batching.not_mapped -all_gather_p = core.AxisPrimitive('all_gather') +all_gather_p = core.Primitive('all_gather') all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval) all_gather_p.def_impl(_all_gather_impl) mlir.register_lowering(all_gather_p, _all_gather_lowering) @@ -1189,9 +1355,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, partial(_all_gather_lowering, platform=p), platform=p) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) -batching.primitive_batchers[all_gather_p] = _all_gather_batcher -batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective -core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[all_gather_p] = _all_gather_batched_collective +batching.skippable_batchers[all_gather_p] = partial(_names_in_param, 'axis_name') def _reduce_scatter_lowering( @@ -1208,7 +1373,7 @@ def _reduce_scatter_lowering( axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1244,11 +1409,11 @@ def _reduce_scatter_lowering( def _reduce_scatter_effectful_abstract_eval( - x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled + x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - x_aval = core.raise_to_shaped(x) + _check_axis_names(axis_name) new_shape = list(x_aval.shape) scatter_dim_input_size = x_aval.shape[scatter_dimension] if tiled: @@ -1289,9 +1454,15 @@ def _reduce_scatter_batcher(vals_in, dims_in, *, scatter_dimension, axis_name, tiled=tiled) return result, d -def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, +def _reduce_scatter_collective(axis_data, vals_in, dims_in, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled): + frame_size, frame_name = axis_data.size, axis_data.name + if frame_name not in axis_name: + return _reduce_scatter_batcher( + vals_in, dims_in, scatter_dimension=scatter_dimension, + axis_name=axis_name, axis_index_groups=axis_index_groups, + axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap") assert axis_size == frame_size, "axis size doesn't match" @@ -1310,21 +1481,17 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, return y, dy -reduce_scatter_p = core.AxisPrimitive("reduce_scatter") +reduce_scatter_p = core.Primitive("reduce_scatter") reduce_scatter_p.def_effectful_abstract_eval( _reduce_scatter_effectful_abstract_eval ) ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule) -batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher -batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective +batching.fancy_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective +batching.skippable_batchers[reduce_scatter_p] = partial(_names_in_param, 'axis_name') mlir.register_lowering(reduce_scatter_p, partial(_reduce_scatter_lowering, lax.add_p)) -core.axis_substitution_rules[reduce_scatter_p] = \ - partial(_subst_all_names_in_param, 'axis_name') - - def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False): """ @@ -1401,6 +1568,8 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, [12 14] [16 18]] """ + if not isinstance(axis_name, tuple): + axis_name = axis_name, axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) bind = partial( @@ -1420,6 +1589,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): raise NotImplementedError( '`axis_index` translation rule does not support multiple axis names.') axis_name, = axis_name + if axis_name not in axis_env.names: + raise NameError(f"unbound axis name: {axis_name}") axis_pos = list(axis_env.names).index(axis_name) nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( @@ -1431,7 +1602,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: device_id = hlo.partition_id() @@ -1443,51 +1614,22 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): unsigned_index) def _axis_index_lowering(ctx, *, axis_name): - return [ - _build_axis_index_lowering_hlo(ctx, axis_name, - ctx.module_context.axis_env) - ] - + return [_build_axis_index_lowering_hlo(ctx, axis_name, + ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - frame = core.axis_frame(axis_name) + _check_axis_names([axis_name]) return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} +def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): + return lax.iota(np.int32, axis_data.size), 0 + axis_index_p = core.Primitive('axis_index') +axis_index_p.def_impl(partial(dispatch.apply_primitive, axis_index_p)) mlir.register_lowering(axis_index_p, _axis_index_lowering) axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval) -core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name') - -# Axis index doesn't get any arguments, so that the default bind would have no -# way to call into a data-dependency based trace such as vmap. Each trace that -# wants to bind an axis name has to additionally implement `process_axis_index` -# and put its main trace on the axis env stack. -def _axis_index_bind(*, axis_name): - def name_idx(name): - frame = core.axis_frame(name) - dynamic = core.thread_local_state.trace_state.trace_stack.dynamic - if (frame.main_trace is None or dynamic.level > frame.main_trace.level): - return core.Primitive.bind(axis_index_p, axis_name=name) - else: - trace = frame.main_trace.with_cur_sublevel() - return trace.process_axis_index(frame) - - if not isinstance(axis_name, (tuple, list)): - return name_idx(axis_name) - else: - inner_size = 1 - index = 0 - for name in reversed(axis_name): - index += name_idx(name) * inner_size - inner_size *= psum(1, name) - return index -axis_index_p.def_custom_bind(_axis_index_bind) - -def _vmap_process_axis_index(self, frame): - assert frame.size is not None - return batching.BatchTracer(self, lax.iota(np.int32, frame.size), 0) -batching.BatchTrace.process_axis_index = _vmap_process_axis_index # type: ignore - +batching.fancy_primitive_batchers[axis_index_p] = _axis_index_batcher +batching.skippable_batchers[axis_index_p] = partial(_names_in_param, 'axis_name') def _pgather_impl(src, idx, *, axes): assert all(isinstance(axis, int) for axis in axes) @@ -1508,6 +1650,7 @@ def _pgather_impl(src, idx, *, axes): def _pgather_abstract_eval(src, idx, *, axes): # TODO: Avals with names rule: remove all axes from src, insert those from idx # The order is important, because it is ok to re-insert one of the deleted axes! + _check_axis_names(axes) shape = list(src.shape) for axis in sorted((a for a in axes if isinstance(a, int)), reverse=True): del shape[axis] @@ -1521,22 +1664,6 @@ def _pgather_parallel_lowering(ctx, src, idx, *, axes): return mlir.lower_fun(_pgather_impl, multiple_results=False)( ctx, src, idx, axes=axes) -def _pgather_batcher(vals_in, dims_in, *, axes): - src, idx = vals_in - dsrc, didx = dims_in - if didx is not batching.not_mapped and dsrc is not batching.not_mapped: - # NB: We could just go forward with it and take the diagonal along the - # two axes we get in the output, but that would be quite inefficient - raise NotImplementedError("Please open a feature request!") - elif didx is not batching.not_mapped: - return pgather_p.bind(src, idx, axes=axes), didx - elif dsrc is not batching.not_mapped: - src_last_batched = moveaxis(src, dsrc, -1) - result = pgather_p.bind(src_last_batched, idx, axes=axes) - return result, result.ndim - 1 - else: - assert False # This shouldn't get called anyway - def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, axes): src, idx = vals_in dsrc, didx = dims_in @@ -1559,11 +1686,10 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a else: return pgather_p.bind(src, idx, axes=new_axes), batching.not_mapped -pgather_p = core.AxisPrimitive('pgather') +pgather_p = core.Primitive('pgather') pgather_p.def_impl(_pgather_impl) pgather_p.def_abstract_eval(_pgather_abstract_eval) mlir.register_lowering(pgather_p, _pgather_parallel_lowering) # TODO: Transpose? That requires adding pscatter... -batching.primitive_batchers[pgather_p] = _pgather_batcher -batching.axis_primitive_batchers[pgather_p] = _pgather_collective_batcher -core.axis_substitution_rules[pgather_p] = partial(_subst_all_names_in_param, 'axes') +batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher +batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes') diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 40a04ff11d2c..c6c85ce4f6a3 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1270,6 +1270,39 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): return tuple(core.stride_dim(d, window_size=1, window_stride=s) for d, s in zip(diff, strides)) +def _get_sub_spec_size(mesh, sub_spec): + if isinstance(sub_spec, tuple): + return math.prod(mesh.shape[s] for s in sub_spec) + return mesh.shape[sub_spec] + +def _get_sharding_for_varying_out_shape(out_shape, operand, name): + """Returns a sharding when out_shape may not be the same as operand shape""" + mesh = operand.sharding.mesh + for op_sh, out_sh, op_spec in safe_zip( + operand.shape, out_shape, operand.sharding.spec): + if (op_sh != out_sh and op_spec is not None and + out_sh % _get_sub_spec_size(mesh, op_spec) != 0): + raise NotImplementedError( + f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" + f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" + f" ({op_spec}) is not implemented.") + # TODO(yashkatariya): Returning operand.sharding as is may or may not move + # data. So think about how to avoid it which might include creating a new + # mesh? For example: + # mesh = {'x': 4} + # x = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))` + # ys = lax.split(x, [4, 4]) # This will create outputs of shape (4,) + # According to the current logic, ys[0].sharding.spec == P('x') + # which involves data movement. + return operand.sharding + +def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _slice_shape_rule(operand, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + return _get_sharding_for_varying_out_shape(out_shape, operand, 'slicing') + def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape @@ -1308,7 +1341,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, out = slice(operand, new_start_indices, new_limit_indices, new_strides) return out, bdim -slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice') +slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', + sharding_rule=_slice_sharding_rule) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1333,14 +1367,16 @@ def _slice_impl(x, start_indices, limit_indices, strides): def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): strides = strides or [1] * len(start_indices) aval_out, = ctx.avals_out - return [mlir.slice_op(ctx, x, aval_out, - start_indices=start_indices, limit_indices=limit_indices, strides=strides)] + out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(slice_p, _slice_lower) -def _dynamic_slice_shape_rule( - operand, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if operand.ndim != len(start_indices): msg = ("dynamic_slice start_indices must have length equal to the number " @@ -1363,6 +1399,12 @@ def _dynamic_slice_shape_rule( f" got indices {start_indices}") return tuple(lax._merge_dyn_shape(slice_sizes, dyn)) +def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes): + out_shape = _dynamic_slice_shape_rule( + operand, *starts_and_dyn_sizes, slice_sizes=slice_sizes) + return _get_sharding_for_varying_out_shape(out_shape, operand, 'dynamic_slice') + + def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if any(i.dtype != start_indices[0].dtype or @@ -1466,7 +1508,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + sharding_rule=_dynamic_slice_sharding_rule) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1480,7 +1523,10 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): aval_out, = ctx.avals_out if dyn: aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn)) - return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)] + out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower) @@ -1511,6 +1557,14 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): f"scalars, got indices {start_indices}") return operand.shape +def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): + if operand.sharding != update.sharding: + raise TypeError( + "dynamic_update_slice update sharding must be equal to operand" + f" sharding, got update sharding {update.sharding} for operand sharding" + f" {operand.sharding}.") + return operand.sharding + def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): lax.check_same_dtypes("dynamic_update_slice", operand, update) if any(i.dtype != start_indices[0].dtype or @@ -1576,7 +1630,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice') + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -1585,8 +1639,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): def _dynamic_update_slice_lower(ctx, x, update, *start_indices): aval_out, = ctx.avals_out - return [mlir.dynamic_update_slice(ctx, aval_out, x, update, - start_indices=start_indices)] + out = mlir.dynamic_update_slice(ctx, aval_out, x, update, + start_indices=start_indices) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index deb3c19c0a61..78d125436029 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -52,10 +52,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = type(max(avals, key=_get_array_abstraction_level)) - if least_specialized is core.ConcreteArray: - out = prim.impl(*[x.val for x in avals], **kwargs) - return core.ConcreteArray(out.dtype, out, weak_type=weak_type) - elif least_specialized is core.ShapedArray: + if least_specialized is core.ShapedArray: return core.ShapedArray( shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, @@ -72,20 +69,21 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs): + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) weak_types = weak_type_rule(*avals, **kwargs) - if least_specialized is core.ConcreteArray: - out_vals = prim.impl(*[x.val for x in avals], **kwargs) - return [core.ConcreteArray(val.dtype, val, weak_type=weak_type) - for val, weak_type in zip(out_vals, weak_types)] - elif least_specialized is core.ShapedArray: + if least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) - return [core.ShapedArray(s, d, weak_type=weak_type) - for s, d, weak_type in zip(out_shapes, out_dtypes, weak_types)] + out_shardings = (sharding_rule(*avals, **kwargs) + if config.sharding_in_types.value else + [None] * len(out_shapes)) + return [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) + for s, d, weak_type, sh in zip(out_shapes, out_dtypes, weak_types, + out_shardings)] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [core.UnshapedArray(dtype, weak_type=weak_type) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 089a77de2949..462e5fbed1c5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -23,7 +23,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import util -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -142,14 +142,15 @@ def _get_monoid_window_reducer( return None x, = xs aval = core.get_aval(x) - if (type(aval) is ConcreteArray) and aval.shape == (): + if core.is_concrete(x) and aval.shape == (): + val = core.to_concrete_value(x) if monoid_op is lax.add: - return aval.val == 0 and _reduce_window_sum + return val == 0 and _reduce_window_sum elif monoid_op is lax.max: - return (aval.val == lax._get_max_identity(aval.dtype) + return (val == lax._get_max_identity(aval.dtype) and _reduce_window_max) elif monoid_op is lax.min: - return (aval.val == lax._get_min_identity(aval.dtype) + return (val == lax._get_min_identity(aval.dtype) and _reduce_window_min) return None diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 64bbd3268b16..5309f0b1fd9c 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -19,7 +19,7 @@ import numpy as np from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding -from jax._src.sharding_impls import AUTO as AutoSharding, is_auto +from jax._src.sharding_impls import AUTO as AutoSharding from jax._src.lib import xla_client as xc Shape = tuple[int, ...] @@ -101,7 +101,7 @@ def __init__(self, device_local_layout: LayoutOptions = None, sharding: ShardingOptions = None): # If layout is concrete and sharding is not, error. if (isinstance(device_local_layout, DeviceLocalLayout) and - (sharding is None or is_auto(sharding))): + (sharding is None or isinstance(sharding, AutoSharding))): raise ValueError( 'Sharding has to be concrete when layout is of type' f' {type(device_local_layout)}. Please pass a' diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 7068c0ef6732..1fcbd4b6b7ef 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -40,6 +40,7 @@ py_library_providing_imports_info( "//jax:version", ] + if_building_jaxlib([ "//jaxlib", + "//jaxlib/mosaic/python:gpu_dialect", "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib:cpu_feature_guard", "//jaxlib:utils", diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index ea9191b2cc7f..2810002013ac 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -120,7 +120,14 @@ def _xla_gc_callback(*args): import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 -import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 +try: + import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error +except ImportError: + # TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36. + # Jaxlib doesn't contain Mosaic GPU dialect bindings. + mosaic_gpu_dialect = None # type: ignore + +import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 # Version number for MLIR:Python APIs, provided by jaxlib. mlir_api_version = xla_client.mlir_api_version @@ -143,7 +150,20 @@ def _try_cuda_nvcc_import() -> str | None: from nvidia import cuda_nvcc # pytype: disable=import-error except ImportError: return None - cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent + + if hasattr(cuda_nvcc, '__file__') and cuda_nvcc.__file__ is not None: + # `cuda_nvcc` is a regular package. + cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent + elif hasattr(cuda_nvcc, '__path__') and cuda_nvcc.__path__ is not None: + # `cuda_nvcc` is a namespace package, which might have multiple paths. + cuda_nvcc_path = None + for path in cuda_nvcc.__path__: + if (pathlib.Path(path) / 'bin' / 'ptxas').exists(): + cuda_nvcc_path = pathlib.Path(path) + break + else: + return None + return str(cuda_nvcc_path) if (path := _try_cuda_root_environment_variable()) is not None: @@ -155,9 +175,5 @@ def _try_cuda_nvcc_import() -> str | None: cuda_path = _cuda_path() -if version >= (0, 4, 35): - guard_lib = xla_client._xla.guard_lib -else: - guard_lib = xla_client._xla.transfer_guard_lib - +guard_lib = xla_client._xla.guard_lib Device = xla_client._xla.Device diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 8cb1fedb9ef3..37d812dec619 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -71,7 +71,6 @@ def trans1(static_arg, *dynamic_args, **kwargs): from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.tree_util import tree_map from jax._src.util import curry, cache_clearing_funs @@ -151,10 +150,11 @@ class WrappedFun: params: extra parameters to pass as keyword arguments to `f`, along with the transformed keyword arguments. """ - __slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info") + __slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info") - def __init__(self, f, transforms, stores, params, in_type, debug_info): + def __init__(self, f, f_transformed, transforms, stores, params, in_type, debug_info): self.f = f + self.f_transformed = f_transformed self.transforms = transforms self.stores = stores self.params = params @@ -167,8 +167,14 @@ def __name__(self): def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: """Add another transform and its store.""" - return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + if out_store is None: + return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) + else: + return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -177,47 +183,8 @@ def populate_stores(self, stores): self_store.store(other_store.val) def call_wrapped(self, *args, **kwargs): - """Calls the underlying function, applying the transforms. - - The positional `args` and keyword `kwargs` are passed to the first - transformation generator. - """ - stack = [] - for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): - gen = gen(*(gen_static_args + tuple(args)), **kwargs) - args, kwargs = next(gen) - stack.append((gen, out_store)) - gen = gen_static_args = out_store = None - - try: - ans = self.f(*args, **dict(self.params, **kwargs)) - except: - # Some transformations yield from inside context managers, so we have to - # interrupt them before reraising the exception. Otherwise they will only - # get garbage-collected at some later time, running their cleanup tasks - # only after this exception is handled, which can corrupt the global - # state. - while stack: - stack.pop()[0].close() - raise - - args = kwargs = None - while stack: - gen, out_store = stack.pop() - try: - ans = gen.send(ans) - except: - # As above does for the first half of the transformation, exceptions - # raised in the second half of the transformation also require us to - # clean up references here. - while stack: - stack.pop()[0].close() - raise - if out_store is not None: - ans, side = ans - out_store.store(side) - - return ans + """Calls the transformed function""" + return self.f_transformed(*args, **kwargs) def __repr__(self): def transform_to_str(x): @@ -236,7 +203,7 @@ def __eq__(self, other): self.debug_info == other.debug_info) @curry -def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: +def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """Adds one more transformation to a WrappedFun. Args: @@ -246,8 +213,28 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """ return fun.wrap(gen, gen_static_args, None) +# Backwards compat only. TODO: deprecate +@curry +def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + return gen_inst.send(f(*args_, **kwargs_)) + return transformation2(gen2, fun, *gen_static_args)() + +# Backwards compat only. TODO: deprecate @curry -def transformation_with_aux( +def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, store, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + ans, aux = gen_inst.send(f(*args_, **kwargs_)) + store.store(aux) + return ans + return transformation_with_aux2(gen2, fun, *gen_static_args)() + +@curry +def transformation_with_aux2( gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False ) -> tuple[WrappedFun, Callable[[], Any]]: """Adds one more transformation with auxiliary output to a WrappedFun.""" @@ -263,8 +250,9 @@ def fun_name(f): def wrap_init(f, params=None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" + params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, (), (), params, None, None) + return WrappedFun(f, partial(f, **params_dict), (), (), params, None, None) def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: @@ -272,13 +260,13 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) assert all(isinstance(a, core.AbstractValue) and type(b) is bool - and not isinstance(a, core.ConcreteArray) for a, b in in_type) + for a, b in in_type) def valid_size(d) -> bool: if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: @@ -319,7 +307,7 @@ def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None assert f.debug_info is None if debug_info is None: return f - return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info) def cache(call: Callable, *, explain: Callable | None = None): @@ -337,13 +325,8 @@ def cache(call: Callable, *, explain: Callable | None = None): def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore - if config.check_tracer_leaks.value: - key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, - config.enable_x64.value, config.default_device.value, - config.trace_context()) - else: - key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value, - config.default_device.value, config.trace_context()) + key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value, + config.default_device.value, config.trace_context()) result = cache.get(key, None) if result is not None: ans, stores = result @@ -364,20 +347,9 @@ def _evict_function(f): cache_clearing_funs.add(memoized_fun.cache_clear) return memoized_fun - -def _copy_main_trace(x): - if isinstance(x, core.MainTrace): - return core.MainTrace(x.level, x.trace_type, **x.payload) - else: - return x - -_copy_main_traces = partial(tree_map, _copy_main_trace) - - - -@transformation -def hashable_partial(*args): - yield (yield args, {}) +@transformation2 +def hashable_partial(f, *args): + return f(*args) def merge_linear_aux(aux1, aux2): diff --git a/jax/_src/logging_config.py b/jax/_src/logging_config.py index d2f9d9c8fb1f..bdf588d2054a 100644 --- a/jax/_src/logging_config.py +++ b/jax/_src/logging_config.py @@ -13,19 +13,92 @@ # limitations under the License. import logging +import os import sys -_debug_handler = logging.StreamHandler(sys.stderr) -_debug_handler.setLevel(logging.DEBUG) # Example log message: # DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu' -_debug_handler.setFormatter(logging.Formatter( - "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{')) +logging_formatter = logging.Formatter( + "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{') -_debug_enabled_loggers = [] +_logging_level_set: dict[str, int] = {} +_default_TF_CPP_MIN_LOG_LEVEL = os.environ.get("TF_CPP_MIN_LOG_LEVEL", "1") + +_jax_logger_handler = logging.StreamHandler(sys.stderr) +_jax_logger_handler.setFormatter(logging_formatter) + +_nameToLevel = { + 'CRITICAL': logging.CRITICAL, + 'FATAL': logging.FATAL, + 'ERROR': logging.ERROR, + 'WARN': logging.WARNING, + 'WARNING': logging.WARNING, + 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG, + 'NOTSET': logging.NOTSET, +} + +_tf_cpp_map = { + 'CRITICAL': 3, + 'FATAL': 3, + 'ERROR': 2, + 'WARN': 1, + 'WARNING': 1, + 'INFO': 0, + 'DEBUG': 0, +} + +def _set_TF_CPP_MIN_LOG_LEVEL(logging_level: str | None = None): + if logging_level in (None, "NOTSET"): + # resetting to user-default TF_CPP_MIN_LOG_LEVEL + # this is typically "1", but if the user overrode it, it can be != "1" + os.environ["TF_CPP_MIN_LOG_LEVEL"] = _default_TF_CPP_MIN_LOG_LEVEL + else: + # set cpp runtime logging level if the level is anything but NOTSET + if logging_level not in _tf_cpp_map: + raise ValueError(f"Attempting to set log level \"{logging_level}\" which" + f" isn't one of the supported:" + f" {list(_tf_cpp_map.keys())}.") + # config the CPP logging level 0 - debug, 1 - info, 2 - warning, 3 - error + os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(_tf_cpp_map[logging_level]) + +def update_logging_level_global(logging_level: str | None) -> None: + # remove previous handlers + for logger_name, level in _logging_level_set.items(): + logger = logging.getLogger(logger_name) + logger.removeHandler(_jax_logger_handler) + logger.setLevel(level) + _logging_level_set.clear() + _set_TF_CPP_MIN_LOG_LEVEL(logging_level) + + if logging_level is None: + return + + logging_level_num = _nameToLevel[logging_level] + # update jax and jaxlib root loggers for propagation + root_loggers = [logging.getLogger("jax"), logging.getLogger("jaxlib")] + for logger in root_loggers: + logger.setLevel(logging_level_num) + logger.addHandler(_jax_logger_handler) + _logging_level_set[logger.name] = logger.level -def enable_debug_logging(logger_name): +# per-module debug logging + +_jax_logger = logging.getLogger("jax") + +class _DebugHandlerFilter(logging.Filter): + def filter(self, _): + return _jax_logger.level > logging.DEBUG + +_debug_handler = logging.StreamHandler(sys.stderr) +_debug_handler.setLevel(logging.DEBUG) +_debug_handler.setFormatter(logging_formatter) +_debug_handler.addFilter(_DebugHandlerFilter()) + +_debug_enabled_loggers = [] + +def _enable_debug_logging(logger_name): """Makes the specified logger log everything to stderr. Also adds more useful debug information to the log messages, e.g. the time. @@ -34,21 +107,28 @@ def enable_debug_logging(logger_name): logger_name: the name of the logger, e.g. "jax._src.xla_bridge". """ logger = logging.getLogger(logger_name) + _debug_enabled_loggers.append((logger, logger.level)) + logger.addHandler(_debug_handler) logger.setLevel(logging.DEBUG) - _debug_enabled_loggers.append(logger) -def disable_all_debug_logging(): +def _disable_all_debug_logging(): """Disables all debug logging enabled via `enable_debug_logging`. The default logging behavior will still be in effect, i.e. WARNING and above will be logged to stderr without extra message formatting. """ - for logger in _debug_enabled_loggers: + for logger, prev_level in _debug_enabled_loggers: + logger: logging.Logger logger.removeHandler(_debug_handler) - # Assume that the default non-debug log level is always WARNING. In theory - # we could keep track of what it was set to before. This shouldn't make a - # difference if not other handlers are attached, but set it back in case - # something else gets attached (e.g. absl logger) and for consistency. - logger.setLevel(logging.WARNING) + logger.setLevel(prev_level) + _debug_enabled_loggers.clear() + +def update_debug_log_modules(module_names_str: str | None): + _disable_all_debug_logging() + if not module_names_str: + return + module_names = module_names_str.split(',') + for module_name in module_names: + _enable_debug_logging(module_name) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 8cb508378129..c7b8f692055d 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -18,6 +18,7 @@ import collections from collections.abc import Hashable, Sequence import contextlib +import enum import functools import math import threading @@ -101,6 +102,26 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) +class AxisTypes(enum.Enum): + Auto = enum.auto() + User = enum.auto() + Collective = enum.auto() + + def __repr__(self): + return self.name + +def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: + if axis_types is None: + return {} + d = {} + for t, names in axis_types.items(): + if isinstance(names, tuple): + for n in names: + d[n] = t + else: + d[names] = t + return d + _mesh_object_dict = {} # type: ignore @@ -157,9 +178,11 @@ class Mesh(contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None def __new__(cls, devices: np.ndarray | Sequence[xc.Device], - axis_names: str | Sequence[MeshAxisName]): + axis_names: str | Sequence[MeshAxisName], + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) if isinstance(axis_names, str): @@ -175,7 +198,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - key = (axis_names, devices.shape, tuple(devices.flat)) + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) + key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple) val = _mesh_object_dict.get(key, None) if val is not None: return val @@ -184,11 +210,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names + self.axis_types = axis_types + self._axis_types_tuple = axis_types_tuple _mesh_object_dict[key] = self return self def __reduce__(self): - return (type(self), (self.devices, self.axis_names)) + return (type(self), (self.devices, self.axis_names, self.axis_types)) def __eq__(self, other): if not isinstance(other, Mesh): @@ -199,12 +227,14 @@ def __eq__(self, other): return True return (self.axis_names == other.axis_names and self.devices.shape == other.devices.shape and + self._axis_types_tuple == other._axis_types_tuple and self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.axis_names, self._internal_device_list, self.devices.shape)) + (self.axis_names, self._internal_device_list, self.devices.shape, + self._axis_types_tuple)) return self._hash def __setattr__(self, name, value): @@ -224,17 +254,17 @@ def __enter__(self): new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.mesh_context_manager.set_local( + tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return self def __exit__(self, exc_type, exc_value, traceback): thread_resources.stack.pop() thread_resources.env = thread_resources.stack[-1] - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.mesh_context_manager.set_local( + tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return False @property @@ -253,6 +283,10 @@ def shape_tuple(self): def axis_sizes(self) -> tuple[int, ...]: return self.devices.shape + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @property def size(self): return math.prod(self.shape.values()) if self.devices.ndim else 0 @@ -301,7 +335,8 @@ def __str__(self): def _repr(self): if self.empty: return "Mesh(device_ids=[], axis_names=())" - return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})" def __repr__(self): return self._repr @@ -313,7 +348,7 @@ def local_devices(self): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.shape_tuple) + return AbstractMesh(self.shape_tuple, self.axis_types) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -338,25 +373,32 @@ class AbstractMesh: details. """ - def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): + def __init__(self, shape_tuple: tuple[tuple[str, int], ...], + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): self.shape_tuple = shape_tuple + self.axis_types = axis_types if self.shape_tuple: self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) else: self._axis_names, self._axis_sizes = (), () + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + self._axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) def __hash__(self): - return hash(self.shape_tuple) + return hash((self.shape_tuple, self._axis_types_tuple)) def __eq__(self, other): if not isinstance(other, AbstractMesh): return False if id(self) == id(other): return True - return self.shape_tuple == other.shape_tuple + return (self.shape_tuple == other.shape_tuple and + self._axis_types_tuple == other._axis_types_tuple) def __repr__(self): - return f"AbstractMesh({self.shape_tuple})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"AbstractMesh({self.shape_tuple}{atr})" @property def axis_names(self): @@ -366,6 +408,10 @@ def axis_names(self): def axis_sizes(self) -> tuple[int, ...]: return self._axis_sizes + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @functools.cached_property def size(self): return math.prod(self._axis_sizes) if self._axis_sizes else 0 @@ -382,6 +428,12 @@ def _internal_device_list(self): def empty(self): return self.size == 0 + @functools.cached_property + def _are_all_axes_collective(self) -> bool: + if self.axis_types is None: + return False + return all(t == AxisTypes.Collective for t in self.axis_types.keys()) + @property def devices(self): _raise_value_error("devices") @@ -403,14 +455,14 @@ def local_mesh(self): _raise_value_error("local_mesh") def __enter__(self): - raise RuntimeError("AbstractMesh is not a context manager") + _raise_value_error("__enter__") def __exit__(self, exc_type, exc_value, traceback): - raise RuntimeError("AbstractMesh is not a context manager") + _raise_value_error("__exit__") @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - jax_config.update_thread_local_jit_state(mesh_context_manager=mesh) + jax_config.abstract_mesh_context_manager.set_local(mesh) return @@ -418,3 +470,34 @@ def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): # property raises an exception unconditionally. Remove this once that is fixed. def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") + + +@contextlib.contextmanager +def set_abstract_mesh(mesh: AbstractMesh): + prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh) + try: + yield + finally: + jax_config.abstract_mesh_context_manager.set_local(prev_val) + +def get_abstract_mesh(): + return jax_config.abstract_mesh_context_manager.value + + +@contextlib.contextmanager +def set_concrete_mesh(mesh: Mesh): + prev_val = jax_config.device_context.swap_local(mesh) + try: + yield + finally: + jax_config.device_context.set_local(prev_val) + +def get_concrete_mesh(): + return jax_config.device_context.value + + +@contextlib.contextmanager +def set_mesh(mesh: Mesh): + with (set_abstract_mesh(mesh.abstract_mesh), + jax_config.sharding_in_types(True), set_concrete_mesh(mesh)): + yield diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index bb6152167658..588863d1f244 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -32,6 +32,8 @@ _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' _TPU_V5_LITE = "TPU v5 lite" +_TPU_V5E = "TPU v5e" +_TPU_V5P = "TPU v5p" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -69,6 +71,7 @@ _TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) _V5E_TRAY_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 4) _V5E_TRAY_IOTA_ORDER = (0, 4, 2, 6, 1, 5, 3, 7) +_V5P_2x2x2_ORDER = (0, 1, 3, 2, 6, 7, 5, 4) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -147,6 +150,35 @@ def _v5e_create_device_mesh( return None +def _v5p_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 2: + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_V5P_2x2x2_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + return None + # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. @@ -157,6 +189,7 @@ def _v5e_create_device_mesh( _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, _TPU_V5_LITE: _v5e_create_device_mesh, + _TPU_V5P: _v5p_create_device_mesh, } @@ -572,16 +605,6 @@ def _generate_logical_mesh( return logical_mesh -def _bounds_from_last_device(last_device) -> Sequence[int]: - """Gets the bound from the given last device.""" - # Must be passed the device at the highest-coordinate corner of the - # relevant mesh, which is a requirement we know is satisfied by the last - # device in jax.devices(). - assert hasattr(last_device, 'coords'), 'Only TPU supported' - x, y, z = last_device.coords - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 - - def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: r"""Rearrange TPU devices in a slice into a physical mesh. @@ -682,6 +705,15 @@ def _transpose_trick( *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] ) +def _canonicalize_axis_sizes(axis_sizes: Sequence[int] + ) -> tuple[int, ...] | None: + new_sizes = [] + for s in axis_sizes: + try: + new_sizes.append(int(s)) + except: + return None + return tuple(new_sizes) def create_device_mesh( mesh_shape: Sequence[int], @@ -717,17 +749,25 @@ def create_device_mesh( """ if devices is None: devices = xb.devices() - if np.prod(mesh_shape) != len(devices): + + new_mesh_shape = _canonicalize_axis_sizes(mesh_shape) + if new_mesh_shape is None: + raise ValueError( + f'`mesh_shape` passed to `create_device_mesh` should be a sequence of' + f' ints. Got {mesh_shape}') + del mesh_shape + + if math.prod(new_mesh_shape) != len(devices): raise ValueError( f'Number of devices {len(devices)} must equal the product ' - f'of mesh_shape {mesh_shape}' + f'of mesh_shape {new_mesh_shape}' ) last_device = devices[-1] handler = device_kind_handler_dict.get(last_device.device_kind, None) if handler is not None: result = handler( - mesh_shape, devices, contiguous_submeshes=contiguous_submeshes + new_mesh_shape, devices, contiguous_submeshes=contiguous_submeshes ) if result is not None: return result @@ -735,15 +775,15 @@ def create_device_mesh( if last_device.platform == 'tpu': physical_mesh = _get_physical_tpu_mesh(devices) if contiguous_submeshes: - physical_mesh = _transpose_trick(physical_mesh, mesh_shape) + physical_mesh = _transpose_trick(physical_mesh, new_mesh_shape) device_mesh, _ = _create_device_mesh_for_nd_torus( physical_mesh, - mesh_shape, + new_mesh_shape, allow_split_physical_axes=allow_split_physical_axes, ) return device_mesh else: - device_mesh = np.asarray(devices).reshape(mesh_shape) + device_mesh = np.asarray(devices).reshape(new_mesh_shape) return device_mesh diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 861e3d0123ff..5dfaa7b7e5f7 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -22,7 +22,6 @@ import math import numpy as np from typing import Any, Literal -import warnings import jax import jax.numpy as jnp @@ -502,7 +501,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: + initial: Unspecified = _UNSPECIFIED) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ + # TODO(jakevdp): remove the initial argument after JAX v0.4.40. if initial is not _UNSPECIFIED: - # Added 2024-4-10 - warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.", - DeprecationWarning, stacklevel=2) + raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) @@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike, def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: + initial: Unspecified = _UNSPECIFIED) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -577,10 +575,9 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ + # TODO(jakevdp): remove the initial argument after JAX v0.4.40. if initial is not _UNSPECIFIED: - # Added 2024-4-10 - warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.", - DeprecationWarning, stacklevel=2) + raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index eb1bb1609bbf..8086a97a3748 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -36,7 +36,6 @@ export = set_module('jax.nn.initializers') -KeyArray = Array # TODO: Import or define these to match # https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py. DTypeLikeFloat = Any @@ -48,13 +47,13 @@ @typing.runtime_checkable class Initializer(Protocol): @staticmethod - def __call__(key: KeyArray, + def __call__(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: raise NotImplementedError @export -def zeros(key: KeyArray, +def zeros(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of zeros. @@ -69,7 +68,7 @@ def zeros(key: KeyArray, return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) @export -def ones(key: KeyArray, +def ones(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of ones. @@ -100,7 +99,7 @@ def constant(value: ArrayLike, Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2, Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2, Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2, [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int], fan_out = out_size * receptive_field_size return fan_in, fan_out -def _complex_uniform(key: KeyArray, +def _complex_uniform(key: Array, shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ @@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray, theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, +def _complex_truncated_normal(key: Array, upper: ArrayLike, shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ @@ -314,7 +313,7 @@ def variance_scaling( dtype: the dtype of the weights. """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: shape = core.canonicalize_shape(shape) @@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0, Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -654,7 +653,7 @@ def delta_orthogonal( .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 7b98a5314744..4768a8126c72 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -607,7 +607,6 @@ def __array_module__(self, types): return NotImplemented -@core.stash_axis_env() @partial(jax.jit, static_argnums=(1,2,3)) def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index 90a17000cf16..ec67d7489f30 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -24,10 +24,14 @@ arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module import numpy as np +export = set_module('jax.numpy') + + __all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] @@ -87,7 +91,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: return stack(output_arr, 0) -mgrid = _Mgrid() +mgrid = export(_Mgrid()) class _Ogrid: @@ -129,7 +133,7 @@ def __getitem__( return meshgrid(*output, indexing='ij', sparse=True) -ogrid = _Ogrid() +ogrid = export(_Ogrid()) _IndexType = Union[ArrayLike, str, slice] @@ -279,7 +283,7 @@ class RClass(_AxisConcat): op_name = "r_" -r_ = RClass() +r_ = export(RClass()) class CClass(_AxisConcat): @@ -327,7 +331,7 @@ class CClass(_AxisConcat): op_name = "c_" -c_ = CClass() +c_ = export(CClass()) s_ = np.s_ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6b1bd9acf3ca..3d99405428de 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -52,7 +52,7 @@ from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal from jax._src.lax.lax import ( PrecisionLike,_array_copy, @@ -67,14 +67,17 @@ DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, ) from jax._src.util import ( - NumpyComplexWarning, - canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) -from jax.sharding import Sharding, SingleDeviceSharding + NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, + ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + tuple_replace) +from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, + PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum +export = set_module('jax.numpy') + for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: try: cuda_plugin_extension = importlib.import_module( @@ -116,6 +119,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: printoptions = np.printoptions set_printoptions = np.set_printoptions +@export def iscomplexobj(x: Any) -> bool: """Check if the input is a complex number or an array containing complex elements. @@ -194,6 +198,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), {"dtype": np.dtype(np_scalar_type)}) meta.__module__ = _PUBLIC_MODULE_NAME + meta.__doc__ =\ + f"""A JAX scalar constructor of type {np_scalar_type.__name__}. + + While NumPy defines scalar types for each data type, JAX represents + scalars as zero-dimensional arrays. + """ return meta bool_ = _make_scalar_type(np.bool_) @@ -211,6 +221,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) @@ -321,6 +335,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: return clip(val, min_val, max_val).astype(dtype) +@export def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: """Load JAX arrays from npy files. @@ -370,6 +385,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> ### implementations of numpy functions in terms of lax +@export @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. @@ -421,6 +437,7 @@ def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise maximum of the input arrays. @@ -470,6 +487,7 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: """Return True if arg1 is equal or lower than arg2 in the type hierarchy. @@ -516,6 +534,7 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) +@export def isscalar(element: Any) -> bool: """Return True if the input is a scalar. @@ -605,15 +624,18 @@ def isscalar(element: Any) -> bool: >>> jnp.isscalar(slice(10)) False """ - if (isinstance(element, (np.ndarray, jax.Array)) - or hasattr(element, '__jax_array__') - or np.isscalar(element)): + if np.isscalar(element): + return True + elif isinstance(element, (np.ndarray, jax.Array)): + return element.ndim == 0 + elif hasattr(element, '__jax_array__'): return asarray(element).ndim == 0 return False iterable = np.iterable +@export def result_type(*args: Any) -> DType: """Return the result of applying JAX promotion rules to the inputs. @@ -657,6 +679,7 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) +@export @jit def trunc(x: ArrayLike) -> Array: """Round input to the nearest integer towards zero. @@ -733,6 +756,7 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, @@ -808,6 +832,7 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision=precision, preferred_element_type=preferred_element_type) +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, @@ -893,6 +918,7 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision=precision, preferred_element_type=preferred_element_type) +@export def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: None | Array | Sequence[ArrayLike] = None, weights: ArrayLike | None = None) -> Array: @@ -944,6 +970,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, return linspace(range[0], range[1], bins_int + 1, dtype=dtype) +@export def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Sequence[ArrayLike] | None = None, weights: ArrayLike | None = None, @@ -1025,6 +1052,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, return counts, bin_edges +@export def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1114,6 +1142,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = return hist, edges[0], edges[1] +@export def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1223,6 +1252,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim +@export def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -1301,6 +1331,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) +@export def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: """Permute the axes/dimensions of an array. @@ -1330,6 +1361,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: return lax.transpose(a, axes) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose the last two dimensions of an array. @@ -1383,6 +1415,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) +@export @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. @@ -1466,6 +1499,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) +@export def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Reverse the order of elements of an array along the given axis. @@ -1533,6 +1567,7 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) +@export def fliplr(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 1. @@ -1559,6 +1594,7 @@ def fliplr(m: ArrayLike) -> Array: return _flip(asarray(m), 1) +@export def flipud(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 0. @@ -1584,6 +1620,8 @@ def flipud(m: ArrayLike) -> Array: util.check_arraylike("flipud", m) return _flip(asarray(m), 0) + +@export @jit def iscomplex(x: ArrayLike) -> Array: """Return boolean array showing where the input is complex. @@ -1607,6 +1645,8 @@ def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) + +@export @jit def isreal(x: ArrayLike) -> Array: """Return boolean array showing where the input is real. @@ -1631,6 +1671,7 @@ def isreal(x: ArrayLike) -> Array: return lax.eq(i, _lax_const(i, 0)) +@export @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: """Return the angle of a complex valued number or array. @@ -1682,6 +1723,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result +@export @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, @@ -1794,6 +1836,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, return arr +@export @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: @@ -1856,6 +1899,8 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result + +@export @partial(jit, static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, @@ -1986,6 +2031,7 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad +@export def isrealobj(x: Any) -> bool: """Check if the input is not a complex number or an array containing complex elements. @@ -2020,6 +2066,7 @@ def isrealobj(x: Any) -> bool: return not iscomplexobj(x) +@export def reshape( a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), @@ -2098,20 +2145,11 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(micky774): deprecated 2024-5-9, remove after deprecation expires. + # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. if not isinstance(newshape, DeprecatedArg): - if shape is not None: - raise ValueError( - "jnp.reshape received both `shape` and `newshape` arguments. Note that " - "using `newshape` is deprecated, please only use `shape` instead." - ) - deprecations.warn( - "jax-numpy-reshape-newshape", - ("The newshape argument of jax.numpy.reshape is deprecated. " - "Please use the shape argument instead."), stacklevel=2) - shape = newshape - del newshape - elif shape is None: + raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." + " Use shape instead.") + if shape is None: raise TypeError( "jnp.shape requires passing a `shape` argument, but none was given." ) @@ -2123,6 +2161,7 @@ def reshape( return asarray(a).reshape(shape, order=order) +@export @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: """Flatten array into a 1-dimensional shape. @@ -2176,6 +2215,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: return reshape(a, (size(a),), order) +@export def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: """Convert multi-dimensional indices into flat indices. @@ -2267,6 +2307,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], return result +@export def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: """Convert flat indices into multi-dimensional indices. @@ -2330,6 +2371,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: for s, i in safe_zip(shape, out_indices)) +@export @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: """Return a new array with specified shape. @@ -2381,6 +2423,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) +@export def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Remove one or more length-1 axes from array @@ -2451,6 +2494,7 @@ def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: return lax.squeeze(a, axis) +@export def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: """Insert dimensions of length 1 into array @@ -2521,6 +2565,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: return lax.expand_dims(a, axis) +@export @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: """Swap two axes of an array. @@ -2568,6 +2613,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: return lax.transpose(a, list(perm)) +@export def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: """Move an array axis to a new position @@ -2633,6 +2679,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) - return lax.transpose(a, perm) +@export @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -2777,6 +2824,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f +@export def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, @@ -2859,6 +2907,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, ) -> Array | tuple[Array, ...]: ... +@export def where(condition, x=None, y=None, /, *, size=None, fill_value=None): """Select elements from two arrays based on a condition. @@ -2934,6 +2983,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): return util._where(condition, x, y) +@export def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -3001,6 +3051,7 @@ def select( return lax.select_n(*broadcast_arrays(idx, *choicelist)) +@export def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: @@ -3064,6 +3115,8 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, Array([2, 1, 0, 1, 0], dtype=int32) """ util.check_arraylike("bincount", x) + if _dtype(x) == bool: + x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") if ndim(x) != 1: @@ -3074,7 +3127,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, x_arr = core.concrete_or_error(asarray, x, "The error occurred because of argument 'x' of jnp.bincount. " "To avoid this error, pass a static `length` argument.") - length = max(minlength, x_arr.size and int(x_arr.max()) + 1) + length = max(minlength, x_arr.size and int(max(0, x_arr.max())) + 1) else: length = core.concrete_dim_or_error(length, "The error occurred because of argument 'length' of jnp.bincount.") @@ -3091,6 +3144,7 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... +@export def broadcast_shapes(*shapes): """Broadcast input shapes to a common output shape. @@ -3131,6 +3185,7 @@ def broadcast_shapes(*shapes): return lax.broadcast_shapes(*shapes) +@export def broadcast_arrays(*args: ArrayLike) -> list[Array]: """Broadcast arrays to a common shape. @@ -3170,6 +3225,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: return util._broadcast_arrays(*args) +@export def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: """Broadcast an array to a specified shape. @@ -3246,6 +3302,7 @@ def _split(op: str, ary: ArrayLike, for start, end in zip(split_indices[:-1], split_indices[1:])] +@export def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3309,6 +3366,7 @@ def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, return _split("split", ary, indices_or_sections, axis=axis) +@export def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays vertically. @@ -3343,6 +3401,7 @@ def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("vsplit", ary, indices_or_sections, axis=0) +@export def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays horizontally. @@ -3383,6 +3442,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) +@export def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays depth-wise. @@ -3424,6 +3484,7 @@ def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("dsplit", ary, indices_or_sections, axis=2) +@export def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3449,6 +3510,7 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array return _split("array_split", ary, indices_or_sections, axis=axis) +@export @jit def clip( arr: ArrayLike | None = None, @@ -3520,6 +3582,7 @@ def clip( return asarray(arr) +@export @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Round input evenly to the given number of decimals. @@ -3591,12 +3654,14 @@ def _round_float(x: ArrayLike) -> Array: return _round_float(a) +@export @partial(jit, static_argnames=('decimals',)) def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Alias of :func:`jax.numpy.round`""" return round(a, decimals, out) +@export @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. @@ -3635,6 +3700,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) +@export @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, @@ -3700,6 +3766,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, return out +@export @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -3748,6 +3815,7 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, return reductions.all(isclose(a, b, rtol, atol, equal_nan)) +@export def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: @@ -3855,6 +3923,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, return out +@export def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array: """Return indices of nonzero elements in a flattened array @@ -3900,10 +3969,64 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, return nonzero(ravel(a), size=size, fill_value=fill_value)[0] -@util.implements(np.unwrap) +@export @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: + """Unwrap a periodic signal. + + JAX implementation of :func:`numpy.unwrap`. + + Args: + p: input array + discont: the maximum allowable discontinuity in the sequence. The + default is ``period / 2`` + axis: the axis along which to unwrap; defaults to -1 + period: the period of the signal, which defaults to :math:`2\\pi` + + Returns: + An unwrapped copy of ``p``. + + Examples: + Consider a situation in which you are making measurements of the position of + a rotating disk via the ``x`` and ``y`` locations of some point on that disk. + The underlying variable is an always-increating angle which we'll generate + this way, using degrees for ease of representation: + + >>> rng = np.random.default_rng(0) + >>> theta = rng.integers(0, 90, size=(20,)).cumsum() + >>> theta + array([ 76, 133, 179, 203, 230, 233, 239, 240, 255, 328, 386, 468, 513, + 567, 654, 719, 775, 823, 873, 957]) + + Our observations of this angle are the ``x`` and ``y`` coordinates, given by + the sine and cosine of this underlying angle: + + >>> x, y = jnp.sin(jnp.deg2rad(theta)), jnp.cos(jnp.deg2rad(theta)) + + Now, say that given these ``x`` and ``y`` coordinates, we wish to recover + the original angle ``theta``. We might do this via the :func:`atan2` function: + + >>> theta_out = jnp.rad2deg(jnp.atan2(x, y)).round() + >>> theta_out + Array([ 76., 133., 179., -157., -130., -127., -121., -120., -105., + -32., 26., 108., 153., -153., -66., -1., 55., 103., + 153., -123.], dtype=float32) + + The first few values match the input angle ``theta`` above, but after this the + values are wrapped because the ``sin`` and ``cos`` observations obscure the phase + information. The purpose of the :func:`unwrap` function is to recover the original + signal from this wrapped view of it: + + >>> jnp.unwrap(theta_out, period=360) + Array([ 76., 133., 179., 203., 230., 233., 239., 240., 255., 328., 386., + 468., 513., 567., 654., 719., 775., 823., 873., 957.], dtype=float32) + + It does this by assuming that the true underlying sequence does not differ by more than + ``discont`` (which defaults to ``period / 2``) within a single step, and when it encounters + a larger discontinuity it adds factors of the period to the data. For periodic signals + that satisfy this assumption, :func:`unwrap` can recover the original phased signal. + """ util.check_arraylike("unwrap", p) p = asarray(p) if issubdtype(p.dtype, np.complexfloating): @@ -3987,15 +4110,37 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str): def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array: nd = ndim(array) - constant_values = broadcast_to(constant_values, (nd, 2)) constant_values = lax_internal._convert_element_type( constant_values, array.dtype, dtypes.is_weakly_typed(array)) + constant_values_nd = ndim(constant_values) + + if constant_values_nd == 0: + widths = [(low, high, 0) for (low, high) in pad_width] + return lax.pad(array, constant_values, widths) + + if constant_values_nd == 1: + if constant_values.shape[-1] == 1: + widths = [(low, high, 0) for (low, high) in pad_width] + return lax.pad(array, squeeze(constant_values), widths) + elif constant_values.shape[-1] == 2: + widths = [(low, 0, 0) for (low, _) in pad_width] + array = lax.pad(array, constant_values[0], widths) + widths = [(0, high, 0) for (_, high) in pad_width] + return lax.pad(array, constant_values[1], widths) + else: + raise ValueError("jnp.pad: constant_values has unsupported shape " + f"{constant_values.shape}. If the shape is 1D or 2D, the " + "last dimension must be of size 1 or 2.") + + constant_values = broadcast_to(constant_values, (nd, 2)) for i in range(nd): widths = [(0, 0, 0)] * nd - widths[i] = (pad_width[i][0], 0, 0) - array = lax.pad(array, constant_values[i, 0], widths) - widths[i] = (0, pad_width[i][1], 0) - array = lax.pad(array, constant_values[i, 1], widths) + if pad_width[i][0] != 0: + widths[i] = (pad_width[i][0], 0, 0) + array = lax.pad(array, constant_values[i, 0], widths) + if pad_width[i][1] != 0: + widths[i] = (0, pad_width[i][1], 0) + array = lax.pad(array, constant_values[i, 1], widths) return array @@ -4254,6 +4399,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") +@export def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: """Add padding to an array. @@ -4410,6 +4556,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], ### Array-creation functions +@export def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: """Join arrays along a new axis. @@ -4476,6 +4623,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], return concatenate(new_arrays, axis=axis, dtype=dtype) +@export @partial(jit, static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: """Unstack an array along an axis. @@ -4516,6 +4664,8 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: ) return tuple(moveaxis(x, axis, 0)) + +@export def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: """Construct an array by repeating ``A`` along specified dimensions. @@ -4579,6 +4729,7 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, return lax.reshape(arr, shape, dimensions) +@export def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: """Join arrays along an existing axis. @@ -4642,6 +4793,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] +@export def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: """Join arrays along an existing axis. @@ -4682,6 +4834,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: return jax.numpy.concatenate(arrays, axis=axis) +@export def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Vertically stack arrays. @@ -4742,6 +4895,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) +@export def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Horizontally stack arrays. @@ -4802,6 +4956,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) +@export def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Stack arrays depth-wise. @@ -4862,6 +5017,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) +@export def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """Stack arrays column-wise. @@ -4922,6 +5078,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, axis=1) +@export def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: """Construct an array by stacking slices of choice arrays. @@ -5045,9 +5202,78 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: else: return asarray(xs), 1 -@util.implements(np.block) + +@export @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: + """Create an array from a list of blocks. + + JAX implementation of :func:`numpy.block`. + + Args: + arrays: an array, or nested list of arrays which will be concatenated + together to form the final array. + + Returns: + a single array constructed from the inputs. + + See also: + - :func:`concatenate`, :func:`concat`: concatenate arrays along an existing axis. + - :func:`stack`, :func:`vstack`, :func:`hstack`, :func:`dstack` concatenate + arrays along a new axis. + + Examples: + consider these blocks: + + >>> zeros = jnp.zeros((2, 2)) + >>> ones = jnp.ones((2, 2)) + >>> twos = jnp.full((2, 2), 2) + >>> threes = jnp.full((2, 2), 3) + + Passing a single array to :func:`block` returns the array: + + >>> jnp.block(zeros) + Array([[0., 0.], + [0., 0.]], dtype=float32) + + Passing a simple list of arrays concatenates them along the last axis: + + >>> jnp.block([zeros, ones]) + Array([[0., 0., 1., 1.], + [0., 0., 1., 1.]], dtype=float32) + + Passing a doubly-nested list of arrays concatenates the inner list along + the last axis, and the outer list along the second-to-last axis: + + >>> jnp.block([[zeros, ones], + ... [twos, threes]]) + Array([[0., 0., 1., 1.], + [0., 0., 1., 1.], + [2., 2., 3., 3.], + [2., 2., 3., 3.]], dtype=float32) + + Note that blocks need not align in all dimensions, though the size along the axis + of concatenation must match. For example, this is valid because after the inner, + horizontal concatenation, the resulting blocks have a valid shape for the outer, + vertical concatenation. + + >>> a = jnp.zeros((2, 1)) + >>> b = jnp.ones((2, 3)) + >>> c = jnp.full((1, 2), 2) + >>> d = jnp.full((1, 2), 3) + >>> jnp.block([[a, b], [c, d]]) + Array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [2., 2., 3., 3.]], dtype=float32) + + Note also that this logic generalizes to blocks in 3 or more dimensions. + Here's a 3-dimensional block-wise array: + + >>> x = jnp.arange(6).reshape((1, 2, 3)) + >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] + >>> jnp.block(blocks).shape + (5, 8, 9) + """ out, _ = _block(arrays) return out @@ -5061,6 +5287,7 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 1 dimension. @@ -5115,6 +5342,7 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 2 dimensions. @@ -5178,6 +5406,7 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 3 dimensions. @@ -5254,6 +5483,7 @@ def _supports_buffer_protocol(obj): return True +@export def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5446,6 +5676,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x +@export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: @@ -5511,6 +5742,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result +@export def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None, device: xc.Device | Sharding | None = None) -> Array: @@ -5592,6 +5824,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) +@export def copy(a: ArrayLike, order: str | None = None) -> Array: """Return a copy of the array. @@ -5640,6 +5873,7 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) +@export def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5682,6 +5916,7 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5724,6 +5959,7 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5773,6 +6009,7 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device +@export def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5821,6 +6058,7 @@ def full(shape: Any, fill_value: ArrayLike, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5877,6 +6115,7 @@ def full_like(a: ArrayLike | DuckTypedArray, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of zeros. @@ -5913,6 +6152,7 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of ones. @@ -5949,6 +6189,7 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an empty array. @@ -5992,6 +6233,7 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore "with a single tuple argument for the shape?") +@export def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: """Check if two arrays are element-wise equal. @@ -6033,6 +6275,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: return reductions.all(eq) +@export def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: """Check if two arrays are element-wise equal. @@ -6073,6 +6316,7 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. +@export def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: r"""Convert a buffer into a 1-D JAX array. @@ -6120,6 +6364,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) +@export def fromfile(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromfile. @@ -6138,6 +6383,7 @@ def fromfile(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def fromiter(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromiter. @@ -6156,6 +6402,7 @@ def fromiter(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None) -> Array: """Construct a JAX array via DLPack. @@ -6216,6 +6463,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, return from_dlpack(x, device=device, copy=copy) +@export def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: """Create an array from a function applied over indices. @@ -6302,6 +6550,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) +@export def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: """Convert a string of text into 1-D JAX array. @@ -6330,6 +6579,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) +@export def eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None, @@ -6409,6 +6659,7 @@ def _eye(N: DimSize, M: DimSize | None = None, return (i + offset == j).astype(dtype) +@export def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: """Create a square identity matrix @@ -6442,6 +6693,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: return eye(n, dtype=dtype) +@export def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, step: ArrayLike | None = None, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -6609,6 +6861,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, @@ -6734,6 +6987,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return (result, delta) if retstep else result +@export def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: @@ -6819,6 +7073,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) +@export def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: """Generate geometrically-spaced values. @@ -6893,6 +7148,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) +@export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: """Construct N-dimensional grid arrays from N 1-dimensional vectors. @@ -6913,6 +7169,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, A length-N list of grid arrays. See also: + - :func:`jax.numpy.indices`: generate a grid of indices. - :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax. - :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax. @@ -6973,6 +7230,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output +@export @jit def i0(x: ArrayLike) -> Array: r"""Calculate modified Bessel function of first kind, zeroth order. @@ -7022,6 +7280,7 @@ def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) +@export def ix_(*args: ArrayLike) -> tuple[Array, ...]: """Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. @@ -7085,9 +7344,39 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: ... -@util.implements(np.indices) +@export def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: + """Generate arrays of grid indices. + + JAX implementation of :func:`numpy.indices`. + + Args: + dimensions: the shape of the grid. + dtype: the dtype of the indices (defaults to integer). + sparse: if True, then return sparse indices. Default is False, which + returns dense indices. + + Returns: + An array of shape ``(len(dimensions), *dimensions)`` If ``sparse`` is False, + or a sequence of arrays of the same length as ``dimensions`` if ``sparse`` is True. + + See also: + - :func:`jax.numpy.meshgrid`: generate a grid from arbitrary input arrays. + - :obj:`jax.numpy.mgrid`: generate dense indices using a slicing syntax. + - :obj:`jax.numpy.ogrid`: generate sparse indices using a slicing syntax. + + Examples: + >>> jnp.indices((2, 3)) + Array([[[0, 0, 0], + [1, 1, 1]], + + [[0, 1, 2], + [0, 1, 2]]], dtype=int32) + >>> jnp.indices((2, 3), sparse=True) + (Array([[0], + [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)) + """ dtypes.check_user_dtype_supported(dtype, "indices") dtype = dtype or dtypes.canonicalize_dtype(int_) dimensions = tuple( @@ -7106,6 +7395,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, return stack(output, 0) if output else array([], dtype=dtype) +@export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: """Construct an array from repeated elements. @@ -7250,6 +7540,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) +@export @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: @@ -7309,6 +7600,7 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) +@export def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: r"""Return an array with ones on and below the diagonal and zeros elsewhere. @@ -7365,6 +7657,7 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None return lax_internal._tri(dtype, (N, M), k) +@export @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: r"""Return lower triangle of an array. @@ -7426,6 +7719,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) +@export @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: r"""Return upper triangle of an array. @@ -7491,6 +7785,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) +@export @partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -7556,6 +7851,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int return reductions.sum(a, axis=(-2, -1), dtype=dtype) +@export def mask_indices(n: int, mask_func: Callable[[ArrayLike, int], Array], k: int = 0, *, size: int | None = None) -> tuple[Array, Array]: @@ -7615,6 +7911,7 @@ def _triu_size(n, m, k): return mk * (mk + 1) // 2 + mk * (m - k - mk) +@export def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of upper triangle of an array of size ``(n, m)``. @@ -7673,6 +7970,7 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of lower triangle of an array of size ``(n, m)``. @@ -7731,6 +8029,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. @@ -7788,6 +8087,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. @@ -7845,6 +8145,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: """Return a copy of the array with the diagonal overwritten. @@ -7926,6 +8227,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n)) +@export def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a multidimensional array. @@ -7961,6 +8263,8 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: .format(ndim)) return (lax.iota(int_, n),) * ndim + +@export def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a given array. @@ -8002,6 +8306,8 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) + +@export @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: @@ -8053,6 +8359,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] +@export def diag(v: ArrayLike, k: int = 0) -> Array: """Returns the specified diagonal or constructs a diagonal array. @@ -8116,6 +8423,8 @@ def _diag(v, k): else: raise ValueError("diag input must be 1d or 2d") + +@export def diagflat(v: ArrayLike, k: int = 0) -> Array: """Return a 2-D array with the flattened input array laid out on the diagonal. @@ -8171,6 +8480,8 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res +# TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2 +@export def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. @@ -8225,6 +8536,8 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] + +@export @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None @@ -8279,6 +8592,7 @@ def append( return concatenate([arr, values], axis=axis) +@export def delete( arr: ArrayLike, obj: ArrayLike | slice, @@ -8403,6 +8717,7 @@ def delete( return a[tuple(slice(None) for i in range(axis)) + (mask,)] +@export def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = None) -> Array: """Insert entries into an array at specified indices. @@ -8502,6 +8817,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, return out +@export def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: @@ -8579,6 +8895,7 @@ def apply_along_axis( return func(arr) +@export def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: """Apply a function repeatedly over specified axes. @@ -8637,6 +8954,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, ### Tensor contraction operations +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -8726,6 +9044,7 @@ def dot(a: ArrayLike, b: ArrayLike, *, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -8750,7 +9069,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, Returns: array containing the matrix product of the inputs. Shape is ``a.shape[:-1]`` - if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading + if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading dimensions of ``a`` and ``b`` are broadcast together. See Also: @@ -8849,6 +9168,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, @@ -8897,6 +9217,7 @@ def vdot( preferred_element_type=preferred_element_type) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -8952,6 +9273,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) +@export def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, @@ -9081,6 +9403,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: ... @overload @@ -9093,8 +9416,10 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: ... +@export def einsum( subscripts, /, *operands, @@ -9103,6 +9428,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: """Einstein summation @@ -9334,11 +9660,11 @@ def einsum( contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) if spec is not None: einsum = jax.named_call(einsum, name=spec) return einsum(operands, contractions, precision, - preferred_element_type, _dot_general) + preferred_element_type, _dot_general, out_type) # Enable other modules to override einsum_contact_path. @@ -9369,6 +9695,7 @@ def einsum_path( optimize: bool | str | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... +@export def einsum_path( subscripts, /, *operands, @@ -9437,7 +9764,15 @@ def _einsum( precision, preferred_element_type, _dot_general=lax.dot_general, + out_type=None, ): + if out_type is not None and not config.sharding_in_types.value: + raise NotImplementedError("out_type only works when sharding_in_types " + "config is True.") + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError( + "`out_type` argument of `einsum` only supports NamedSharding instances." + " Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") operands = list(map(asarray, operands)) if preferred_element_type is None: @@ -9559,13 +9894,25 @@ def filter_singleton_dims(operand, names, other_shape, other_names): names = batch_names_str + remaining_rhs_names + remaining_lhs_names if names == result_names: dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch)) + k_out_type = {} if out_type is None else {'out_type': out_type} operand = _dot_general(rhs, lhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + **k_out_type) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names + if (config.sharding_in_types.value and out_type is not None and + names != result_names): + spec = out_type.spec + inverse_spec = tuple(spec[result_names.index(name)] for name in names) + dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec)) + else: + dot_general_out_type = out_type # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) + dot_general_out_type = ({} if dot_general_out_type is None else # type: ignore + {'out_type': dot_general_out_type}) operand = _dot_general(lhs, rhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + **dot_general_out_type) else: raise NotImplementedError # if this is actually reachable, open an issue! @@ -9578,9 +9925,11 @@ def filter_singleton_dims(operand, names, other_shape, other_names): operand = lax.transpose(operand, perm) operands.append(operand) # used in next iteration - return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type) + return lax_internal._convert_element_type(operands[0], preferred_element_type, + output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9637,6 +9986,7 @@ def inner( preferred_element_type=preferred_element_type) +@export @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """Compute the outer product of two arrays. @@ -9671,6 +10021,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: return ravel(a)[:, None] * ravel(b)[None, :] +@export @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): @@ -9771,6 +10122,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) +@export @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: """Compute the Kronecker product of two input arrays. @@ -9816,6 +10168,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) +@export @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False @@ -9879,6 +10232,7 @@ def vander( ### Misc +@export def argwhere( a: ArrayLike, *, @@ -9944,6 +10298,7 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) +@export def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the maximum value of an array. @@ -9999,22 +10354,23 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the minimum value of an array. - JAX implementation of :func:`numpy.argmax`. + JAX implementation of :func:`numpy.argmin`. Args: a: input array - axis: optional integer specifying the axis along which to find the maximum + axis: optional integer specifying the axis along which to find the minimum value. If ``axis`` is not specified, ``a`` will be flattened. out: unused by JAX keepdims: if True, then return an array with the same number of dimensions as ``a``. Returns: - an array containing the index of the maximum value along the specified axis. + an array containing the index of the minimum value along the specified axis. See also: - :func:`jax.numpy.argmax`: return the index of the maximum value. @@ -10054,6 +10410,7 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def nanargmax( a: ArrayLike, axis: int | None = None, @@ -10121,6 +10478,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export def nanargmin( a: ArrayLike, axis: int | None = None, @@ -10181,6 +10539,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, @@ -10244,6 +10603,7 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result +@export @jit def sort_complex(a: ArrayLike) -> Array: """Return a sorted copy of complex array. @@ -10281,6 +10641,7 @@ def sort_complex(a: ArrayLike) -> Array: return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) +@export @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: """Sort a sequence of keys in lexicographic order. @@ -10358,6 +10719,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1] +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, @@ -10438,6 +10800,7 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices +@export @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns a partially-sorted copy of an array. @@ -10508,6 +10871,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) +@export @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns indices that partially sort an array. @@ -10612,6 +10976,8 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a + +@export def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: """Roll the elements of an array along a specified axis. @@ -10665,6 +11031,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], return _roll_static(arr, shift, axis) +@export @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: """Roll the specified axis to a given position. @@ -10730,6 +11097,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) +@export @partial(jit, static_argnames=('axis', 'bitorder')) def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: """Pack array of bits into a uint8 array. @@ -10814,6 +11182,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar return swapaxes(packed, axis, -1) +@export @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -10905,6 +11274,7 @@ def unpackbits( return swapaxes(unpacked, axis, -1) +@export def take( a: ArrayLike, indices: ArrayLike, @@ -11062,6 +11432,7 @@ def _normalize_index(index, axis_size): return lax.select(index < 0, lax.add(index, axis_size_val), index) +@export @partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, @@ -11249,6 +11620,106 @@ def replace(tup, val): mode="fill" if mode is None else mode, fill_value=fill_value) +_indices = indices # argument below named 'indices' shadows the function + + +def _make_along_axis_idx(shape, indices, axis): + return tuple_replace(_indices(shape, sparse=True), axis, indices) + + +@export +@partial(jit, static_argnames=('axis', 'inplace', 'mode')) +def put_along_axis( + arr: ArrayLike, + indices: ArrayLike, + values: ArrayLike, + axis: int | None, + inplace: bool = True, + *, + mode: str | None = None, +) -> Array: + """Put values into the destination array by matching 1d index and data slices. + + JAX implementation of :func:`numpy.put_along_axis`. + + The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + arr: array into which values will be put. + indices: array of indices at which to put values. + values: array of values to put into the array. + axis: the axis along which to put values. If not specified, the array will + be flattened before indexing is applied. + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options, + see :attr:`jax.numpy.ndarray.at`. + + Returns: + A copy of ``a`` with specified entries updated. + + See Also: + - :func:`jax.numpy.put`: put elements into an array at given indices. + - :func:`jax.numpy.place`: place elements into an array via boolean mask. + - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. + - :func:`jax.numpy.take`: extract values from an array at given indices. + - :func:`jax.numpy.take_along_axis`: extract values from an array along an axis. + + Examples: + >>> from jax import numpy as jnp + >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) + >>> i = jnp.argmax(a, axis=1, keepdims=True) + >>> print(i) + [[1] + [0]] + >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) + >>> print(b) + [[10 99 20] + [99 40 50]] + """ + if inplace: + raise ValueError( + "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays" + "are immutable. Pass inplace=False to instead return an updated array.") + + util.check_arraylike("put_along_axis", arr, indices, values) + arr = asarray(arr) + indices = asarray(indices) + values = asarray(values) + + original_axis = axis + original_arr_shape = arr.shape + + if axis is None: + arr = arr.ravel() + axis = 0 + + if not arr.ndim == indices.ndim: + raise ValueError( + "put_along_axis arguments 'arr' and 'indices' must have same ndim. Got " + f"{arr.ndim=} and {indices.ndim=}." + ) + + try: + values = broadcast_to(values, indices.shape) + except ValueError: + raise ValueError( + "put_along_axis argument 'values' must be broadcastable to 'indices'. Got " + f"{values.shape=} and {indices.shape=}." + ) + + idx = _make_along_axis_idx(arr.shape, indices, axis) + result = arr.at[idx].set(values, mode=mode) + + if original_axis is None: + result = result.reshape(original_arr_shape) + + return result + + ### Indexing def _is_integer_index(idx: Any) -> bool: @@ -11500,6 +11971,14 @@ def _int(aval): def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], normalize_indices: bool = True) -> _Indexer: + # Check whether advanced indices are contiguous. We must do this before + # removing ellipses (https://github.com/jax-ml/jax/issues/25109) + # If advanced idexing axes do not appear contiguously, NumPy semantics + # move the advanced axes to the front. + is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray)) + or isscalar(e) for e in idx]) + advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1) + # Remove ellipses and add trailing slice(None)s. idx = _canonicalize_tuple_index(len(x_shape), idx) @@ -11516,10 +11995,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - # Do the advanced indexing axes appear contiguously? If not, NumPy semantics - # move the advanced axes to the front. - advanced_axes_are_contiguous = False - advanced_indexes: Sequence[Array | np.ndarray] | None = None # The positions of the advanced indexing axes in `idx`. @@ -11538,7 +12013,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) for e, i, j in advanced_pairs) advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs) - advanced_axes_are_contiguous = bool(np.all(np.diff(idx_advanced_axes) == 1)) x_axis = 0 # Current axis in x. y_axis = 0 # Current axis in y, before collapsing. See below. @@ -11608,7 +12082,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], except TypeError: abstract_i = None # Handle basic int indexes. - if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i): + if isinstance(abstract_i, ShapedArray) and _int(abstract_i): if core.definitely_equal(x_shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") @@ -11638,7 +12112,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], "arrays within JIT compiled functions).") raise IndexError(msg) - start, step, slice_size = _preprocess_slice(i, x_shape[x_axis]) + start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis]) slice_shape.append(slice_size) if core.definitely_equal(step, 1): @@ -11764,7 +12238,7 @@ def _expand_bool_indices(idx, shape): i = array(i) abstract_i = core.get_aval(i) - if not type(abstract_i) is ConcreteArray: + if not core.is_concrete(i): # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete raise errors.NonConcreteBooleanIndexError(abstract_i) elif _ndim(i) == 0: @@ -11794,7 +12268,7 @@ def _is_slice_element_none_or_constant_or_symbolic(elt): if elt is None: return True if core.is_symbolic_dim(elt): return True try: - return type(core.get_aval(elt)) is ConcreteArray + return core.is_concrete(elt) except TypeError: return False @@ -11841,66 +12315,8 @@ def _canonicalize_tuple_index(arr_ndim, idx): idx = tuple(idx) + colons return idx -def _preprocess_slice( - s: slice, - axis_size: core.DimSize - ) -> tuple[core.DimSize, core.DimSize, core.DimSize]: - """Computes the start index, step, and size of the slice `x[s]`.""" - # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding - # "this is harder to get right than you may think" - # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275) - def convert_to_index(d: DimSize) -> DimSize: - # Convert np.array and jax.Array to int, leave symbolic dimensions alone - try: - return operator.index(d) - except: - return d - - # Must resolve statically if step is {<0, ==0, >0} - step = convert_to_index(s.step) if s.step is not None else 1 - try: - if step == 0: - raise ValueError("slice step cannot be zero") - step_gt_0 = (step > 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the step ({step}) must " + - f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") - - def clamp_index(i: DimSize, which: str): - try: - i_ge_0 = (i >= 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the {which} ({i}) must " + - f"be resolved statically if it is >= 0.\nDetails: {e}") - if i_ge_0: - if step_gt_0: - return core.min_dim(axis_size, i) - else: - return core.min_dim(axis_size - 1, i) - else: - if step_gt_0: - return core.max_dim(0, axis_size + i) - else: - return core.max_dim(-1, axis_size + i) - - if s.start is None: - start = 0 if step_gt_0 else axis_size - 1 - else: - start = clamp_index(convert_to_index(s.start), "start") - - if s.stop is None: - stop = axis_size if step_gt_0 else -1 - else: - stop = clamp_index(convert_to_index(s.stop), "stop") - - gap = step if step_gt_0 else - step - distance = (stop - start) if step_gt_0 else (start - stop) - slice_size = core.max_dim(0, distance + gap - 1) // gap - return start, step, slice_size - +@export def blackman(M: int) -> Array: """Return a Blackman window of size M. @@ -11931,6 +12347,7 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) +@export def bartlett(M: int) -> Array: """Return a Bartlett window of size M. @@ -11961,6 +12378,7 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) +@export def hamming(M: int) -> Array: """Return a Hamming window of size M. @@ -11991,6 +12409,7 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) +@export def hanning(M: int) -> Array: """Return a Hanning window of size M. @@ -12021,6 +12440,7 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) +@export def kaiser(M: int, beta: ArrayLike) -> Array: """Return a Kaiser window of size M. @@ -12063,6 +12483,8 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) + +@export @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the greatest common divisor of two arrays. @@ -12109,6 +12531,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd +@export @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the least common multiple of two arrays. @@ -12156,6 +12579,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) +@export def extract(condition: ArrayLike, arr: ArrayLike, *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: """Return the elements of an array that satisfy a condition. @@ -12217,6 +12641,7 @@ def extract(condition: ArrayLike, arr: ArrayLike, return compress(ravel(condition), ravel(arr), size=size, fill_value=fill_value) +@export def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, *, size: int | None = None, fill_value: ArrayLike = 0, out: None = None) -> Array: """Compress an array along a given axis using a boolean condition. @@ -12311,12 +12736,103 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, return moveaxis(result, 0, axis) -@util.implements(np.cov) +@export @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, fweights: ArrayLike | None = None, aweights: ArrayLike | None = None) -> Array: + r"""Estimate the weighted sample covariance. + + JAX implementation of :func:`numpy.cov`. + + The covariance :math:`C_{ij}` between variable *i* and variable *j* is defined + as + + .. math:: + + cov[X_i, X_j] = E[(X_i - E[X_i])(X_j - E[X_j])] + + Given an array of *N* observations of the variables :math:`X_i` and :math:`X_j`, + this can be estimated via the sample covariance: + + .. math:: + + C_{ij} = \frac{1}{N - 1} \sum_{n=1}^N (X_{in} - \overline{X_i})(X_{jn} - \overline{X_j}) + + Where :math:`\overline{X_i} = \frac{1}{N} \sum_{k=1}^N X_{ik}` is the mean of the + observations. + + Args: + m: array of shape ``(M, N)`` (if ``rowvar`` is True), or ``(N, M)`` + (if ``rowvar`` is False) representing ``N`` observations of ``M`` variables. + ``m`` may also be one-dimensional, representing ``N`` observations of a + single variable. + y: optional set of additional observations, with the same form as ``m``. If + specified, then ``y`` is combined with ``m``, i.e. for the default + ``rowvar = True`` case, ``m`` becomes ``jnp.vstack([m, y])``. + rowvar: if True (default) then each row of ``m`` represents a variable. If + False, then each column represents a variable. + bias: if False (default) then normalize the covariance by ``N - 1``. If True, + then normalize the covariance by ``N`` + ddof: specify the degrees of freedom. Defaults to ``1`` if ``bias`` is False, + or to ``0`` if ``bias`` is True. + fweights: optional array of integer frequency weights of shape ``(N,)``. This + is an absolute weight specifying the number of times each observation is + included in the computation. + aweights: optional array of observation weights of shape ``(N,)``. This is + a relative weight specifying the "importance" of each observation. In the + ``ddof=0`` case, it is equivalent to assigning probabilities to each + observation. + + Returns: + A covariance matrix of shape ``(M, M)``. + + See also: + - :func:`jax.numpy.corrcoef`: compute the correlation coefficient, a normalized + version of the covariance matrix. + + Examples: + Consider these observations of two variables that correlate perfectly. + The covariance matrix in this case is a 2x2 matrix of ones: + + >>> x = jnp.array([[0, 1, 2], + ... [0, 1, 2]]) + >>> jnp.cov(x) + Array([[1., 1.], + [1., 1.]], dtype=float32) + + Now consider these observations of two variables that are perfectly + anti-correlated. The covariance matrix in this case has ``-1`` in the + off-diagonal: + + >>> x = jnp.array([[-1, 0, 1], + ... [ 1, 0, -1]]) + >>> jnp.cov(x) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + Equivalently, these sequences can be specified as separate arguments, + in which case they are stacked before continuing the computation. + + >>> x = jnp.array([-1, 0, 1]) + >>> y = jnp.array([1, 0, -1]) + >>> jnp.cov(x, y) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + In general, the entries of the covariance matrix may be any positive + or negative real value. For example, here is the covariance of 100 + points drawn from a 3-dimensional standard normal distribution: + + >>> key = jax.random.key(0) + >>> x = jax.random.normal(key, shape=(3, 100)) + >>> with jnp.printoptions(precision=2): + ... print(jnp.cov(x)) + [[ 1.22 -0. 0.11] + [-0. 0.84 -0.1 ] + [ 0.11 -0.1 0.88]] + """ if y is not None: m, y = util.promote_args_inexact("cov", m, y) if y.ndim > 2: @@ -12379,9 +12895,82 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() -@util.implements(np.corrcoef) +@export @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: + r"""Compute the Pearson correlation coefficients. + + JAX implementation of :func:`numpy.corrcoef`. + + This is a normalized version of the sample covariance computed by :func:`jax.numpy.cov`. + For a sample covariance :math:`C_{ij}`, the correlation coefficients are + + .. math:: + + R_{ij} = \frac{C_{ij}}{\sqrt{C_{ii}C_{jj}}} + + they are constructed such that the values satisfy :math:`-1 \le R_{ij} \le 1`. + + Args: + x: array of shape ``(M, N)`` (if ``rowvar`` is True), or ``(N, M)`` + (if ``rowvar`` is False) representing ``N`` observations of ``M`` variables. + ``x`` may also be one-dimensional, representing ``N`` observations of a + single variable. + y: optional set of additional observations, with the same form as ``m``. If + specified, then ``y`` is combined with ``m``, i.e. for the default + ``rowvar = True`` case, ``m`` becomes ``jnp.vstack([m, y])``. + rowvar: if True (default) then each row of ``m`` represents a variable. If + False, then each column represents a variable. + + Returns: + A covariance matrix of shape ``(M, M)``. + + See also: + - :func:`jax.numpy.cov`: compute the covariance matrix. + + Examples: + Consider these observations of two variables that correlate perfectly. + The correlation matrix in this case is a 2x2 matrix of ones: + + >>> x = jnp.array([[0, 1, 2], + ... [0, 1, 2]]) + >>> jnp.corrcoef(x) + Array([[1., 1.], + [1., 1.]], dtype=float32) + + Now consider these observations of two variables that are perfectly + anti-correlated. The correlation matrix in this case has ``-1`` in the + off-diagonal: + + >>> x = jnp.array([[-1, 0, 1], + ... [ 1, 0, -1]]) + >>> jnp.corrcoef(x) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + Equivalently, these sequences can be specified as separate arguments, + in which case they are stacked before continuing the computation. + + >>> x = jnp.array([-1, 0, 1]) + >>> y = jnp.array([1, 0, -1]) + >>> jnp.corrcoef(x, y) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + The entries of the correlation matrix are normalized such that they + lie within the range -1 to +1, where +1 indicates perfect correlation + and -1 indicates perfect anti-correlation. For example, here is the + correlation of 100 points drawn from a 3-dimensional standard normal + distribution: + + >>> key = jax.random.key(0) + >>> x = jax.random.normal(key, shape=(3, 100)) + >>> with jnp.printoptions(precision=2): + ... print(jnp.corrcoef(x)) + [[ 1. -0. 0.1 ] + [-0. 1. -0.12] + [ 0.1 -0.12 1. ]] + """ util.check_arraylike("corrcoef", x) c = cov(x, y, rowvar) if len(shape(c)) == 0: @@ -12403,9 +12992,11 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A @partial(vectorize, excluded={0, 1, 3, 4}) def _searchsorted_via_scan(unrolled: bool, sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: op = _sort_le_comparator if side == 'left' else _sort_lt_comparator + unsigned_dtype = np.uint32 if dtype == np.int32 else np.uint64 def body_fun(state, _): low, high = state - mid = (low + high) // 2 + mid = low.astype(unsigned_dtype) + high.astype(unsigned_dtype) + mid = lax.div(mid, unsigned_dtype(2)).astype(dtype) go_left = op(query, sorted_arr[mid]) return (where(go_left, low, mid), where(go_left, mid, high)), () n_levels = int(np.ceil(np.log2(len(sorted_arr) + 1))) @@ -12434,6 +13025,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) +@export @partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: @@ -12523,6 +13115,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', return impl(asarray(a), asarray(v), side, dtype) # type: ignore +@export @partial(jit, static_argnames=('right', 'method')) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str | None = None) -> Array: @@ -12578,6 +13171,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, ) +@export def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: @@ -12685,6 +13279,7 @@ def _tile_to_size(arr: Array, size: int) -> Array: return arr[:size] if arr.size > size else arr +@export def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: bool = True) -> Array: """Update array elements based on a mask. @@ -12760,6 +13355,7 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) +@export def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = None, *, inplace: bool = True) -> Array: """Put elements into an array at given indices. diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 79b47d9090af..8e35560a52ed 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -35,10 +35,13 @@ from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg +export = set_module('jax.numpy.linalg') + + class EighResult(NamedTuple): eigenvalues: jax.Array eigenvectors: jax.Array @@ -67,6 +70,7 @@ def _H(x: ArrayLike) -> Array: def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 +@export @partial(jit, static_argnames=['upper']) def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """Compute the Cholesky decomposition of a matrix. @@ -91,8 +95,8 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Args: a: input array, representing a (batched) positive-definite hermitian matrix. Must have shape ``(..., N, N)``. - upper: if True, compute the upper Cholesky decomposition `L`. if False - (default), compute the lower Cholesky decomposition `U`. + upper: if True, compute the upper Cholesky decomposition `U`. if False + (default), compute the lower Cholesky decomposition `L`. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition @@ -191,6 +195,7 @@ def svd( ... +@export @partial( jit, static_argnames=( @@ -311,6 +316,7 @@ def svd( ) +@export @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: """Raise a square matrix to an integer power. @@ -392,6 +398,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: return result +@export @jit def matrix_rank( M: ArrayLike, rtol: ArrayLike | None = None, *, @@ -496,6 +503,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: return sign_diag * sign_taus, log_abs_det +@export @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ @@ -532,7 +540,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: - raise ValueError("Argument to slogdet() must have shape [..., n, n], got {a_shape}") + raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": return SlogdetResult(*_slogdet_lu(a)) elif method == "qr": @@ -675,6 +683,7 @@ def _det_jvp(primals, tangents): return y, jnp.trace(z, axis1=-1, axis2=-2) +@export @jit def det(a: ArrayLike) -> Array: """ @@ -711,6 +720,7 @@ def det(a: ArrayLike) -> Array: raise ValueError(msg.format(a_shape)) +@export def eig(a: ArrayLike) -> tuple[Array, Array]: """ Compute the eigenvalues and eigenvectors of a square array. @@ -731,7 +741,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: - This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 for 64-bit input. - - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + - At present, non-symmetric eigendecomposition is only implemented on the CPU and + GPU backends. For more details about the GPU implementation, see the + documentation for :func:`jax.lax.linalg.eig`. See also: - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. @@ -754,6 +766,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: return w, v +@export @jit def eigvals(a: ArrayLike) -> Array: """ @@ -791,6 +804,7 @@ def eigvals(a: ArrayLike) -> Array: compute_right_eigenvectors=False)[0] +@export @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: @@ -846,6 +860,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, return EighResult(w, v) +@export @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ @@ -882,6 +897,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: # TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. +@export def pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False, *, rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: @@ -995,6 +1011,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents): return p, p_dot +@export @jit def inv(a: ArrayLike) -> Array: """Return the inverse of a square matrix @@ -1055,6 +1072,7 @@ def inv(a: ArrayLike) -> Array: arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) +@export @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, @@ -1141,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None, num_axes = len(axis) if num_axes == 1: - if ord is None or ord == 2: - return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, - keepdims=keepdims)) - elif ord == jnp.inf: - return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == -jnp.inf: - return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == 0: - return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, - axis=axis, keepdims=keepdims) - elif ord == 1: - # Numpy has a special case for ord == 1 as an optimization. We don't - # really need the optimization (XLA could do it for us), but the Numpy - # code has slightly different type promotion semantics, so we need a - # special case too. - return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif isinstance(ord, str): - msg = f"Invalid order '{ord}' for vector norm." - if ord == "inf": - msg += "Use 'jax.numpy.inf' instead." - if ord == "-inf": - msg += "Use '-jax.numpy.inf' instead." - raise ValueError(msg) - else: - abs_x = ufuncs.abs(x) - ord_arr = lax_internal._const(abs_x, ord) - ord_inv = lax_internal._const(abs_x, 1. / ord_arr) - out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) - return ufuncs.power(out, ord_inv) + return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims) elif num_axes == 2: row_axis, col_axis = axis # pytype: disable=bad-unpacking @@ -1220,6 +1210,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... +@export @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: """Compute the QR decomposition of an array @@ -1303,6 +1294,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: return QRResult(q, r) +@export @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: """Solve a linear system of equations @@ -1406,6 +1398,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) +@export def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: """ @@ -1446,6 +1439,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, return _jit_lstsq(a, b, rcond) +@export def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): r"""Compute the cross-product of two 3D vectors @@ -1491,6 +1485,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): return jnp.cross(x1, x2, axis=axis) +@export def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute the outer product of two 1-dimensional arrays. @@ -1521,7 +1516,8 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: return x1[:, None] * x2[None, :] -def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: +@export +def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array: """Compute the norm of a matrix or stack of matrices. JAX implementation of :func:`numpy.linalg.matrix_norm` @@ -1551,6 +1547,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose a matrix or stack of matrices. @@ -1606,7 +1603,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) -def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, +@export +def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Compute the vector norm of a vector or batch of vectors. @@ -1642,14 +1640,37 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa Array([3.7416575, 9.486833 ], dtype=float32) """ check_arraylike('jnp.linalg.vector_norm', x) - if axis is None: - result = norm(jnp.ravel(x), ord=ord) - if keepdims: - result = lax.expand_dims(result, range(jnp.ndim(x))) - return result - return norm(x, axis=axis, keepdims=keepdims, ord=ord) - + if ord is None or ord == 2: + return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, + keepdims=keepdims)) + elif ord == jnp.inf: + return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif ord == -jnp.inf: + return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif ord == 0: + return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, + axis=axis, keepdims=keepdims) + elif ord == 1: + # Numpy has a special case for ord == 1 as an optimization. We don't + # really need the optimization (XLA could do it for us), but the Numpy + # code has slightly different type promotion semantics, so we need a + # special case too. + return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif isinstance(ord, str): + msg = f"Invalid order '{ord}' for vector norm." + if ord == "inf": + msg += "Use 'jax.numpy.inf' instead." + if ord == "-inf": + msg += "Use '-jax.numpy.inf' instead." + raise ValueError(msg) + else: + abs_x = ufuncs.abs(x) + ord_arr = lax_internal._const(abs_x, ord) + ord_inv = lax_internal._const(abs_x, 1. / ord_arr) + out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) + return ufuncs.power(out, ord_inv) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -1700,6 +1721,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, preferred_element_type=preferred_element_type) +@export def matmul(x1: ArrayLike, x2: ArrayLike, /, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -1760,6 +1782,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *, preferred_element_type=preferred_element_type) +@export def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, precision: PrecisionLike = None, @@ -1841,6 +1864,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, preferred_element_type=preferred_element_type) +@export def svdvals(x: ArrayLike, /) -> Array: """Compute the singular values of a matrix. @@ -1865,6 +1889,7 @@ def svdvals(x: ArrayLike, /) -> Array: return svd(x, compute_uv=False, hermitian=False) +@export def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: """Extract the diagonal of an matrix or stack of matrices. @@ -1905,6 +1930,7 @@ def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) +@export def tensorinv(a: ArrayLike, ind: int = 2) -> Array: """Compute the tensor inverse of an array. @@ -1947,6 +1973,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array: return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape) +@export def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array: """Solve the tensor equation a x = b for x. @@ -1996,6 +2023,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) return solve(a_arr, b_arr.ravel()).reshape(out_shape) +@export def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: """Efficiently compute matrix products between a sequence of arrays. @@ -2085,9 +2113,10 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) - if arrs[-1].ndim == 1: einsum_axes[-1] = einsum_axes[-1][:1] return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload] - optimize='optimal', precision=precision) + optimize='auto', precision=precision) +@export @partial(jit, static_argnames=['p']) def cond(x: ArrayLike, p=None): """Compute the condition number of a matrix. @@ -2147,6 +2176,7 @@ def cond(x: ArrayLike, p=None): return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) +@export def trace(x: ArrayLike, /, *, offset: int = 0, dtype: DTypeLike | None = None) -> Array: """Compute the trace of a matrix. diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 10cc90575cef..19388b903e5d 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -33,6 +33,10 @@ from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, _where) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module + + +export = set_module('jax.numpy') @jit @@ -57,6 +61,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) +@export def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: r"""Returns the roots of a polynomial given the coefficients ``p``. @@ -116,6 +121,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: return _roots_with_zeros(p_arr, num_leading_zeros) +@export @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, full: bool = False, w: ArrayLike | None = None, cov: bool = False @@ -287,6 +293,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, return c +@export @jit def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. @@ -369,6 +376,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: return a +@export @partial(jit, static_argnames=['unroll']) def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. @@ -432,6 +440,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: return y +@export @jit def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the sum of the two polynomials. @@ -489,6 +498,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) +@export @partial(jit, static_argnames=('m',)) def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. @@ -557,6 +567,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array return true_divide(concatenate((p_arr, k_arr)), coeff) +@export @partial(jit, static_argnames=('m',)) def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. @@ -607,6 +618,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array: return p_arr[:-m] * coeff[::-1] +@export def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: r"""Returns the product of two polynomials. @@ -673,6 +685,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - return convolve(a1_arr, a2_arr, mode='full') +@export def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: r"""Returns the quotient and remainder of polynomial division. @@ -732,6 +745,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> return q, u_arr +@export @jit def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the difference of two polynomials. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 1c2a4689cb85..eea734420176 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -37,9 +37,11 @@ from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, - NumpyComplexWarning) + set_module, NumpyComplexWarning) +export = set_module('jax.numpy') + _all = builtins.all _lax_const = lax_internal._const @@ -79,10 +81,24 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: return dtypes.int_ return dtype +def check_where(name: str, where: ArrayLike | None) -> Array | None: + if where is None: + return where + check_arraylike(name, where) + where_arr = lax_internal.asarray(where) + if where_arr.dtype != bool: + # Deprecation added 2024-12-05 + deprecations.warn( + 'jax-numpy-reduction-non-boolean-where', + f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.", + stacklevel=2) + return where_arr.astype(bool) + return where_arr + ReductionOp = Callable[[Any, Any], Any] -def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, +def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, *, has_identity: bool = True, preproc: Callable[[ArrayLike], ArrayLike] | None = None, bool_op: ReductionOp | None = None, @@ -99,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") check_arraylike(name, a) + where_ = check_where(name, where_) dtypes.check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") @@ -192,6 +209,11 @@ def _cast_to_bool(operand: ArrayLike) -> Array: def _cast_to_numeric(operand: ArrayLike) -> Array: return promote_dtypes_numeric(operand)[0] +def _require_integer(operand: ArrayLike) -> Array: + arr = lax_internal.asarray(operand) + if not dtypes.isdtype(arr, ("bool", "integral")): + raise ValueError(f"integer argument required; got dtype={arr.dtype}") + return arr def _ensure_optional_axes(x: Axis) -> Axis: def force(x): @@ -210,13 +232,14 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: - return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, + return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric, bool_op=lax.bitwise_or, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.psum, promote_integers=promote_integers) +@export def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: @@ -291,17 +314,19 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) + @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: - return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric, + return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) +@export def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -381,11 +406,12 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, + return _reduction(a, "max", lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) +@export def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -463,11 +489,12 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None, def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, + return _reduction(a, "min", lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) +@export def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -543,10 +570,11 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, + return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether all array elements along a given axis evaluate to True. @@ -599,10 +627,11 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, + return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether any of the array elements along a given axis evaluate to True. @@ -652,6 +681,100 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + arr = lax_internal.asarray(a) + init_val = np.array(-1, dtype=dtype or arr.dtype) + return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Compute log(sum(exp(a))) while avoiding precision loss.""" + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") + dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") + check_arraylike("logsumexp", a) + where = check_where("logsumexp", where) + a_arr, = promote_dtypes_inexact(a) + pos_dims, dims = _reduction_dims(a_arr, axis) + amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf) + amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) + amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) + exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) + sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) + result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) + return result if initial is None else lax.logaddexp(initial, result) + + +def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Compute log2(sum(2 ** a)) via logsumexp.""" + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") + dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") + check_arraylike("logsumexp2", a) + where = check_where("logsumexp2", where) + ln2 = float(np.log(2)) + if initial is not None: + initial *= ln2 + return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims, + where=where, initial=initial) / ln2 + + +@export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -659,6 +782,7 @@ def amin(a: ArrayLike, axis: Axis = None, out: None = None, return min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) +@export def amax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -678,6 +802,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): return size +@export def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -744,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: check_arraylike("mean", a) + where = check_where("mean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -781,6 +907,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, * @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... +@export def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: """Compute the weighed average. @@ -891,6 +1018,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, return avg +@export def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -979,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) + where = check_where("var", where) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") @@ -1031,6 +1160,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy return _upcast_f16(computation_dtype), np.dtype(dtype) +@export def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1115,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) + where = check_where("std", where) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") @@ -1123,6 +1254,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) +@export def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: r"""Return the peak-to-peak range along a given axis. @@ -1174,6 +1306,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, return lax.sub(x, y) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: @@ -1219,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None, def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, - axis: Axis = None, keepdims: bool = False, **kwargs) -> Array: + axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None, + **kwargs) -> Array: check_arraylike(name, a) + where = check_where(name, where) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): - return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) + return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), - axis=axis, keepdims=keepdims, **kwargs) + axis=axis, keepdims=keepdims, where=where, **kwargs) if nan_if_all_nan: return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) @@ -1233,6 +1368,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], return out +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1315,6 +1451,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1397,6 +1534,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1480,6 +1618,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1563,6 +1702,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: @@ -1639,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out Array([[nan, nan, nan, nan]], dtype=float32) """ check_arraylike("nanmean", a) + where = check_where("nanmean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): @@ -1654,6 +1795,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out return td +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1731,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: [4. ]], dtype=float32) """ check_arraylike("nanvar", a) + where = check_where("nanvar", where) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") @@ -1756,6 +1899,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: return lax.convert_element_type(result, dtype) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1825,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: Array([[0.5, 0.5, 0. , 0. ]], dtype=float32) """ check_arraylike("nanstd", a) + where = check_where("nanstd", where) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") @@ -1838,7 +1983,7 @@ def __call__(self, a: ArrayLike, axis: Axis = None, def _cumulative_reduction( name: str, reduction: Callable[..., Array], - a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None, + a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None = None, fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" @@ -1877,6 +2022,7 @@ def _cumulative_reduction( return result +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -1913,6 +2059,7 @@ def cumsum(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -1948,6 +2095,7 @@ def cumprod(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -1997,6 +2145,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None, fill_nan=True, fill_value=0) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2053,6 +2202,7 @@ def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, a, axis, dtype, out, promote_integers=True) +@export def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2064,7 +2214,7 @@ def cumulative_sum( Args: x: N-dimensional array axis: integer axis along which to accumulate. If ``x`` is one-dimensional, - this argument is optional. + this argument is optional and defaults to zero. dtype: optional dtype of the output. include_initial: if True, then include the initial value in the cumulative sum. Default is False. @@ -2113,9 +2263,72 @@ def cumulative_sum( dimension=axis) return out + +@export +def cumulative_prod( + x: ArrayLike, /, *, axis: int | None = None, + dtype: DTypeLike | None = None, + include_initial: bool = False) -> Array: + """Cumulative product along the axis of an array. + + JAX implementation of :func:`numpy.cumulative_prod`. + + Args: + x: N-dimensional array + axis: integer axis along which to accumulate. If ``x`` is one-dimensional, + this argument is optional and defaults to zero. + dtype: optional dtype of the output. + include_initial: if True, then include the initial value in the cumulative + product. Default is False. + + Returns: + An array containing the accumulated values. + + See Also: + - :func:`jax.numpy.cumprod`: alternative API for cumulative product. + - :func:`jax.numpy.nancumprod`: cumulative product while ignoring NaN values. + - :func:`jax.numpy.multiply.accumulate`: cumulative product via the ufunc API. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumulative_prod(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + >>> jnp.cumulative_prod(x, axis=1, include_initial=True) + Array([[ 1, 1, 2, 6], + [ 1, 4, 20, 120]], dtype=int32) + """ + check_arraylike("cumulative_prod", x) + x = lax_internal.asarray(x) + if x.ndim == 0: + raise ValueError( + "The input must be non-scalar to take a cumulative product, however a " + "scalar value or scalar array was given." + ) + if axis is None: + axis = 0 + if x.ndim > 1: + raise ValueError( + f"The input array has rank {x.ndim}, however axis was not set to an " + "explicit value. The axis argument is only optional for one-dimensional " + "arrays.") + + axis = _canonicalize_axis(axis, x.ndim) + dtypes.check_user_dtype_supported(dtype) + out = _cumulative_reduction("cumulative_prod", lax.cumprod, x, axis, dtype) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = lax_internal.concatenate( + [lax_internal.full(zeros_shape, 1, dtype=out.dtype), out], + dimension=axis) + return out + # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2172,6 +2385,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2299,7 +2513,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + with jax.debug_nans(False): + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) @@ -2351,7 +2566,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2407,7 +2624,9 @@ def percentile(a: ArrayLike, q: ArrayLike, return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2467,6 +2686,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, method=method, keepdims=keepdims) +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, @@ -2518,6 +2738,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, keepdims=keepdims, method='midpoint') +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 6491a7617d8d..0d5ea905becc 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -35,10 +35,12 @@ from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan from jax._src.numpy.util import check_arraylike, promote_dtypes -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike +export = set_module('jax.numpy') + _lax_const = lax_internal._const @@ -88,6 +90,7 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: return arr, num_unique1 + num_unique2 +@export def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set difference of two 1D arrays. @@ -175,6 +178,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value) +@export def union1d(ar1: ArrayLike, ar2: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set union of two 1D arrays. @@ -278,6 +282,7 @@ def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, return where(arange(len(vals)) < num_results, vals, fill_value) +@export def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. @@ -417,6 +422,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as return vals +@export def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return_indices: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]: @@ -524,6 +530,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return int1d +@export def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: bool = False, invert: bool = False, *, method='auto') -> Array: @@ -652,6 +659,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo return ret[0] if len(ret) == 1 else ret +@export def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, axis: int | None = None, *, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None): @@ -863,6 +871,7 @@ class _UniqueInverseResult(NamedTuple): inverse_indices: Array +@export def unique_all(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueAllResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -945,6 +954,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) +@export def unique_counts(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueCountsResult: """Return unique values from x, along with counts. @@ -1005,6 +1015,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, return _UniqueCountsResult(values=values, counts=counts) +@export def unique_inverse(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueInverseResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -1070,6 +1081,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) +@export def unique_values(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Return unique values from x, along with indices, inverse indices, and counts. diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 3473e8a7468a..5dbd67e62a9f 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -33,6 +33,8 @@ import numpy as np +export = set_module("jax.numpy") + _AT_INPLACE_WARNING = """\ Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. @@ -40,7 +42,7 @@ """ -@set_module('jax.numpy') +@export class ufunc: """Universal functions which operation element-by-element on arrays. @@ -586,6 +588,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: return result.reshape(*np.shape(A), *np.shape(B)) +@export def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, *, identity: Any = None) -> ufunc: """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. @@ -598,5 +601,28 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. + + Examples: + Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`: + + >>> import operator + >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0) + + Now all the standard :class:`jax.numpy.ufunc` methods are available: + + >>> x = jnp.arange(4) + >>> add(x, 10) + Array([10, 11, 12, 13], dtype=int32) + >>> add.outer(x, x) + Array([[0, 1, 2, 3], + [1, 2, 3, 4], + [2, 3, 4, 5], + [3, 4, 5, 6]], dtype=int32) + >>> add.reduce(x) + Array(6, dtype=int32) + >>> add.accumulate(x) + Array([0, 1, 3, 6], dtype=int32) + >>> add.at(x, 1, 10, inplace=False) + Array([ 0, 11, 2, 3], dtype=int32) """ return ufunc(func, nin, nout, identity=identity) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b18b06d02f2c..de8688e491ba 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -31,13 +31,17 @@ from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.lax import other as lax_other -from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, - promote_shapes, _where, implements, check_no_float0s) + promote_shapes, _where, check_no_float0s) from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy import reductions +from jax._src.util import set_module + + +export = set_module('jax.numpy') _lax_const = lax._const @@ -57,6 +61,25 @@ def _to_bool(x: Array) -> Array: return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) +def unary_ufunc(func: Callable[[ArrayLike], Array]) -> ufunc: + """An internal helper function for defining unary ufuncs.""" + func_jit = jit(func, inline=True) + return ufunc(func_jit, name=func.__name__, nin=1, nout=1, call=func_jit) + + +def binary_ufunc(identity: Any, reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None) -> Callable[[Callable[[ArrayLike, ArrayLike], Array]], ufunc]: + """An internal helper function for defining binary ufuncs.""" + def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc: + func_jit = jit(func, inline=True) + return ufunc(func_jit, name=func.__name__, nin=2, nout=1, call=func_jit, + identity=identity, reduce=reduce, accumulate=accumulate, at=at, reduceat=reduceat) + return decorator + + +@export @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -101,18 +124,21 @@ def fabs(x: ArrayLike, /) -> Array: return lax.abs(*promote_args_inexact('fabs', x)) +@export @partial(jit, inline=True) def bitwise_invert(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_invert', x)) +@export @partial(jit, inline=True) def bitwise_not(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_not', x)) +@export @partial(jit, inline=True) def invert(x: ArrayLike, /) -> Array: """Compute the bitwise inversion of an input. @@ -160,8 +186,8 @@ def invert(x: ArrayLike, /) -> Array: return lax.bitwise_not(*promote_args('invert', x)) -@partial(jit, inline=True) -def _negative(x: ArrayLike, /) -> Array: +@unary_ufunc +def negative(x: ArrayLike, /) -> Array: """Return element-wise negative values of the input. JAX implementation of :obj:`numpy.negative`. @@ -205,6 +231,7 @@ def _negative(x: ArrayLike, /) -> Array: return lax.neg(*promote_args('negative', x)) +@export @partial(jit, inline=True) def positive(x: ArrayLike, /) -> Array: """Return element-wise positive values of the input. @@ -253,6 +280,7 @@ def positive(x: ArrayLike, /) -> Array: return lax.asarray(*promote_args('positive', x)) +@export @partial(jit, inline=True) def sign(x: ArrayLike, /) -> Array: r"""Return an element-wise indication of sign of the input. @@ -303,6 +331,7 @@ def sign(x: ArrayLike, /) -> Array: return lax.sign(*promote_args('sign', x)) +@export @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. @@ -341,6 +370,7 @@ def floor(x: ArrayLike, /) -> Array: return lax.floor(*promote_args_inexact('floor', x)) +@export @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. @@ -379,6 +409,7 @@ def ceil(x: ArrayLike, /) -> Array: return lax.ceil(*promote_args_inexact('ceil', x)) +@export @partial(jit, inline=True) def exp(x: ArrayLike, /) -> Array: """Calculate element-wise exponential of the input. @@ -420,6 +451,7 @@ def exp(x: ArrayLike, /) -> Array: return lax.exp(*promote_args_inexact('exp', x)) +@export @partial(jit, inline=True) def log(x: ArrayLike, /) -> Array: """Calculate element-wise natural logarithm of the input. @@ -457,6 +489,7 @@ def log(x: ArrayLike, /) -> Array: return lax.log(*promote_args_inexact('log', x)) +@export @partial(jit, inline=True) def expm1(x: ArrayLike, /) -> Array: """Calculate ``exp(x)-1`` of each element of the input. @@ -501,6 +534,7 @@ def expm1(x: ArrayLike, /) -> Array: return lax.expm1(*promote_args_inexact('expm1', x)) +@export @partial(jit, inline=True) def log1p(x: ArrayLike, /) -> Array: """Calculates element-wise logarithm of one plus input, ``log(x+1)``. @@ -541,6 +575,7 @@ def log1p(x: ArrayLike, /) -> Array: return lax.log1p(*promote_args_inexact('log1p', x)) +@export @partial(jit, inline=True) def sin(x: ArrayLike, /) -> Array: """Compute a trigonometric sine of each element of input. @@ -572,6 +607,7 @@ def sin(x: ArrayLike, /) -> Array: return lax.sin(*promote_args_inexact('sin', x)) +@export @partial(jit, inline=True) def cos(x: ArrayLike, /) -> Array: """Compute a trigonometric cosine of each element of input. @@ -602,6 +638,7 @@ def cos(x: ArrayLike, /) -> Array: return lax.cos(*promote_args_inexact('cos', x)) +@export @partial(jit, inline=True) def tan(x: ArrayLike, /) -> Array: """Compute a trigonometric tangent of each element of input. @@ -632,6 +669,7 @@ def tan(x: ArrayLike, /) -> Array: return lax.tan(*promote_args_inexact('tan', x)) +@export @partial(jit, inline=True) def arcsin(x: ArrayLike, /) -> Array: r"""Compute element-wise inverse of trigonometric sine of input. @@ -673,6 +711,7 @@ def arcsin(x: ArrayLike, /) -> Array: return lax.asin(*promote_args_inexact('arcsin', x)) +@export @partial(jit, inline=True) def arccos(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric cosine of input. @@ -715,6 +754,7 @@ def arccos(x: ArrayLike, /) -> Array: return lax.acos(*promote_args_inexact('arccos', x)) +@export @partial(jit, inline=True) def arctan(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric tangent of input. @@ -755,6 +795,7 @@ def arctan(x: ArrayLike, /) -> Array: return lax.atan(*promote_args_inexact('arctan', x)) +@export @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic sine of input. @@ -809,6 +850,7 @@ def sinh(x: ArrayLike, /) -> Array: return lax.sinh(*promote_args_inexact('sinh', x)) +@export @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic cosine of input. @@ -862,6 +904,7 @@ def cosh(x: ArrayLike, /) -> Array: return lax.cosh(*promote_args_inexact('cosh', x)) +@export @partial(jit, inline=True) def arcsinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic sine of input. @@ -911,6 +954,7 @@ def arcsinh(x: ArrayLike, /) -> Array: return lax.asinh(*promote_args_inexact('arcsinh', x)) +@export @jit def arccosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic cosine of input. @@ -966,6 +1010,7 @@ def arccosh(x: ArrayLike, /) -> Array: return result +@export @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic tangent of input. @@ -1019,6 +1064,7 @@ def tanh(x: ArrayLike, /) -> Array: return lax.tanh(*promote_args_inexact('tanh', x)) +@export @partial(jit, inline=True) def arctanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic tangent of input. @@ -1067,6 +1113,7 @@ def arctanh(x: ArrayLike, /) -> Array: return lax.atanh(*promote_args_inexact('arctanh', x)) +@export @partial(jit, inline=True) def sqrt(x: ArrayLike, /) -> Array: """Calculates element-wise non-negative square root of the input array. @@ -1099,6 +1146,7 @@ def sqrt(x: ArrayLike, /) -> Array: return lax.sqrt(*promote_args_inexact('sqrt', x)) +@export @partial(jit, inline=True) def cbrt(x: ArrayLike, /) -> Array: """Calculates element-wise cube root of the input array. @@ -1126,8 +1174,18 @@ def cbrt(x: ArrayLike, /) -> Array: """ return lax.cbrt(*promote_args_inexact('cbrt', x)) -@partial(jit, inline=True) -def _add(x: ArrayLike, y: ArrayLike, /) -> Array: + +def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.add.at.""" + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].add(b).astype(bool) + return a.at[indices].add(b) + + +@binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) +def add(x: ArrayLike, y: ArrayLike, /) -> Array: """Add two arrays element-wise. JAX implementation of :obj:`numpy.add`. This is a universal function, @@ -1156,8 +1214,19 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) -@partial(jit, inline=True) -def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: + +def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.multiply.at.""" + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].mul(b).astype(bool) + else: + return a.at[indices].mul(b) + + +@binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) +def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: """Multiply two arrays element-wise. JAX implementation of :obj:`numpy.multiply`. This is a universal function, @@ -1186,8 +1255,9 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@partial(jit, inline=True) -def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and) +def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, @@ -1215,8 +1285,9 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@partial(jit, inline=True) -def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or) +def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, @@ -1244,8 +1315,9 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@partial(jit, inline=True) -def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor) +def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, @@ -1274,6 +1346,7 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) +@export @partial(jit, inline=True) def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise. @@ -1329,12 +1402,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.shift_left(*promote_args_numeric("left_shift", x, y)) +@export @partial(jit, inline=True) def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.left_shift`.""" return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) +@export @partial(jit, inline=True) def equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x == y``. @@ -1384,6 +1459,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.eq(*promote_args("equal", x, y)) +@export @partial(jit, inline=True) def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x != y``. @@ -1433,8 +1509,13 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.ne(*promote_args("not_equal", x, y)) -@partial(jit, inline=True) -def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array: +def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.subtract.at.""" + return a.at[indices].subtract(b) + + +@binary_ufunc(identity=None, at=_subtract_at) +def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: """Subtract two arrays element-wise. JAX implementation of :obj:`numpy.subtract`. This is a universal function, @@ -1463,12 +1544,63 @@ def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.sub(*promote_args("subtract", x, y)) -@implements(np.arctan2, module='numpy') +@export @partial(jit, inline=True) def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + r"""Compute the arctangent of x1/x2, choosing the correct quadrant. + + JAX implementation of :func:`numpy.arctan2` + + Args: + x1: numerator array. + x2: denomniator array; should be broadcast-compatible with x1. + + Returns: + The elementwise arctangent of x1 / x2, tracking the correct quadrant. + + See also: + - :func:`jax.numpy.tan`: compute the tangent of an angle + - :func:`jax.numpy.atan2`: the array API version of this function. + + Examples: + Consider a sequence of angles in radians between 0 and :math:`2\pi`: + + >>> theta = jnp.linspace(-jnp.pi, jnp.pi, 9) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(theta) + [-3.14 -2.36 -1.57 -0.79 0. 0.79 1.57 2.36 3.14] + + These angles can equivalently be represented by ``(x, y)`` coordinates + on a unit circle: + + >>> x, y = jnp.cos(theta), jnp.sin(theta) + + To reconstruct the input angle, we might be tempted to use the identity + :math:`\tan(\theta) = y / x`, and compute :math:`\theta = \tan^{-1}(y/x)`. + Unfortunately, this does not recover the input angle: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.arctan(y / x)) + [-0. 0.79 1.57 -0.79 0. 0.79 1.57 -0.79 0. ] + + The problem is that :math:`y/x` contains some ambiguity: although + :math:`(y, x) = (-1, -1)` and :math:`(y, x) = (1, 1)` represent different points in + Cartesian space, in both cases :math:`y / x = 1`, and so the simple arctan + approach loses information about which quadrant the angle lies in. :func:`arctan2` + is built to address this: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.arctan2(y, x)) + [ 3.14 -2.36 -1.57 -0.79 0. 0.79 1.57 2.36 -3.14] + + The results match the input ``theta``, except at the endpoints where :math:`+\pi` + and :math:`-\pi` represent indistinguishable points on the unit circle. By convention, + :func:`arctan2` alwasy returns values between :math:`-\pi` and :math:`+\pi` inclusive. + """ return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) +@export @partial(jit, inline=True) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1529,6 +1661,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) +@export @partial(jit, inline=True) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. @@ -1588,6 +1721,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.max(*promote_args("maximum", x, y)) +@export @partial(jit, inline=True) def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: """Calculate element-wise base ``x`` exponential of ``y``. @@ -1634,6 +1768,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.pow(*promote_args_inexact("float_power", x, y)) +@export @partial(jit, inline=True) def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise next floating point value after ``x`` towards ``y``. @@ -1661,6 +1796,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.nextafter(*promote_args_inexact("nextafter", x, y)) +@export @partial(jit, inline=True) def spacing(x: ArrayLike, /) -> Array: """Return the spacing between ``x`` and the next adjacent number. @@ -1705,8 +1841,8 @@ def spacing(x: ArrayLike, /) -> Array: # Logical ops -@partial(jit, inline=True) -def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=True, reduce=reductions._reduce_logical_and) +def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical AND operation elementwise. JAX implementation of :obj:`numpy.logical_and`. This is a universal function, @@ -1725,8 +1861,9 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) -@partial(jit, inline=True) -def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=False, reduce=reductions._reduce_logical_or) +def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical OR operation elementwise. JAX implementation of :obj:`numpy.logical_or`. This is a universal function, @@ -1745,8 +1882,9 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) -@partial(jit, inline=True) -def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=False, reduce=reductions._reduce_logical_xor) +def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical XOR operation elementwise. JAX implementation of :obj:`numpy.logical_xor`. This is a universal function, @@ -1766,6 +1904,7 @@ def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) +@export @partial(jit, inline=True) def logical_not(x: ArrayLike, /) -> Array: """Compute NOT bool(x) element-wise. @@ -1811,6 +1950,8 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], lax_op(x.real, y.real)) return lax_op(x, y) + +@export @partial(jit, inline=True) def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x >= y``. @@ -1856,6 +1997,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) +@export @partial(jit, inline=True) def greater(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x > y``. @@ -1902,6 +2044,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.gt, *promote_args("greater", x, y)) +@export @partial(jit, inline=True) def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x <= y``. @@ -1948,6 +2091,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) +@export @partial(jit, inline=True) def less(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x < y``. @@ -1993,42 +2137,58 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array: """ return _complex_comparison(lax.lt, *promote_args("less", x, y)) + # Array API aliases +@export @partial(jit, inline=True) def acos(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccos`""" return arccos(*promote_args('acos', x)) + +@export @partial(jit, inline=True) def acosh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccosh`""" return arccosh(*promote_args('acosh', x)) + +@export @partial(jit, inline=True) def asin(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsin`""" return arcsin(*promote_args('asin', x)) + +@export @partial(jit, inline=True) def asinh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsinh`""" return arcsinh(*promote_args('asinh', x)) + +@export @partial(jit, inline=True) def atan(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan`""" return arctan(*promote_args('atan', x)) + +@export @partial(jit, inline=True) def atanh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctanh`""" return arctanh(*promote_args('atanh', x)) + +@export @partial(jit, inline=True) def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" return arctan2(*promote_args('atan2', x1, x2)) + +@export @jit def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value @@ -2064,6 +2224,8 @@ def bitwise_count(x: ArrayLike, /) -> Array: # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') + +@export @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. @@ -2115,12 +2277,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_fn(x1, x2) +@export @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.right_shift`.""" return right_shift(x1, x2) +@export @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. @@ -2156,12 +2320,14 @@ def absolute(x: ArrayLike, /) -> Array: return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) +@export @partial(jit, inline=True) def abs(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.absolute`.""" return absolute(x) +@export @jit def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer @@ -2201,6 +2367,7 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) +@export @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. @@ -2240,6 +2407,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) +@export @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the division of x1 by x2 element-wise @@ -2278,11 +2446,13 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.div(x1, x2) +@export def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.true_divide`.""" return true_divide(x1, x2) +@export @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise @@ -2337,6 +2507,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _float_divmod(x1, x2)[0] +@export @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise @@ -2391,6 +2562,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod +@export def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise base ``x1`` exponential of ``x2``. @@ -2463,7 +2635,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # lax.pow. # Case 1: concrete integer scalar powers: - if isinstance(core.get_aval(x2), core.ConcreteArray): + if core.is_concrete(x2): try: x2 = operator.index(x2) # type: ignore[arg-type] except TypeError: @@ -2475,6 +2647,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) +@export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.power`""" return power(x1, x2) @@ -2514,7 +2687,7 @@ def _pow_int_int(x1, x2): return acc -@jit +@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp) def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. @@ -2540,17 +2713,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) -def _wrap_between(x, _a): - """Wraps `x` between `[-a, a]`.""" - a = _constant_like(x, _a) - two_a = _constant_like(x, 2 * _a) - zero = _constant_like(x, 0) - rem = lax.rem(lax.add(x, a), two_a) - rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) - return lax.sub(rem, a) - - -@jit +@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp2) def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. @@ -2578,35 +2741,11 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array(True, dtype=bool) """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) - return _logaddexp2(x1, x2) - - -@custom_jvp -def _logaddexp2(x1, x2): - amax = lax.max(x1, x2) - if dtypes.issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(lax._isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), - _constant_like(x1, np.log(2))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) - - -@_logaddexp2.defjvp -def _logaddexp2_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) - primal_out = logaddexp2(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out + ln2 = float(np.log(2)) + return logaddexp(x1 * ln2, x2 * ln2) / ln2 +@export @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of ``x`` element-wise. @@ -2629,6 +2768,7 @@ def log2(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) +@export @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise @@ -2652,6 +2792,7 @@ def log10(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) +@export @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: """Calculate element-wise base-2 exponential of input. @@ -2686,6 +2827,7 @@ def exp2(x: ArrayLike, /) -> Array: return lax.exp2(x) +@export @jit def signbit(x: ArrayLike, /) -> Array: """Return the sign bit of array elements. @@ -2758,6 +2900,7 @@ def _normalize_float(x): return lax.bitcast_convert_type(x1, int_type), x2 +@export @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute x1 * 2 ** x2 @@ -2807,6 +2950,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(isinf(x1) | (x1 == 0), x1, x) +@export @jit def frexp(x: ArrayLike, /) -> tuple[Array, Array]: """Split floating point values into mantissa and twos exponent. @@ -2860,6 +3004,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Returns element-wise remainder of the division. @@ -2907,11 +3052,13 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) +@export def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.remainder`""" return remainder(x1, x2) +@export @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise floating-point modulo operation. @@ -2953,6 +3100,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.rem(*promote_args_numeric("fmod", x1, x2)) +@export @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: """Calculate element-wise square of the input array. @@ -2999,9 +3147,10 @@ def square(x: ArrayLike, /) -> Array: """ check_arraylike("square", x) x, = promote_dtypes_numeric(x) - return lax.integer_pow(x, 2) + return lax.square(x) +@export @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: r"""Convert angles from degrees to radians. @@ -3036,6 +3185,7 @@ def deg2rad(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, np.pi / 180)) +@export @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: r"""Convert angles from radians to degrees. @@ -3071,15 +3221,19 @@ def rad2deg(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, 180 / np.pi)) +@export def degrees(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.rad2deg`""" return rad2deg(x) + +@export def radians(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.deg2rad`""" return deg2rad(x) +@export @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: """Return element-wise complex-conjugate of the input. @@ -3109,11 +3263,13 @@ def conjugate(x: ArrayLike, /) -> Array: return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) +@export def conj(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.conjugate`""" return conjugate(x) +@export @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: """Return element-wise imaginary of part of the complex argument. @@ -3145,6 +3301,7 @@ def imag(val: ArrayLike, /) -> Array: return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) +@export @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: """Return element-wise real part of the complex argument. @@ -3176,6 +3333,7 @@ def real(val: ArrayLike, /) -> Array: return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) +@export @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: """Return element-wise fractional and integral parts of the input array. @@ -3209,6 +3367,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: return x - whole, whole +@export @partial(jit, inline=True) def isfinite(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is finite. @@ -3249,6 +3408,7 @@ def isfinite(x: ArrayLike, /) -> Array: return lax.full_like(x, True, dtype=np.bool_) +@export @jit def isinf(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is infinite. @@ -3304,6 +3464,7 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) +@export def isposinf(x, /, out=None): """ Return boolean array indicating whether each element of input is positive infinite. @@ -3337,6 +3498,7 @@ def isposinf(x, /, out=None): return _isposneginf(np.inf, x, out) +@export def isneginf(x, /, out=None): """ Return boolean array indicating whether each element of input is negative infinite. @@ -3370,6 +3532,7 @@ def isneginf(x, /, out=None): return _isposneginf(-np.inf, x, out) +@export @partial(jit, inline=True) def isnan(x: ArrayLike, /) -> Array: """Returns a boolean array indicating whether each element of input is ``NaN``. @@ -3404,6 +3567,7 @@ def isnan(x: ArrayLike, /) -> Array: return lax.ne(x, x) +@export @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the heaviside step function. @@ -3453,6 +3617,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) +@export @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: r""" @@ -3501,6 +3666,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(idx_inf, _lax_const(x, np.inf), x) +@export @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: """Calculate element-wise reciprocal of the input. @@ -3534,6 +3700,7 @@ def reciprocal(x: ArrayLike, /) -> Array: return lax.integer_pow(x, -1) +@export @jit def sinc(x: ArrayLike, /) -> Array: r"""Calculate the normalized sinc function. @@ -3604,57 +3771,3 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t - - -def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_and.reduce()") - result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - - -def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_or.reduce()") - result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - -def _add_at(a: Array, indices: Any, b: ArrayLike): - if a.dtype == bool: - a = a.astype('int32') - b = lax.convert_element_type(b, bool).astype('int32') - return a.at[indices].add(b).astype(bool) - return a.at[indices].add(b) - -def _subtract_at(a: Array, indices: Any, b: ArrayLike): - return a.at[indices].subtract(b) - -def _multiply_at(a: Array, indices: Any, b: ArrayLike): - if a.dtype == bool: - a = a.astype('int32') - b = lax.convert_element_type(b, bool).astype('int32') - return a.at[indices].mul(b).astype(bool) - else: - return a.at[indices].mul(b) - -# Generate ufunc interfaces for several common binary functions. -# We start with binary ufuncs that have well-defined identities.' -# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience? -# TODO(jakevdp): optimize some implementations. -# - define add.at/multiply.at in terms of scatter_add/scatter_mul -# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod -# - define all monoidal reductions in terms of lax.reduce -add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) -multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) -bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) -bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) -bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) -logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) -logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) -logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) -negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative) -subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 27496ad99056..15cbc22dfa0d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -13,11 +13,9 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import partial -import re -import textwrap -from typing import Any, NamedTuple, TypeVar +from typing import Any import warnings @@ -34,173 +32,6 @@ zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map -_T = TypeVar("_T") - -_parameter_break = re.compile("\n(?=[A-Za-z_])") -_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE) -_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE) -_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE) -_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`') - -class ParsedDoc(NamedTuple): - """ - docstr: full docstring - signature: signature from docstring. - summary: summary from docstring. - front_matter: front matter before sections. - sections: dictionary of section titles to section content. - """ - docstr: str | None - signature: str = "" - summary: str = "" - front_matter: str = "" - sections: dict[str, str] = {} - - -def _parse_numpydoc(docstr: str | None) -> ParsedDoc: - """Parse a standard numpy-style docstring. - - Args: - docstr: the raw docstring from a function - Returns: - ParsedDoc: parsed version of the docstring - """ - if docstr is None or not docstr.strip(): - return ParsedDoc(docstr) - - # Remove any :doc: directives in the docstring to avoid sphinx errors - docstr = _docreference.sub( - lambda match: f"{match.groups()[0]}", docstr) - - signature, body = "", docstr - match = _numpy_signature_re.match(body) - if match: - signature = match.group() - body = docstr[match.end():] - - firstline, _, body = body.partition('\n') - body = textwrap.dedent(body.lstrip('\n')) - - match = _numpy_signature_re.match(body) - if match: - signature = match.group() - body = body[match.end():] - - summary = firstline - if not summary: - summary, _, body = body.lstrip('\n').partition('\n') - body = textwrap.dedent(body.lstrip('\n')) - - front_matter = "" - body = "\n" + body - section_list = _section_break.split(body) - if not _section_break.match(section_list[0]): - front_matter, *section_list = section_list - sections = {section.split('\n', 1)[0]: section for section in section_list} - - return ParsedDoc(docstr=docstr, signature=signature, summary=summary, - front_matter=front_matter, sections=sections) - - -def _parse_parameters(body: str) -> dict[str, str]: - """Parse the Parameters section of a docstring.""" - title, underline, content = body.split('\n', 2) - assert title == 'Parameters' - assert underline and not underline.strip('-') - parameters = _parameter_break.split(content) - return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} - - -def implements( - original_fun: Callable[..., Any] | None, - update_doc: bool = True, - sections: Sequence[str] = ('Parameters', 'Returns', 'References'), - module: str | None = None, -) -> Callable[[_T], _T]: - """Decorator for JAX functions which implement a specified NumPy function. - - This mainly contains logic to copy and modify the docstring of the original - function. In particular, if `update_doc` is True, parameters listed in the - original function that are not supported by the decorated function will - be removed from the docstring. For this reason, it is important that parameter - names match those in the original numpy function. - - Args: - original_fun: The original function being implemented - update_doc: whether to transform the numpy docstring to remove references of - parameters that are supported by the numpy version but not the JAX version. - If False, include the numpy docstring verbatim. - sections: a list of sections to include in the docstring. The default is - ["Parameters", "Returns", "References"] - module: an optional string specifying the module from which the original function - is imported. This is useful for objects such as ufuncs, where the module cannot - be determined from the original function itself. - """ - def decorator(wrapped_fun): - wrapped_fun.__np_wrapped__ = original_fun - # Allows this pattern: @implements(getattr(np, 'new_function', None)) - if original_fun is None: - return wrapped_fun - docstr = getattr(original_fun, "__doc__", None) - name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) - try: - mod = module or original_fun.__module__ - except AttributeError: - if config.enable_checks.value: - raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().") - else: - name = f"{mod}.{name}" - if docstr: - try: - parsed = _parse_numpydoc(docstr) - - if update_doc and 'Parameters' in parsed.sections: - code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None) - # Remove unrecognized parameter descriptions. - parameters = _parse_parameters(parsed.sections['Parameters']) - parameters = {p: desc for p, desc in parameters.items() - if (code is None or p in code.co_varnames)} - if parameters: - parsed.sections['Parameters'] = ( - "Parameters\n" - "----------\n" + - "\n".join(_versionadded.split(desc)[0].rstrip() - for p, desc in parameters.items()) - ) - else: - del parsed.sections['Parameters'] - - docstr = parsed.summary.strip() + "\n" if parsed.summary else "" - docstr += f"\nLAX-backend implementation of :func:`{name}`.\n" - docstr += "\n*Original docstring below.*\n" - - # We remove signatures from the docstrings, because they redundant at best and - # misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments. - # if parsed.signature: - # docstr += "\n" + parsed.signature.strip() + "\n" - - if parsed.front_matter: - docstr += "\n" + parsed.front_matter.strip() + "\n" - kept_sections = (content.strip() for section, content in parsed.sections.items() - if section in sections) - if kept_sections: - docstr += "\n" + "\n\n".join(kept_sections) + "\n" - except: - if config.enable_checks.value: - raise - docstr = original_fun.__doc__ - - wrapped_fun.__doc__ = docstr - for attr in ['__name__', '__qualname__']: - try: - value = getattr(original_fun, attr) - except AttributeError: - pass - else: - setattr(wrapped_fun, attr, value) - return wrapped_fun - return decorator - _dtype = partial(dtypes.dtype, canonicalize=True) def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e7a0e2142327..f1e6d399b97b 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,9 +23,11 @@ from jax._src import config from jax import lax from jax._src.numpy import lax_numpy as jnp -from jax._src.util import safe_map as map, safe_zip as zip +from jax._src.util import set_module, safe_map as map, safe_zip as zip +export = set_module('jax.numpy') + # See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html _DIMENSION_NAME = r'\w+' _CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME) @@ -185,6 +187,7 @@ def new_func(*args, **kwargs): return new_func, dynamic_args, dynamic_kwargs +@export def vectorize(pyfunc, *, excluded=frozenset(), signature=None): """Define a vectorized function with broadcasting. diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 4ff7062ac1e8..e1bedaf93377 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -30,6 +30,7 @@ py_library( srcs = [ "__init__.py", "core.py", + "cost_estimate.py", "pallas_call.py", "primitives.py", "utils.py", diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index abbd7154d1b7..acbf0d4f7ed5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -37,6 +37,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state import types as state_types from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -139,13 +140,6 @@ def __hash__(self): self.memory_space, )) - def at_least_vspace(self): - """Vector space method needed for AD.""" - raise NotImplementedError - - def join(self, other): - raise NotImplementedError - def str_short(self, short_dtypes=False): dt_str = \ dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name @@ -225,10 +219,13 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' - def join(self, other): - assert isinstance(other, AbstractMemoryRef) - return AbstractMemoryRef(self.inner_aval.join(other.inner_aval), - self.memory_space) + @property + def sharding(self): + return self.inner_aval.sharding + + def update_weak_type(self, weak_type): + return AbstractMemoryRef( + self.inner_aval.update_weak_type(weak_type), self.memory_space) def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval @@ -239,6 +236,10 @@ def to_tangent_aval(self): return AbstractMemoryRef( self.inner_aval.to_tangent_aval(), self.memory_space) + # TODO(dougalm, sharadmv): figure out how to avoid needing this + def normalize(self): + return state.AbstractRef(self.inner_aval).normalize() + def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval and self.memory_space == other.memory_space) @@ -261,13 +262,6 @@ def __str__(self) -> str: return self.value -def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): - return AbstractMemoryRef( - jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type), - ref_aval.memory_space) -jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped - - @dataclasses.dataclass(frozen=True) class PallasGridContext: grid: GridMappingGrid @@ -883,7 +877,7 @@ def get_grid_mapping( ) # The inputs for the index maps index_map_avals = ( - (index_map_grid_aval,) * len(grid_spec.grid)) + (index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid)) index_map_tree = tree_util.tree_structure((index_map_avals, {})) num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0) @@ -1059,6 +1053,8 @@ def _core_map_abstract_eval(*args, jaxpr, mesh): raise ValueError("core_map must not return any outputs.") effs = set() for eff in jaxpr.effects: + if mesh.discharges_effect(eff): + continue if not isinstance(eff, jax_core.NamedAxisEffect): effs.add(eff) continue @@ -1068,6 +1064,53 @@ def _core_map_abstract_eval(*args, jaxpr, mesh): _core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} + + +def default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + grid, + compiler_params, + backend, + jaxpr, +): + """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" + del out_avals # Unused. + + def body(*args): + # Due to aliasing, ``args`` contains aliased inputs and outputs so we + # remove outputs. + in_refs = args[:len(in_avals)] + jax_core.eval_jaxpr(jaxpr, in_refs) + + assert len(jaxpr.outvars) == 0 + modified_idxs = sorted( + eff.input_index + for eff in jaxpr.effects + if isinstance(eff, state_types.WriteEffect) + ) + any_spec = BlockSpec(memory_space=MemorySpace.ANY) + from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call.pallas_call( + body, + out_shape=[in_avals[idx] for idx in modified_idxs], + in_specs=[any_spec] * len(in_avals), + out_specs=[any_spec] * len(modified_idxs), + input_output_aliases={ + in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) + }, + grid=grid, + compiler_params=compiler_params, + backend=backend, + )(*args) + # ``outs`` lacks the unmodified inputs. Add them back in. + all_outs = [None] * len(args) + for out_idx, in_idx in enumerate(modified_idxs): + all_outs[in_idx] = outs[out_idx] + return all_outs, () + + @state_discharge.register_discharge_rule(core_map_p) def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs): if type(mesh) not in _core_map_mesh_rules: @@ -1083,6 +1126,8 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh): jax_core.check_jaxpr(jaxpr) effs = set() for eff in jaxpr.effects: + if mesh.discharges_effect(eff): + continue if not isinstance(eff, jax_core.NamedAxisEffect): effs.add(eff) continue @@ -1090,14 +1135,3 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh): effs.add(eff) return [], effs jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule - - -def _core_map_axis_subst(params, subst, traverse): - if not traverse: - return params - def shadowed_subst(name): - return (name,) if name in params['mesh'].shape else subst(name) - with jax_core.extend_axis_env_nd(params['mesh'].shape.items()): - new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst) - return dict(params, jaxpr=new_jaxpr) -jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py new file mode 100644 index 000000000000..5b322eedc837 --- /dev/null +++ b/jax/_src/pallas/cost_estimate.py @@ -0,0 +1,256 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper tool for automatic cost estimation.""" +import dataclasses +import functools +import math +from typing import Any, Sequence + +import jax +from jax._src import api_util +from jax._src import core as jax_core +from jax._src import custom_derivatives +from jax._src import linear_util as lu +from jax._src import pjit +from jax._src.state import discharge +from jax._src.pallas import core as pallas_core +from jax._src.interpreters import partial_eval as pe +from jax._src.util import safe_map +from jax._src.util import safe_zip +from jax._src.lax import lax + +map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin +zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin + +_cost_rules = {} + +@dataclasses.dataclass(frozen=True) +class CostEstimate: + flops: int + transcendentals: int + bytes_accessed: int + + def __add__(self, other: 'CostEstimate') -> 'CostEstimate': + return CostEstimate( + flops=self.flops + other.flops, + transcendentals=self.transcendentals + other.transcendentals, + bytes_accessed=self.bytes_accessed + other.bytes_accessed, + ) + +def register_cost_rule(primitive: jax_core.Primitive, rule): + _cost_rules[primitive] = rule + +@dataclasses.dataclass(frozen=True) +class Context: + avals_in: Sequence[Any] + avals_out: Sequence[Any] + +def cost_estimate_jaxpr( + jaxpr: jax_core.ClosedJaxpr, +) -> pallas_core.CostEstimate: + """Returns the cost estimate for the given Jaxpr.""" + jaxpr, _ = jaxpr.jaxpr, jaxpr.consts + total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) + + for eqn in jaxpr.eqns: + _, bind_params = eqn.primitive.get_bind_params(eqn.params) + rule = _cost_rules.get(eqn.primitive, None) + if rule is not None: + context = Context(avals_in=[v.aval for v in eqn.invars], + avals_out=[v.aval for v in eqn.outvars]) + op_cost = rule(context, **bind_params) + total_cost = total_cost + op_cost + return pallas_core.CostEstimate( + flops=total_cost.flops, + transcendentals=total_cost.transcendentals, + bytes_accessed=total_cost.bytes_accessed, + ) + +def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: + """Computes a cost estimate for the given function. + + Args: + fun: The function to compute the cost estimate for. + *args: The arguments to the function. Can be jax.ShapeDtypeStruct or + jax.Array. + **kwargs: The keyword arguments to the function. + + Returns: + A pallas_core.CostEstimate object containing the cost estimate. + """ + flattened_args, treedef = jax.tree.flatten(args) + partial_fun = functools.partial(fun, **kwargs) + wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun), + treedef) + avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) + estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) + input_bytes = sum( + math.prod(a.shape) * a.dtype.itemsize for a in flattened_args) + output_bytes = sum( + math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars) + return pallas_core.CostEstimate( + flops=estimate.flops, + transcendentals=estimate.transcendentals, + bytes_accessed=estimate.bytes_accessed + input_bytes + output_bytes, + ) + +def binary_cost_rule(ctx: Context, **_) -> CostEstimate: + aval_out, = ctx.avals_out + out_flops = math.prod(aval_out.shape) + return CostEstimate( + flops=out_flops, + transcendentals=0, + bytes_accessed=0, + ) +BINARY_OPS = [ + lax.add_p, + lax.mul_p, + lax.sub_p, + lax.div_p, + lax.min_p, + lax.max_p, + lax.or_p, + lax.and_p, + lax.xor_p, +] +for op in BINARY_OPS: + register_cost_rule(op, binary_cost_rule) + + +def unary_cost_rule(transcendental: bool): + def cost_rule(ctx: Context, **_) -> CostEstimate: + x_aval, = ctx.avals_in + new_flops = 0 + new_transcendentals = 0 + if transcendental: + new_transcendentals += math.prod(x_aval.shape) + else: + new_flops += math.prod(x_aval.shape) + return CostEstimate( + flops=new_flops, + transcendentals=new_transcendentals, + bytes_accessed=0, + ) + return cost_rule + +UN_OPS = [ + lax.neg_p, + lax.floor_p, + lax.ceil_p, + lax.round_p, + lax.not_p, +] +for op in UN_OPS: + register_cost_rule(op, unary_cost_rule(transcendental=False)) + +TRANSCENDENTAL_OPS = [ + lax.cos_p, + lax.sin_p, + lax.tan_p, + lax.sinh_p, + lax.cosh_p, + lax.tanh_p, + lax.acos_p, + lax.asin_p, + lax.atan_p, + lax.exp_p, + lax.log_p, + lax.logistic_p, + lax.sqrt_p, +] +for op in TRANSCENDENTAL_OPS: + register_cost_rule(op, unary_cost_rule(transcendental=True)) + +def _integer_pow_cost_rule(ctx: Context, *, y: int) -> CostEstimate: + x_aval, = ctx.avals_in + num_elements = math.prod(x_aval.shape) + if y == 0 or y == 1: + # No flops, the result is 0 or a copy of the input. + cost_per_element = 0 + else: + # We assume integer pow is implemented using repeated squaring. + # The cost is log(y) squarings, plus one multiply per non-zero bit. + highest_bit = math.floor(math.log(y, 2)) + cost_per_element = highest_bit + y.bit_count() + return CostEstimate( + flops=num_elements * cost_per_element, + transcendentals=0, + bytes_accessed=0, + ) +register_cost_rule(lax.integer_pow_p, _integer_pow_cost_rule) + +def dot_general_cost_rule(ctx: Context, + dimension_numbers: lax.DotDimensionNumbers, + **_) -> CostEstimate: + x_aval, y_aval = ctx.avals_in + x_shape, y_shape = x_aval.shape, y_aval.shape + (lhs_contracting_dims, rhs_contracting_dims), ( + lhs_batch_dims, rhs_batch_dims) = dimension_numbers + assert len(lhs_contracting_dims) == len(rhs_contracting_dims) + assert len(lhs_batch_dims) == len(rhs_batch_dims) + flops = 1 + # Flops along a contracting dim is 2*dim (addition and multiplication) + for i in range(len(lhs_contracting_dims)): + lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i] + assert x_shape[lhs_dim] == y_shape[rhs_dim] + flops *= 2 * x_shape[lhs_dim] + # Now we handle all other dimensions. + for i, lhs_dim in enumerate(x_shape): + if i in lhs_contracting_dims: + continue + flops *= lhs_dim + for i, rhs_dim in enumerate(y_shape): + if i in rhs_contracting_dims: + continue + # Don't double-count batch dims (we already counted for LHS) + if i in rhs_batch_dims: + continue + flops *= rhs_dim + return CostEstimate( + flops=flops, + transcendentals=0, + bytes_accessed=0, + ) +register_cost_rule(lax.dot_general_p, dot_general_cost_rule) + +# Higher-order primitives +def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_): + del ctx + inner_cost = cost_estimate_jaxpr(jaxpr) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(pjit.pjit_p, _pjit_cost_rule) + +def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): + del ctx + inner_cost = cost_estimate_jaxpr(fun_jaxpr) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) + +def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2): + inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr)) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(discharge.run_state_p, _run_state_rule) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index f52ba9ddd6cd..c9754933908f 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -74,10 +74,17 @@ py_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", + ":verification", "//jax", + "//jax:config", + "//jax:core", + "//jax:mlir", "//jax:mosaic", + "//jax:sharding_impls", "//jax:source_info_util", + "//jax:tpu_custom_call", "//jax/_src/lib", "//jax/_src/pallas", ] + py_deps("numpy"), @@ -117,6 +124,7 @@ py_library( "//jax:pallas", "//jax:util", "//jax/_src/pallas", + "//jax/extend:backend", ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 82fe9c2baa96..ad9a6cb13f42 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -28,10 +28,10 @@ from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call import jax.numpy as jnp import numpy as np + map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -174,15 +174,6 @@ def get_ref_aval(self) -> AbstractMemoryRef: class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType - def join(self, other): - if not isinstance(other, AbstractSemaphore): - raise ValueError - if other.sem_type != self.sem_type: - raise ValueError - return self - -jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval - @dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): @@ -220,6 +211,10 @@ class TensorCoreMesh: def shape(self): return collections.OrderedDict(zip(self.axis_names, self.devices.shape)) + def discharges_effect(self, effect: jax_core.Effect): + del effect + return False + def create_tensorcore_mesh( axis_name: str, devices: Sequence[jax.Device] | None = None @@ -246,33 +241,19 @@ def _tensorcore_mesh_discharge_rule( mesh, jaxpr, ): - del out_avals assert isinstance(mesh, TensorCoreMesh) if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") core_axis_name, num_cores = list(mesh.shape.items())[0] - def body(*args): - # Due to aliasing, args contains aliased inputs and outputs so we remove - # outputs. - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, in_refs) - assert len(jaxpr.outvars) == 0 - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - out_specs=[pallas_core.BlockSpec( - memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + jaxpr=jaxpr, grid=((core_axis_name, num_cores),), - compiler_params=dict( - mosaic=dict(dimension_semantics=("parallel",)), - ), + compiler_params=TPUCompilerParams(dimension_semantics=("parallel",)), backend="mosaic_tpu", - )(*args) - return out, () + ) pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( _tensorcore_mesh_discharge_rule diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 18b73a66ca23..c3211efa2031 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -42,6 +42,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -50,6 +51,7 @@ from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core @@ -61,7 +63,7 @@ from jax._src.state import primitives as state_primitives from jax._src.state.types import RefBitcaster, RefReshaper from jax._src.state.utils import dtype_bitwidth -from jax._src.typing import DTypeLike +from jax._src.typing import Array, DTypeLike from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list @@ -501,8 +503,13 @@ def err_details(): ) else: assert rank == 1 - # TODO(necula): test this for bool. What should it do? - tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) + # bools get a bitwidth of 32 due to how mosaic handles them + if bm.array_shape_dtype.dtype == jnp.bool_: + bitwidth = 32 + else: + bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype) + packing = 32 // bitwidth + tiling_size = 128 * packing evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) if not evenly_divisible: raise ValueError( @@ -837,6 +844,8 @@ def write_env(var: jax_core.Var, val): except LoweringException: raise # We only add the extra info to the innermost exception. except Exception as e: + if not pallas_call._verbose_errors_enabled(): + raise msg = (f"{type(e).__name__}: {e}\n" + "Additional diagnostics: \n" + f"Failing jaxpr equation: {eqn}\n") @@ -1307,12 +1316,20 @@ def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): ref, transforms, val, mask = args_tree.unflatten(args_flat) - ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) + ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten( + ctx.avals_in + ) (*prev_transforms, idx) = transforms (*_, idx_aval) = transforms_avals if mask is not None: - raise NotImplementedError + if val_aval.dtype.itemsize != 4: + raise NotImplementedError("masked swap with non-32-bit data") + if val_aval.shape != mask_aval.shape: + raise ValueError( + "Expected value and mask to have the same shape, but got" + f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}." + ) ref_block_shape, *_ = ctx.block_shapes ref, ref_block_shape = _transform_ref( @@ -1343,6 +1360,8 @@ def _masked_swap_lowering_rule( need_stride = not all((s is None or s == 1) for s in strides) if is_smem_store: + if mask is not None: + raise ValueError("SMEM store does not support masks") if val_aval.shape: raise ValueError("Can only store scalars to SMEM") result = memref.load(ref, starts) @@ -1372,7 +1391,7 @@ def _masked_swap_lowering_rule( 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) for b in ref_block_shape ] - mem_aval = aval_out.update(shape=tuple(mem_slice_shape)) + mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None) mem_aval_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)) if need_stride: @@ -1391,9 +1410,15 @@ def _masked_swap_lowering_rule( result = _maybe_cast_load_to_bool(val_aval, result) if need_stride: + if mask is not None: + raise NotImplementedError("masked swap with strided store") tpu.StridedStoreOp(val, ref, starts, strides) - else: + elif jaxlib_version <= (0, 4, 35): + if mask is not None: + raise NotImplementedError("masked swap with vector store") vector.StoreOp(val, ref, starts) + else: + tpu.VectorStoreOp(val, ref, starts, [], mask=mask) return result @@ -1520,9 +1545,22 @@ def _proxy_reduce(arg, *, axes): lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule +def _broadcast_to_lowering_rule( + ctx: LoweringRuleContext, x, shape: Sequence[int] +): + raise RuntimeError( + "`broadcast_to` is a Triton-specific primitive. Please consider using" + " `jnp.broadcast_to` instead." + ) + + +lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule + + def _broadcast_in_dim_lowering_rule( - ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions + ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): + del sharding (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out @@ -1560,6 +1598,71 @@ def _proxy_fun(val, *, shape, broadcast_dimensions): lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule +def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): + """Converts a jax dot dimension numbers to a tpu dot dimension numbers. + + Jax dot dimension numbers are given as a tuple of tuples of sequences of ints + of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims)). + + TPU dot dimension numbers are given as an MLIR definition of the form + #tpu.dot_dimension_numbers - which can be found in the tpu dilect definition + # file, tpu.td . + """ + (contracting_dims, batch_dims) = dimension_numbers + lhs_contracting_dims, rhs_contracting_dims = contracting_dims + lhs_batch_dims, rhs_batch_dims = batch_dims + + lhs_total_dims = set(range(len(lhs_shape))) + rhs_total_dims = set(range(len(rhs_shape))) + + lhs_non_contracting_dims = sorted( + lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims) + ) + rhs_non_contracting_dims = sorted( + rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims) + ) + + # Create output_dim_order + # Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims, + # rhs_non_contracting_dims - this assumption is safe to make, as it is + # the same one made in jax's dot_general. + output_dim_order = [] + + lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))} + rhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(rhs_shape)))} + + for dim in lhs_batch_dims: + output_dim_order.append(0) + output_dim_order.append(lhs_dim_map[dim]) + + for dim in lhs_non_contracting_dims: + output_dim_order.append(0) + output_dim_order.append(lhs_dim_map[dim]) + + for dim in rhs_non_contracting_dims: + output_dim_order.append(1) + output_dim_order.append(rhs_dim_map[dim]) + + def format_dims(dims): + return "[" + ", ".join(str(d) for d in dims) + "]" + + all_dims = ( + lhs_contracting_dims, + rhs_contracting_dims, + lhs_non_contracting_dims, + rhs_non_contracting_dims, + output_dim_order, + lhs_batch_dims, + rhs_batch_dims, + ) + tpu_dim_numbers_str = ( + f"#tpu.dot_dimension_numbers<{','.join(map(format_dims, all_dims))}>" + ) + + return ir.Attribute.parse(tpu_dim_numbers_str) + + def _dot_general_lowering_rule( ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_ ): @@ -1585,7 +1688,7 @@ def _dot_general_lowering_rule( raise NotImplementedError( f"Only 2D tensors supported in dot; received: {ctx.avals_in}" ) - lhs_aval, _ = ctx.avals_in + lhs_aval, rhs_aval = ctx.avals_in # This is really a matrix-vector product. It only looks like matrix-matrix. if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1: if ctx.avals_in[0].shape != ctx.avals_in[1].shape: @@ -1611,18 +1714,10 @@ def _dot_general_lowering_rule( ) return vector.shape_cast(out_type, red) - if lhs_dims == (1,): - transpose_lhs = False - elif lhs_dims == (0,): - transpose_lhs = True - else: - raise NotImplementedError - if rhs_dims == (0,): - transpose_rhs = False - elif rhs_dims == (1,): - transpose_rhs = True - else: - raise NotImplementedError + tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims( + dimension_numbers, lhs_aval.shape, rhs_aval.shape + ) + if precision is not None: if precision[0] != precision[1]: raise NotImplementedError("Per-operand dot precision unsupported") @@ -1639,9 +1734,12 @@ def _dot_general_lowering_rule( out_type, ir.DenseElementsAttr.get_splat(out_type, val) ) return tpu.matmul( - out_type, x, y, out_tile, - transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs, - precision=precision_attr + out_type, + x, + y, + out_tile, + dimension_numbers=tpu_dot_dims, + precision=precision_attr, ) @@ -1763,7 +1861,8 @@ def _convert_element_type_lowering_rule( lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule -def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions): +def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, + sharding): if dimensions is not None: raise NotImplementedError if any(d is None for d in new_sizes): @@ -1909,7 +2008,7 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): lowering_rules[lax.sub_p] = _sub_lowering_rule -skip_mlir_conversions.add(lax.max_p) +skip_mlir_conversions.add(lax.sub_p) def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): @@ -1929,7 +2028,7 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out - if jnp.issubdtype(aval_out.dtype, jnp.integer): + if jnp.issubdtype(aval_out.dtype, jnp.signedinteger): return arith.divsi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): return arith.divui(x, y) @@ -1992,6 +2091,15 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.sign_p] = _sign_lowering_rule +def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): + return lower_fun( + pallas_utils.nextafter_lowering_helper, multiple_results=False, + )(ctx, x, y) + + +lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule + + def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): return math.rsqrt(x) @@ -2006,6 +2114,15 @@ def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule +def _square_lowering_rule(ctx: LoweringRuleContext, x): + if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer): + return arith.muli(x, x) + return arith.mulf(x, x) + + +lowering_rules[lax.square_p] = _square_lowering_rule + + def _exp_lowering_rule(ctx: LoweringRuleContext, x): return math.exp(x) @@ -2014,6 +2131,11 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): + # jax accepts float base (x) and integer/float exponent (y), and integer + # exponent is casted to float. + out_type = aval_to_ir_type(ctx.avals_out[0]) + if jnp.issubdtype(ctx.avals_in[1].dtype, jnp.integer): + y = arith.sitofp(out_type, y) if not isinstance(x, ir.Value) and x == 2.: return math.exp2(y) x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) @@ -2173,7 +2295,49 @@ def _population_count_lowering_rule(ctx: LoweringRuleContext, x): } -def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): +# The relationship between comparison operations on booleans and boolean +# algebra is as follows: +# eq(x, y) = !(x ^ y) +# ne(x, y) = x ^ y +# lt(x, y) = !x && y +# le(x, y) = !x || y +# gt(x, y) = x && !y +# ge(x, y) = x || !y +def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array): + """A helper function for lowering comparison operations for boolean inputs. + + Args: + primitive: A JAX primitive representing a comparison operation, which is + one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals), + `lax.lt_p` (less than), `lax.le_p` (less than or equal to), + `lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to). + x: A boolean array representing the first operand in the comparison. + y: A boolean array representing the second operand in the comparison. + + Returns: + A boolean array that is the result of applying the comparison operation + between `x` and `y` based on the given primitive. + + Raises: + ValueError: If an unsupported comparison primitive is provided. + """ + if primitive == lax.eq_p: + return jnp.logical_not(jnp.logical_xor(x, y)) + elif primitive == lax.ne_p: + return jnp.logical_xor(x, y) + elif primitive == lax.lt_p: + return jnp.logical_and(jnp.logical_not(x), y) + elif primitive == lax.le_p: + return jnp.logical_or(jnp.logical_not(x), y) + elif primitive == lax.gt_p: + return jnp.logical_and(x, jnp.logical_not(y)) + elif primitive == lax.ge_p: + return jnp.logical_or(x, jnp.logical_not(y)) + else: + raise ValueError(f"Unsupported comparison primitive: {primitive}") + + +def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) x_aval, y_aval = ctx.avals_in if x_aval.dtype != y_aval.dtype: @@ -2182,60 +2346,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): ) dtype = x_aval.dtype - # For boolean comparisons, we handle them in two different ways. For `ne`, - # we directly use the xor operation since they are equivalent. For all - # other comparisons, we convert the boolean values to `int32` and use select - # operations to perform the comparison. - # - # The relationship between comparison operations on booleans and boolean - # algebra is as follows: - # - # eq(a, b) = !(a ^ b) - # ne(a, b) = a ^ b - # lt(a, b) = !a && b - # le(a, b) = !a || b - # gt(a, b) = a && !b - # ge(a, b) = a || !b - # - # However, except for `ne`, all other operations require negation, which is - # currently not supported. At present, even if negation were supported, - # it would still need to be implemented using `select` operations, making - # it equivalent to our current approach. For more details on negation support, - # see https://github.com/jax-ml/jax/issues/24243. if jnp.issubdtype(dtype, jnp.bool_): - if prim == lax.ne_p: - return arith.xori(x, y) - - i32 = ir.IntegerType.get_signless(32) - vtype = ir.VectorType.get(x_aval.shape, i32) - - # Convert `x` and `y` from `bool` to `int32` for comparison, with 2 - # for true and 0 for false. For example, comparing `x > y` is equivalent - # to `(x ? 2 : 0) > (y ? 2 : 0)`. - # - # Note that we cannot use 1 for true because the select operation will be - # misteriously eliminated. - two = arith.constant(i32, 2) - zero = arith.constant(i32, 0) - - out_aval, = ctx.avals_out - if out_aval.shape != (): - # broadcast to vectors if we are comparing vectors - two = vector.broadcast(vtype, two) - zero = vector.broadcast(vtype, zero) - - x = arith.select(x, two, zero) - y = arith.select(y, two, zero) - dtype = jnp.int32 + return lower_fun( + functools.partial(_cmp_boolean_lowering_helper, primitive), + multiple_results=False, + )(ctx, x, y) if jnp.issubdtype(dtype, jnp.integer): is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger) - pred = (_cmpui_lowering_types if is_uint else _cmpsi_lowering_types)[prim] + pred = ( + _cmpui_lowering_types if is_uint else _cmpsi_lowering_types + )[primitive] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) return arith.cmpi(predicate, x, y) if jnp.issubdtype(dtype, jnp.floating): - pred = _cmpf_lowering_types[prim] + pred = _cmpf_lowering_types[primitive] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) return arith.cmpf(predicate, x, y) @@ -2795,16 +2921,17 @@ def _device_id_to_logical( # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides - def _linearize_mesh_indices(*indices): - return sum(a * b for a, b in zip(indices, mesh_strides)) - lower_ctx = LoweringRuleContext( - lowering_context=ctx.lowering_context, - avals_in=[pallas_core.index_map_grid_aval] * len(device_ids), - avals_out=[pallas_core.index_map_grid_aval], - block_shapes=(None,) * len(device_ids), + + i32 = ir.IntegerType.get_signless(32) + if len(device_ids) == 0: + return arith.constant(i32, 0) + return functools.reduce( + arith.addi, + ( + arith.muli(a, arith.constant(i32, b)) + for a, b in zip(device_ids, mesh_strides) + ), ) - return lower_fun(_linearize_mesh_indices, multiple_results=False)( - lower_ctx, *device_ids) elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") @@ -3150,3 +3277,68 @@ def _lower_fun(shape): lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering + + +def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): + operand, padding_value = args + padding_config = kwargs["padding_config"] + + out_type: ir.VectorType = aval_to_ir_type(ctx.avals_in[0]) + if not isinstance(out_type, ir.VectorType): + raise NotImplementedError("Only vector types are supported.") + + for axis, (low, high, interior) in enumerate(padding_config): + if low == 0 and high == 0 and interior == 0: + continue + + def _pad(val): + shape = list(operand.type.shape) + shape[axis] = val + pad_vec_type = ir.VectorType.get( + shape, + operand.type.element_type, + ) + + if isinstance(padding_value, ir.OpResult): + pad = vector.BroadcastOp( + pad_vec_type, + padding_value, + ).result + else: + scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) + pad = arith.ConstantOp( + pad_vec_type, + ir.DenseElementsAttr.get_splat( + pad_vec_type, + scalar_attr, + ), + ).result + return pad + + if low != 0: + pad_low = _pad(low) + new_shape = out_type.shape + new_shape[axis] += low + out_type = ir.VectorType.get( + new_shape, + out_type.element_type, + ) + operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis) + + if high != 0: + pad_high = _pad(high) + new_shape = out_type.shape + new_shape[axis] += high + out_type = ir.VectorType.get( + new_shape, + out_type.element_type, + ) + operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis) + + if interior > 0: + raise NotImplementedError("Not implemented: interior padding") + + return operand + + +lowering_rules[lax.pad_p] = _pad_lowering_rule diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 2bf96511b64e..ec9500c67cd7 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -19,24 +19,22 @@ import os import tempfile from typing import Any -import warnings import jax -from jax import core as jax_core from jax import dtypes from jax._src import config -from jax._src import core as jax_src_core +from jax._src import core as jax_core from jax._src import sharding_impls +from jax._src import tpu_custom_call from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering from jax._src.pallas.mosaic import verification -from jax._src import tpu_custom_call from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu -from jax.experimental.pallas import tpu as pltpu + def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): """Casts boolean values to integers. @@ -126,27 +124,12 @@ def pallas_call_tpu_lowering_rule( else: mosaic_params = {} - if "cost_estimate" in mosaic_params: - # TODO(amagni): Remove this branch after October 22th 2024. - if cost_estimate is not None: - raise ValueError( - "Passing cost estimate via both compiler_params=dict(mosaic=...) and" - " pallas_call(..., cost_estimate=...) is not supported." - ) - - warnings.warn( - "Passing cost estimate via compiler_params=dict(cost_estimate=...) is" - " deprecated. Use pallas_call(..., cost_estimate=...) instead.", - DeprecationWarning, - ) - cost_estimate = mosaic_params["cost_estimate"] - mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: if isinstance(axis_context, sharding_impls.SPMDAxisContext): mesh = axis_context.mesh - mlir_ctx = ir.Context() + mlir_ctx = mlir.JaxIrContext() mlir_ctx.append_dialect_registry(mlir.upstream_dialects) mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) @@ -205,7 +188,7 @@ def lower_module(for_verification: bool): # Replace in_avals to physical avals. # This step is required for mapping logical types to physical types. # (e.g. PRNG key -> uint32[2]) - physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in] + physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in] ctx = ctx.replace(avals_in=physical_avals) # Booleans are loaded into the kernel as integers. @@ -222,7 +205,7 @@ def _maybe_cast_inputs(*args): kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) output_memory_spaces = _get_memory_spaces_from_avals(out_avals) if cost_estimate is not None: - mosaic_cost_estimate = pltpu.CostEstimate( + mosaic_cost_estimate = tpu_custom_call.CostEstimate( flops=cost_estimate.flops, bytes_accessed=cost_estimate.bytes_accessed, transcendentals=cost_estimate.transcendentals, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 0112b3cb4dbb..6ddb21e77bd8 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -33,6 +33,7 @@ from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives from jax.experimental import pallas as pl +from jax.extend.backend import get_default_device import jax.numpy as jnp import numpy as np @@ -75,7 +76,7 @@ def add_leaves(i, x): @jax_util.cache(trace_context_in_key=False) def _get_tpu_generation() -> int: - kind = jax.devices()[0].device_kind + kind = get_default_device().device_kind if kind.endswith(' lite'): kind = kind[:-len(' lite')] assert kind[:5] == "TPU v", kind diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 7aab30ffc2ab..0a7bd371a639 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -593,14 +593,14 @@ def dma_start_discharge_rule(in_avals, out_avals, # Note that this code only works in SPMD mode. If not all devices execute # the DMA then the devices that do will hang. # TODO(justinfu): Verify that code only works in SPMD mode. - axis_env = jax_core.thread_local_state.trace_state.axis_env - nonempty_axes = [frame for frame in axis_env if frame.name is not None] + axis_env = jax_core.get_axis_env() + nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] if device_id_type == DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") - shard_axis = nonempty_axes[0].name + shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) elif device_id_type == DeviceIdType.MESH: device_id_len = 1 @@ -608,14 +608,14 @@ def dma_start_discharge_rule(in_avals, out_avals, device_id_len = device_id.size elif hasattr(device_id, '__len__'): device_id_len = len(device_id) - if device_id_len != len(axis_env): + if device_id_len != len(axis_env.axis_sizes): raise ValueError( - f"device_id ({device_id_len}) and mesh ({len(axis_env)}) " + f"device_id ({device_id_len}) and mesh ({len(axis_env.axis_sizes)}) " "must have same length.") if device_id_len > 1 or len(nonempty_axes) > 1: raise NotImplementedError("Meshes with more than 1 named dimension not " "implemented in dma_start_p") - shard_axis = nonempty_axes[0].name + shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) else: raise ValueError(f"Unknown device_id_type: {device_id_type}") diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index bae87226c664..08ff58770804 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -145,6 +145,15 @@ def block(self, begin: str, end: str): self.level -= 1 self.locals.append(self._indent(end) + "\n") + @contextlib.contextmanager + def comment_if_emitted(self, comment): + self.comment(comment) + yield + self.comment(comment) + if self.locals[-1] == self.locals[-2]: + self.locals.pop() + self.locals.pop() + def get(self, value: ir.Value, default: Any = _UNSPECIFIED): if default is _UNSPECIFIED: return self.env[value] @@ -358,6 +367,17 @@ def _print_op(ctx, op): return bin_op(ctx, "int", "%", *op.operands) case "arith.divsi": return bin_op(ctx, "int", "/", *op.operands) + case "arith.andi": + return bin_op(ctx, _model_type(op.result.type), "&", *op.operands) + case "arith.select": + cond, if_true, if_false = map(lambda o: ctx.get(o, None), op.operands) + if cond is None or if_true is None or if_false is None: + return NotImplemented + result_ty = _model_type(op.result.type) + return ctx.emit(result_ty, f"({cond} -> {if_true} : {if_false})") + case "arith.index_cast": + model = ctx.get(op.operands[0], None) + return ctx.emit("int", model) if model is not None else NotImplemented case "arith.cmpi": match op.predicate.value: case arith.CmpIPredicate.eq: @@ -386,12 +406,44 @@ def _print_op(ctx, op): read_refs.append(model) with ctx.block("d_step {", "}"): # Start reading for r in read_refs: + for loc in r.written_at(None): + ctx.emit(None, f"assert(!{loc})") for loc in r.readers_at(None): ctx.emit(None, f"{loc}++") with ctx.block("d_step {", "}"): # Stop reading for r in read_refs: for loc in r.readers_at(None): ctx.emit(None, f"{loc}--") + case "vector.load": + ref = ctx.get(op.operands[0]) + assert isinstance(ref, GlobalRefModel) + if (first_idx := ctx.get(op.operands[1], None)) is not None: + leading_load_len = ir.VectorType(op.result.type).shape[0] + ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_load_len) + with ctx.block("d_step {", "}"): # Start reading + for loc in ref.written_at(None): + ctx.emit(None, f"assert(!{loc})") + for loc in ref.readers_at(None): + ctx.emit(None, f"{loc}++") + with ctx.block("d_step {", "}"): # Stop reading + for loc in ref.readers_at(None): + ctx.emit(None, f"{loc}--") + return NotImplemented # We don't model the result of the load. + case "vector.store": + ref = ctx.get(op.operands[1]) # Stored value goes first + assert isinstance(ref, GlobalRefModel) + if (first_idx := ctx.get(op.operands[2], None)) is not None: + leading_store_len = ir.VectorType(op.operands[0].type).shape[0] + ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_store_len) + with ctx.block("d_step {", "}"): # Start writing + for loc in ref.readers_at(None): + ctx.emit(None, f"assert(!{loc})") + for loc in ref.written_at(None): + ctx.emit(None, f"assert(!{loc})") + ctx.emit(None, f"{loc} = 1") + with ctx.block("d_step {", "}"): # Stop reading + for loc in ref.written_at(None): + ctx.emit(None, f"{loc} = 0") case "scf.for": carrys = [ ctx.emit("int", ctx.get(arg)) @@ -419,6 +471,7 @@ def _print_op(ctx, op): ctx.emit(None, f"{c} = {ctx.get(new)}") ctx.emit(None, f"{induction_var} = {induction_var} + {step}") ctx.emit(None, ":: else -> break") + ctx.emit(None, "skip") # To avoid "Jump into d_step sequence errors" if len(carrys) == 1: return carrys[0] else: @@ -450,16 +503,27 @@ def bin_op(ctx, result_ty, op, lhs, rhs): return ctx.emit(result_ty, f"{lhs} {op} {rhs}") +def _model_type(ty): + if ir.IntegerType.isinstance(ty): + if ir.IntegerType(ty).width == 1: + return "bool" + else: + return "int" + else: + raise NotImplementedError(ty) + + def _print_block(ctx, block): for op in block: try: - results = _print_op(ctx, op) + with ctx.comment_if_emitted(op.OPERATION_NAME): + results = _print_op(ctx, op) except Exception as e: raise RuntimeError(f"Failed to print op: {op}") from e if results is NotImplemented: continue if not op.results: - assert results is None + assert results is None or results == () elif len(op.results) > 1: raise NotImplementedError(op) else: @@ -529,7 +593,8 @@ def export_promela_model( @assume_p.def_abstract_eval def _assume_abstract_eval(x, y): - return x.join(y) + assert jax_core.typematch(x, y) + return x def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): return y if ctx.lowering_context.for_verification else x diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 91616948be49..e9461a5ceba0 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -44,6 +44,7 @@ pytype_strict_library( deps = [ ":lowering", "//jax", + "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", "//jax/_src/pallas", @@ -61,6 +62,7 @@ pytype_strict_library( "//jax:mosaic_gpu", "//jax:pallas", "//jax:partial_eval", + "//jax:source_info_util", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", @@ -74,6 +76,7 @@ pytype_strict_library( "//jax", "//jax:core", "//jax:dtypes", + "//jax:effects", "//jax:mosaic_gpu", "//jax:tree_util", "//jax/_src/pallas", @@ -89,7 +92,7 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", - "//jax:effects", + "//jax:mlir", "//jax:mosaic_gpu", "//jax:tree_util", "//jax:util", @@ -97,3 +100,19 @@ pytype_strict_library( "//jax/_src/pallas", ], ) + +pytype_strict_library( + name = "pipeline", + srcs = ["pipeline.py"], + deps = [ + ":core", + ":primitives", + "//jax", + "//jax:core", + "//jax:mosaic_gpu", + "//jax:pallas", + "//jax:partial_eval", + "//jax:util", + "//jax/_src/pallas", + ], +) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2ed8910bf3d7..d77ae4358703 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -21,14 +21,16 @@ from collections.abc import Sequence import dataclasses import enum +import itertools as it from typing import Any, ClassVar, Literal from jax._src import core as jax_core from jax._src import dtypes +from jax._src import effects from jax._src import tree_util from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call -from jax._src.state.types import Transform +from jax._src.state import indexing +from jax._src.state import types as state_types import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp from jaxlib.mlir import ir @@ -39,6 +41,20 @@ DimensionSemantics = Literal["parallel", "sequential"] +def is_trivial_index(idx, shape) -> bool: + """Checks if the index selects the entire shape.""" + + # Slices that select the entire dimension. + def _slices(d): + slices = [slice(b, e, s) for b, e, s in it.product([0, None], [d, None], [1, None])] + return [indexing.Slice(0, d, 1), *slices] + + if isinstance(idx, tuple): + return all(i in _slices(d) for d, i in zip(shape, idx)) + + return idx is ... or (len(shape) == 1 and idx in _slices(shape[0])) + + @dataclasses.dataclass(frozen=True, kw_only=True) class GPUCompilerParams(pallas_core.CompilerParams): """Mosaic GPU compiler parameters. @@ -56,12 +72,24 @@ class GPUCompilerParams(pallas_core.CompilerParams): references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. + profile_space: The number of profiler events that can be collected in a + single invocation. It is undefined behavior if a thread collects more + events than this. + profile_dir: The directory to which profiling traces will be written to. """ PLATFORM: ClassVar[str] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[DimensionSemantics] | None = None max_concurrent_steps: int = 1 delay_release: int = 0 + profile_space: int = 0 + profile_dir: str = "" + + def __post_init__(self): + if bool(self.profile_space) ^ bool(self.profile_dir): + raise ValueError( + "Either both profile_space and profile_dir must be set, or neither." + ) class GPUMemorySpace(enum.Enum): @@ -80,7 +108,8 @@ def __call__( shape: tuple[int, ...], dtype: jnp.dtype, transforms: Sequence[MemoryRefTransform] = (), - ): + + ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) @@ -108,6 +137,14 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): def to_gpu_transform(self) -> mgpu.MemRefTransform: pass + def batch(self, leading_rank: int): + """Returns a transform that accepts a ref with the extra `leading_rank` dims. + + The returned transform should leave the leading dimensions unchanged and + only apply to the suffix of the shape. + """ + raise NotImplementedError + def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: return aval.update( shape=self.to_gpu_transform().transform_shape(aval.shape) @@ -123,7 +160,6 @@ class TilingTransform(MemoryRefTransform): shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a tiling of (64, 32) will be tiled as (4, 8, 64, 32). """ - tiling: tuple[int, ...] def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: @@ -131,14 +167,17 @@ def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: ref, transforms=(*ref.transforms, UntileRef(self.tiling)) ) + def batch(self, leading_rank: int): + return self + def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) -class UntileRef(Transform): - tiling: tuple[int, ...] +class UntileRef(state_types.Transform): + tiling: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) def transform_shape(self, shape): if shape is None: @@ -155,7 +194,7 @@ def transform_dtype(self, dtype): def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] idxs_after_tiling = [] @@ -173,14 +212,6 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) - def tree_flatten(self): - return (), (self.tiling,) - - @classmethod - def tree_unflatten(cls, metadata, arrays): - assert not arrays - return cls(*metadata) - def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -198,6 +229,11 @@ def __post_init__(self): if set(self.permutation) != set(range(len(self.permutation))): raise ValueError(f"Permutation {self.permutation} is not a permutation.") + def batch(self, leading_rank: int): + return TransposeTransform( + (*range(leading_rank), *(d + leading_rank for d in self.permutation)) + ) + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: return dataclasses.replace( ref, @@ -211,9 +247,9 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(self.permutation) -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) -class TransposeRef(Transform): +class TransposeRef(state_types.Transform): permutation: tuple[int, ...] def transform_shape(self, shape): @@ -226,7 +262,7 @@ def transform_dtype(self, dtype): def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: removed_dims = [ i for i, idx in enumerate(idxs) if not isinstance(idx, slice) ] @@ -241,14 +277,6 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) - def tree_flatten(self): - return (), (self.permutation,) - - @classmethod - def tree_unflatten(cls, metadata, arrays): - assert not arrays - return cls(*metadata) - def transpose_ref( ref: pallas_core.TransformedRef | Any, @@ -274,6 +302,9 @@ def __post_init__(self): " accepted." ) + def batch(self, leading_rank: int): + return self + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: return dataclasses.replace( ref, transforms=(*ref.transforms, UnswizzleRef(self.swizzle)) @@ -296,14 +327,14 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: return aval -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) -class UnswizzleRef(Transform): - swizzle: int +class UnswizzleRef(state_types.Transform): + swizzle: int = dataclasses.field(metadata=dict(static=True)) def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: if not idxs: return idxs, self if not all(isinstance(idx, slice) for idx in idxs[-2:]): @@ -320,14 +351,6 @@ def untransform_index( raise ValueError("Swizzled dims cannot be sliced") return idxs, self - def tree_flatten(self): - return (), (self.swizzle,) - - @classmethod - def tree_unflatten(cls, metadata, arrays): - assert not arrays - return cls(*metadata) - @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): @@ -398,19 +421,29 @@ def get_ref_aval(self) -> AbstractMemoryRef: class WGMMAAccumulatorRef: shape: tuple[int, int] dtype: jnp.dtype = jnp.float32 + _init: Any = state_types.uninitialized def get_ref_aval(self) -> AbstractMemoryRef: + if self._init is not state_types.uninitialized: + raise ValueError( + "Preinitialized WGMMAAccumulatorRef only supported in pl.run_state." + ) return WGMMAAbstractAccumulatorRef( jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS ) + @staticmethod + def init(array): + return WGMMAAccumulatorRef(array.shape, array.dtype, _init=array) -def _is_trivial_index(idx): - _is_deref1 = lambda i: i is Ellipsis or i == slice(None) - if isinstance(idx, tuple): - return all(_is_deref1(i) for i in idx) - return _is_deref1(idx) +def _wgmma_ref_type_mapping(ref: WGMMAAccumulatorRef): + aval = WGMMAAbstractAccumulatorRef( + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), GPUMemorySpace.REGS + ) + return aval, ref._init +state_types._ref_type_aval_mappings[WGMMAAccumulatorRef] = _wgmma_ref_type_mapping + class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): __slots__ = ["inner_aval", "memory_space"] @@ -418,20 +451,17 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' - def join(self, other): - return _as_accum(super().join(other)) + def update_weak_type(self, weak_type): + return _as_accum(super().update_weak_type(weak_type)) def update(self, inner_aval=None, memory_space=None): return _as_accum(super().update(inner_aval=None, memory_space=None)) - def at_least_vspace(self): - return _as_accum(super().at_least_vspace()) - def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error arr = wgmma_accumulator_deref(tracer) - if not _is_trivial_index(idx): + if not is_trivial_index(idx, tracer.shape): arr = arr[idx] return arr @@ -443,10 +473,6 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: memory_space=ref.memory_space, # pytype: disable=attribute-error ) -def _ref_raise_to_shaped(ref_aval, weak_type): - return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) -jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped - _WARPGROUP_AXIS_NAME = object() @@ -457,6 +483,7 @@ class GPUMesh: # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None axis_names: tuple[str, ...] = () + approx_math: bool = False def __post_init__(self): if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): @@ -484,6 +511,9 @@ def shape(self): ) return collections.OrderedDict(pairs) + def discharges_effect(self, effect: jax_core.Effect): + return effect is _wgmma_pipeline_effect or effect is _memory_effect + def _gpu_mesh_discharge_rule( in_avals, @@ -492,28 +522,34 @@ def _gpu_mesh_discharge_rule( mesh, jaxpr, ): - del out_avals assert isinstance(mesh, GPUMesh) if mesh.cluster: raise NotImplementedError if mesh.num_threads is None: raise NotImplementedError - def body(*args): - # Due to aliasing, args contains aliased inputs and outputs so we remove - # outputs. - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, in_refs) - assert len(jaxpr.outvars) == 0 - any_spec = pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY) - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + jaxpr=jaxpr, grid=tuple(mesh.shape.items()), backend="mosaic_gpu", - )(*args) - return out, () - + compiler_params=GPUCompilerParams(approx_math=mesh.approx_math), + ) pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule + + +class MemoryEffect(jax_core.Effect): + pass + + +effects.control_flow_allowed_effects.add_type(MemoryEffect) +_memory_effect = MemoryEffect() + + +class _WGMMAPipelineEffect(effects.Effect): + pass + + +effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) +_wgmma_pipeline_effect = _WGMMAPipelineEffect() diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ec5584233552..ff5667d75e3b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,36 +17,42 @@ from __future__ import annotations import collections -from collections.abc import MutableMapping, MutableSequence, Sequence +from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools import itertools as it import math -from typing import Any, Hashable, Protocol, cast +from typing import Any, Protocol, cast import jax from jax import lax from jax._src import core as jax_core from jax._src import pjit from jax._src import util +from jax._src import source_info_util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import discharge from jax._src.state import indexing +from jax._src.state import types as state_types from jax._src.state import primitives as sp +from jax._src.state.types import RefReshaper import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import profiler as mgpu_profiler import jax.numpy as jnp import numpy as np @@ -64,6 +70,7 @@ # sensitive to alignment and while this is quite conservative, it gets the job # done. We should make this more refined in the future. _SMEM_ALIGNMENT = 1024 +WARPGROUP_SIZE = 128 def _align_to(x: int, alignment: int): if (rem := x % alignment): @@ -135,6 +142,7 @@ def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources: # Assume that unsupported primitives are neutral wrt resource usage. continue rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params) + return rs @@ -161,9 +169,11 @@ def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int: aval = v.aval if isinstance(aval.dtype, gpu_core.BarrierType): rs += Resources( - barrier_counts=collections.Counter( - [mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)] - ) + barrier_counts=collections.Counter([ + mgpu.Barrier( + aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape + ) + ]) ) else: rs += Resources( @@ -182,7 +192,7 @@ def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int @dataclasses.dataclass class ModuleContext: name: str - grid_mapping: pallas_core.GridMapping + grid_names: Sequence[Hashable] | None program_ids: Sequence[ir.Value] | None approx_math: bool runtime_smem: ir.Value # ir.MemRefType @@ -190,6 +200,9 @@ class ModuleContext: runtime_barriers: MutableMapping[ mgpu.Barrier, MutableSequence[mgpu.BarrierRef] ] + name_stack: source_info_util.NameStack + traceback_caches: mlir.TracebackCaches + squashed_dims: tuple[int, ...] def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: """Reserves a barrier. @@ -253,6 +266,7 @@ def scratch_view( class LoweringRuleContext: module_ctx: ModuleContext launch_ctx: mgpu.LaunchContext + predicate: ir.Value avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -263,7 +277,15 @@ class LoweringRuleContext: class LoweringResult: module: ir.Module grid: tuple[int, ...] + block: tuple[int, ...] out_structs: tuple[jax.ShapeDtypeStruct, ...] + profiler_context: ProfilerContext | None + + +@dataclasses.dataclass(frozen=True) +class ProfilerContext: + dump_path: str + spec: mgpu_profiler.ProfilerSpec class LoweringError(Exception): # pylint: disable=g-bad-exception-name @@ -340,10 +362,6 @@ def lower_jaxpr_to_module( assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims - if len(grid_mapping.grid) > 3: - raise NotImplementedError( - "Only <=3D grids are supported in Mosaic GPU lowering." - ) if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "Dynamic grid bounds not supported in the Mosaic GPU lowering." @@ -377,19 +395,25 @@ def lower_jaxpr_to_module( f" {max_concurrent_steps=}, {delay_release=}" ) - block = (128, 1, 1) - grid = grid_mapping.grid if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid[:-1] + logical_grid = grid_mapping.grid[:-1] + else: + block = (128, 1, 1) + logical_grid = grid_mapping.grid - grid = [d for i, d in enumerate(grid) if i not in sequential_axes] - if len(grid) < 3: - grid += (1,) * (3 - len(grid)) + parallel_grid = [ + d for i, d in enumerate(logical_grid) if i not in sequential_axes + ] + if len(parallel_grid) <= 3: + squashed_dims = () + parallel_grid += (1,) * (3 - len(parallel_grid)) else: - raise NotImplementedError( - "Only <=3D grids are supported in Mosaic GPU lowering." - ) + # If we have >3 parallel dimensions, we merge all leading dimensions + # into the first (Dimension.x) CUDA grid dimension. + squashed_dims = parallel_grid[:-2] + parallel_grid = [math.prod(parallel_grid[:-2]), *parallel_grid[-2:]] + if sequential_axes: # TODO(slebedev): Support multiple sequential axes. if len(sequential_axes) > 1: @@ -477,10 +501,10 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): parallel_count = it.count() program_ids_template = [ - _program_id(next(parallel_count)) + _program_id(next(parallel_count), squashed_dims=squashed_dims) if axis not in sequential_axes else None - for axis in range(len(grid_mapping.grid)) + for axis in range(len(logical_grid)) ] def make_program_ids(step: ir.Value): @@ -493,12 +517,15 @@ def make_program_ids(step: ir.Value): grouped_barriers[barrier].append(barrier_ref) module_ctx = ModuleContext( name_and_src_info.name, - grid_mapping, + grid_mapping.grid_names, None, approx_math, runtime_smem, smem_used_bytes=0, runtime_barriers=grouped_barriers, + name_stack=source_info_util.NameStack(), + traceback_caches=mlir.TracebackCaches(), + squashed_dims=squashed_dims, ) del runtime_smem, grouped_barriers, runtime_barriers @@ -589,7 +616,6 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: gmem_transform=gmem_transforms, swizzle=swizzle, arrive=False, # The caller must do ``arrive_expect_tx`` manually! - uniform=False, predicate=is_memory_thread, ) @@ -642,7 +668,6 @@ def store( gmem_slice=store_slice, gmem_transform=gmem_transforms, swizzle=swizzle, - uniform=False, predicate=do_store, ) return base_offset @@ -744,7 +769,7 @@ def _(step, carry): ) rs = _estimate_resources(jaxpr) extra_barriers = [ - mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + mgpu.Barrier(aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape) for aval in scratch_avals if isinstance(aval.dtype, gpu_core.BarrierType) ] @@ -754,16 +779,21 @@ def _(step, carry): if not isinstance(aval.dtype, gpu_core.BarrierType) and aval.memory_space == gpu_core.SMEM ] - smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") + smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes extra_smem_scratch.append( jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) - module, out_structs_smem, _ = mgpu_core._lower_as_gpu_kernel( + prof_ctx = prof_spec = None + if prof_space := params.get("profile_space", 0): + # Each range is 2 events, each event is 4 bytes. + prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) + prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) + module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel( body, - grid=grid, + grid=parallel_grid, cluster=(), block=block, in_shapes=in_structs_gmem, @@ -778,9 +808,12 @@ def _(step, carry): ), ), module_name=name_and_src_info.name, + prof_spec=prof_spec, ) - return LoweringResult(module, grid, out_structs_smem) + return LoweringResult( + module, parallel_grid, block, out_structs_gmem, prof_ctx + ) mosaic_lowering_rules = {} @@ -794,6 +827,19 @@ def deco(fn): return deco +def _compute_name_stack_updates( + old_name_stack: list[str], + new_name_stack: list[str] +) -> tuple[list[str], list[str]]: + common_prefix_idx = 0 + for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)): + if old == new: + common_prefix_idx = i+1 + else: + break + return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:] + + def lower_jaxpr_to_mosaic_gpu( module_ctx: ModuleContext, launch_ctx: mgpu.LaunchContext, @@ -811,35 +857,57 @@ def write_env(var: jax_core.Var, val): map(write_env, jaxpr.constvars, consts) map(write_env, jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. + last_local_name_stack: list[str] = [] + named_regions = [] for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) - if eqn.primitive not in mosaic_lowering_rules: - raise NotImplementedError( - "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/jax-ml/jax/issues." - ) - rule = mosaic_lowering_rules[eqn.primitive] - rule_ctx = LoweringRuleContext( - module_ctx, - launch_ctx, - avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], - avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], + source_info = eqn.source_info.replace( + name_stack=module_ctx.name_stack + eqn.source_info.name_stack ) - try: - outvals = rule(rule_ctx, *invals, **eqn.params) - except LoweringError: - raise # We only add the extra info to the innermost exception. - except Exception as e: - inval_types = map(lambda t: getattr(t, "type", None), invals) - raise LoweringError( - f"Exception while lowering eqn:\n {eqn}\nWith context:\n " - f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" - ) from e - if eqn.primitive.multiple_results: - map(write_env, eqn.outvars, outvals) - else: - write_env(eqn.outvars[0], outvals) + loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) + with source_info_util.user_context(eqn.source_info.traceback), loc: + if eqn.primitive not in mosaic_lowering_rules: + raise NotImplementedError( + "Unimplemented primitive in Pallas Mosaic GPU lowering: " + f"{eqn.primitive.name}. " + "Please file an issue on https://github.com/jax-ml/jax/issues." + ) + new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] + popped, pushed = _compute_name_stack_updates(last_local_name_stack, new_local_name_stack) + last_local_name_stack = new_local_name_stack + for _ in popped: + named_regions.pop().close() + for name in pushed: + wrapper_stack = contextlib.ExitStack() + wrapper_stack.enter_context(launch_ctx.named_region(name)) + named_regions.append(wrapper_stack) + rule = mosaic_lowering_rules[eqn.primitive] + rule_ctx = LoweringRuleContext( + module_ctx, + launch_ctx, + predicate=mgpu.single_thread_predicate(per_block=False), + avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], + avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], + ) + try: + outvals = rule(rule_ctx, *invals, **eqn.params) + except LoweringError: + raise # We only add the extra info to the innermost exception. + except Exception as e: + if not pallas_call._verbose_errors_enabled(): + raise + inval_types = map(lambda t: getattr(t, "type", None), invals) + raise LoweringError( + f"Exception while lowering eqn:\n {eqn}\nWith context:\n " + f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" + ) from e + if eqn.primitive.multiple_results: + map(write_env, eqn.outvars, outvals) + else: + write_env(eqn.outvars[0], outvals) + while named_regions: # Drain the name stack. + named_regions.pop().close() return map(read_env, jaxpr.outvars) @@ -849,12 +917,42 @@ def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): raise NotImplementedError("pl.program_id() is not supported in this context") return ctx.module_ctx.program_ids[axis] - -def _program_id(axis: int) -> ir.Value: - return arith_dialect.index_cast( - ir.IntegerType.get_signless(32), - gpu_dialect.block_id(gpu_dialect.Dimension(axis)), - ) +def _unravel_program_id( + block_id: ir.Value, + axis: int, + dimensions: tuple[int, ...], + row_major: bool = False +) -> ir.Value: + """Computes the program ID for axes compressed into one block dimension.""" + if row_major: + div_value = math.prod(dimensions[axis+1:]) + else: + div_value = math.prod(dimensions[:axis]) + div_value = _as_index(_i32_constant(div_value)) + pid = arith_dialect.divui(block_id, div_value) + axis_size = _as_index(_i32_constant(dimensions[axis])) + pid = arith_dialect.remui(pid, axis_size) + return arith_dialect.index_cast(ir.IntegerType.get_signless(32), pid) + + +def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value: + if squashed_dims: + if parallel_axis < len(squashed_dims): + # All squashed dimensions are mapped to Dimension.x. + block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) + return _unravel_program_id(block_id, parallel_axis, squashed_dims) + else: + # Handle unsquashed axes. + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension( + parallel_axis - len(squashed_dims) + 1)), + ) + else: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension(parallel_axis)), + ) @register_lowering_rule(primitives.num_programs_p) @@ -866,6 +964,33 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): ) +def _handle_reshaping( + ref: ir.Value, transforms: Sequence[gpu_core.Transform] +) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: + is_trivial_indexer = lambda t: isinstance( + t, indexing.NDIndexer + ) and gpu_core.is_trivial_index(t.indices, t.shape) + + last_reshaper_idx = next( + reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), + None, + ) + if last_reshaper_idx is None: + return ref, transforms + # Check that before the reshape are only trivial indexes and or + # other reshapes. + # TODO(cperivol): Reshapes should bubble up rather than being + # expected to effectively be the first ref transform. + if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): + raise NotImplementedError( + "Reshapes do not compose with other transforms and indexers must be" + f" trivial (transforms: {transforms})" + ) + reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) + # Skip all the reshapes and trivial indexes. + return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] + + def _handle_indexing( ref: ir.Value, transforms: Sequence[gpu_core.Transform] ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: @@ -914,9 +1039,13 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") + x_aval = ctx.avals_in[0] + transforms = jax.tree.unflatten(tree, leaves) + x_smem, transforms = _handle_reshaping(x_smem, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) + match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (64, swizzle // x_aval.dtype.itemsize): @@ -925,6 +1054,12 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle ) case (): + # Handle scalar indexing. + if not ctx.avals_out[0].shape: + is_signed = mgpu_utils.is_signed(x_aval.dtype) + val = memref_dialect.load(x_smem, []) + return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) + return mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) ) @@ -942,6 +1077,7 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) + x_smem, transforms = _handle_reshaping(x_smem, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): @@ -1004,7 +1140,9 @@ def _broadcast_in_dim_lowering_rule( *, broadcast_dimensions, shape, + sharding, ): + del sharding [x_aval] = ctx.avals_in [y_aval] = ctx.avals_out x = _ensure_fa(x, x_aval.dtype) @@ -1030,6 +1168,12 @@ def _convert_element_type_lowering_rule( ) +mosaic_lowering_rules.update({ + lax.neg_p: lambda ctx, x: -x, + lax.not_p: lambda ctx, x: ~x, +}) + + def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) @@ -1039,7 +1183,6 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), - lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y), lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y), lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y), lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y), @@ -1055,6 +1198,14 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): }) +@register_lowering_rule(lax.div_p) +def _div_lowering_rule(ctx: LoweringRuleContext, x, y): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + if ir.FloatType.isinstance(x.mlir_dtype): + return x / y + return x // y + + @register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): [x_aval] = ctx.avals_in @@ -1063,12 +1214,23 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): return x * x return NotImplementedError +@register_lowering_rule(lax.square_p) +def _square_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) + return x * x @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) +@register_lowering_rule(lax.tanh_p) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) + + @register_lowering_rule(lax.logistic_p) def _logistic_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in @@ -1082,24 +1244,29 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): return a.exp(approx=ctx.module_ctx.approx_math) +@register_lowering_rule(lax.exp2_p) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + a = _ensure_fa(x, x_aval.dtype) + return a.exp2(approx=ctx.module_ctx.approx_math) + + @register_lowering_rule(lax.reduce_sum_p) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: case mgpu.WGStridedFragLayout(): - if axes != (0,): - raise NotImplementedError("No support for axes other than 0 yet") + if set(axes) != set(range(x_aval.ndim)): + raise NotImplementedError("No support for axes yet") scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: - return mgpu.FragmentedArray.splat( - x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) + return x.reduce_sum(scratch) case mgpu.WGMMA_LAYOUT: if axes != (x_aval.ndim - 1,): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): raise NotImplementedError - return x.reduce(arith_dialect.addf, axes[0]) + return x.reduce("add", axes[0]) case _: raise NotImplementedError(f"Unsupported layout {x.layout}") @@ -1113,23 +1280,51 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): raise NotImplementedError - return x.reduce(arith_dialect.maxnumf, axes[0]) + return x.reduce("max", axes[0]) case _: raise NotImplementedError(f"Unsupported layout {x.layout}") @register_lowering_rule(lax.axis_index_p) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - grid_names = ctx.module_ctx.grid_mapping.grid_names + i32 = ir.IntegerType.get_signless(32) + grid_names = ctx.module_ctx.grid_names + squashed_dims = ctx.module_ctx.squashed_dims + if squashed_dims: + unsquashed_names = grid_names[-3:] + squashed_names = grid_names[:-3] + else: + # These are unused but initialized for type checkers. + unsquashed_names = () + squashed_names = () if grid_names and axis_name in grid_names: if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=False) + return mgpu.warpgroup_idx(sync=True) else: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - ir.IntegerType.get_signless(32), - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) + if squashed_dims: + if axis_name in unsquashed_names: + # We add 1 to the index because the first dimension is the + # squashed dimension. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.x (0) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) + idx = unsquashed_names.index(axis_name) + 1 + return arith_dialect.index_cast( + i32, + gpu_dialect.block_id(gpu_dialect.Dimension(idx)), + ) + elif axis_name in squashed_names: + # All squashed dimensions are mapped to Dimension.x. + block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) + axis = squashed_names.index(axis_name) + return _unravel_program_id(block_id, axis, squashed_dims) + else: + if axis_name in grid_names: + idx = grid_names.index(axis_name) + return arith_dialect.index_cast( + i32, + gpu_dialect.block_id(gpu_dialect.Dimension(idx)), + ) raise ValueError( "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" ) @@ -1184,7 +1379,9 @@ def _run_scoped_lowering_rule( elif isinstance(aval.dtype, gpu_core.BarrierType): input_refs.append( ctx.module_ctx.reserve_barrier( - mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + mgpu.Barrier( + aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape + ) ) ) should_discharge.append(False) @@ -1231,6 +1428,55 @@ def _run_scoped_lowering_rule( return outs +@register_lowering_rule(discharge.run_state_p) +def _run_state_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr: jax_core.Jaxpr, + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...], +): + del which_linear + # TODO(apaszke): This should be unified with run_scoped. + if not all(is_initialized): + raise NotImplementedError("Uninitialized Refs are not supported in lowering of run_state.") + + should_discharge = [] + new_input_vals = [] + for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + should_discharge.append(True) + assert isinstance(out_aval, jax_core.ShapedArray) + else: + new_input_vals.append(arg) + should_discharge.append(not isinstance(out_aval, state_types.AbstractRef)) + if not any(should_discharge): + raise NotImplementedError( + "Expected at least one accumulator to in run_state." + ) + + discharged_jaxpr, new_consts = discharge.discharge_state( + jaxpr, (), should_discharge=should_discharge + ) + assert not new_consts + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () + ) + # Await the accumulators and extract their final values. + nvvm_dialect.wgmma_wait_group_sync_aligned(0) + outs = [ + out.value if isinstance(out, mgpu.WGMMAAccumulator) else out + for out in outs + ] + # Blend the discharge results with refs we closed over. I don't fully + # understand the reasons behind this calling convention, but sharadmv@ has + # assured me that this is ok. + outs_it = iter(outs) + return [next(outs_it) if d else a for d, a in zip(should_discharge, args)] + + def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, @@ -1241,16 +1487,23 @@ def _lower_jaxpr_to_for_loop( has_loop_index: bool, ): - @mgpu.fori(length, [*args]) + _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) + arg_avals = arg_avals[has_loop_index:] + out_avals = [] + if arg_avals: + out_avals = ctx.avals_out[-len(arg_avals):] + + @mgpu.fori(length, [*map(_ensure_fa, args, arg_avals)]) def loop(loop_index, body_args): if has_loop_index: loop_index = arith_dialect.addi(loop_index, start) jaxpr_args = [*consts, loop_index, *body_args] else: jaxpr_args = [*consts, *body_args] - return lower_jaxpr_to_mosaic_gpu( + outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args ) + return map(_ensure_fa, outs, out_avals) return loop.results @@ -1289,13 +1542,12 @@ def _scan_lowering_rule( _consts_avals, arg_avals = util.split_list(ctx.avals_in, [num_consts]) if has_loop_index: start, *args = args - index_aval, *arg_avals = arg_avals + index_aval, *_ = arg_avals start: ir.Value = _ensure_ir_value(start, index_aval.dtype) length = _ir_constant(length, start.type) else: start = _i32_constant(0) length = _i32_constant(length) - args = map(lambda arg, aval: _ensure_fa(arg, aval.dtype), args, arg_avals) for_out = _lower_jaxpr_to_for_loop( ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index ) @@ -1306,11 +1558,70 @@ def _scan_lowering_rule( return for_out +@register_lowering_rule(lax.while_p) +def _while_lowering_rule( + ctx: LoweringRuleContext, + *args, + cond_jaxpr, + body_jaxpr, + cond_nconsts, + body_nconsts, +): + # First try to lower via a simpler fori loop, which may optimize better. + fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop( + cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts + ) + del cond_jaxpr, body_jaxpr + if fori_jaxpr is None: + raise NotImplementedError(err) + + if fori_jaxpr.constvars: + raise NotImplementedError + + lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:] + # Reflect the changes of the pattern matcher to the context. + avals_in = ( + *ctx.avals_in[cond_nconsts:body_nconsts], + ctx.avals_in[body_nconsts], # the index + *ctx.avals_in[body_nconsts + 2:], + ) + + avals_out = tuple(ctx.avals_out[2:]) + ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out) + _, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts]) + + lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype) + length = arith_dialect.subi(ub, lb) + + for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True) + return (ub, ub, *for_out) + @register_lowering_rule(lax.cond_p) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in + + def _yielded_values(outs, avals): + ret = [] + for out, aval in zip(outs, avals): + if isinstance(out, mgpu.FragmentedArray): + ret.append(out) + else: + ret.append(_ensure_ir_value(out, aval.dtype)) + return ret + + # We need the branch return mlir types in order to construct the + # switch operation. To avoid leaking information about what kind of + # mlir types are internal to FragmentedArrays and other mgpu types, + # we run one of the branches in a dummy module that we throw away to + # extract the return types + with ir.InsertionPoint(ir.Module.create().body): + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args + ) + yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] + switch_op = scf_dialect.IndexSwitchOp( - map(mgpu_utils.dtype_to_ir_type, ctx.avals_out), + yielded_types, _as_index(_ensure_ir_value(index, index_aval.dtype)), ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), num_caseRegions=len(branches) - 1, @@ -1322,16 +1633,54 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): regions = list(switch_op.regions) # Move the default region to the back. regions = regions[1:] + regions[:1] + treedef = None for branch, region in zip(branches, regions): with ir.InsertionPoint(region.blocks.append()): outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args + ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts ) - scf_dialect.yield_([ - _ensure_ir_value(out, aval.dtype) - for out, aval in zip(outs, ctx.avals_out) - ]) - return list(switch_op.results) + + yielded_leaves, yielded_treedef = jax.tree.flatten(_yielded_values(outs, ctx.avals_out)) + if treedef is None: + treedef = yielded_treedef + else: + assert treedef == yielded_treedef + + scf_dialect.yield_(yielded_leaves) + + assert treedef is not None + return treedef.unflatten(list(switch_op.results)) + + +@register_lowering_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, operand, *, new_dtype +): + # TODO(petebu) Handle case where src and dst types have different bitwidths + [operand_aval] = ctx.avals_in + operand = _ensure_fa(operand, operand_aval.dtype) + src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype) + dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype) + assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) + assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) + if src_elem_type.width != dst_elem_type.width: + raise NotImplementedError( + f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they" + " have different widths" + ) + if ir.IntegerType.isinstance(dst_elem_type): + output_is_signed = mgpu_utils.is_signed(new_dtype) + else: + output_is_signed = None + return mgpu.FragmentedArray.bitcast( + operand, dst_elem_type, output_is_signed=output_is_signed + ) + + +@register_lowering_rule(lax.optimization_barrier_p) +def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): + args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + return mgpu.optimization_barrier(*args) def _bcast( @@ -1399,10 +1748,14 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: def _i32_constant(v: int) -> ir.Value: + if v < jnp.iinfo(jnp.int32).min or v > jnp.iinfo(jnp.int32).max: + raise ValueError(f"Integer constant out of range for i32: {v}") return arith_dialect.constant(ir.IntegerType.get_signless(32), v) def _i64_constant(v: int) -> ir.Value: + if v < jnp.iinfo(jnp.int64).min or v > jnp.iinfo(jnp.int64).max: + raise ValueError(f"Integer constant out of range for i64: {v}") return arith_dialect.constant(ir.IntegerType.get_signless(64), v) @@ -1417,4 +1770,4 @@ def _as_index(v: object) -> ir.Value: case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()): return _as_index(v.registers.item()) case _: - raise ValueError(f"Unsupported index: {v}") + raise ValueError(f"Unsupported index: {v} of type {type(v)}") diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 960fe7d71856..18d8baf6e95e 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -17,9 +17,13 @@ from __future__ import annotations +import os +import time from typing import Any +import warnings -from jax import core as jax_core +import jax +from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering @@ -63,10 +67,41 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - return mosaic_core._mosaic_gpu_lowering_rule( - ctx, + new_avals_out = [ + jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs + ] + outs = mosaic_core._mosaic_gpu_lowering_rule( + ctx.replace(avals_out=new_avals_out), *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), out_types=lowering_result.out_structs, input_output_aliases=input_output_aliases, ) + if (prof_ctx := lowering_result.profiler_context) is not None: + *outs, prof_buffer = outs + if (dump_path := prof_ctx.dump_path) == "sponge": + dump_path = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR") # type: ignore + out_file = os.path.join( + dump_path, f"{name_and_src_info.name}-{time.time_ns()}-trace.json" + ) + def dump_profile(prof_buffer): + try: + with open(out_file, "x") as f: + prof_ctx.spec.dump( + prof_buffer, + f, + grid=lowering_result.grid, + block=lowering_result.block, + ) + except FileExistsError: + warnings.warn( + f"Failed to dump profile for pallas_call {name_and_src_info}, " + f"profile already exists at {out_file}" + ) + def do_callback(prof_buffer): + jax.debug.callback(dump_profile, prof_buffer) + return () + mlir.lower_fun(do_callback, multiple_results=True)( + ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer + ) + return outs diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py new file mode 100644 index 000000000000..ee3f03f1849f --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -0,0 +1,326 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for emitting custom GPU pipelines within a Pallas kernel.""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +import dataclasses +import functools +import itertools as it +import math +from typing import Any + +import jax +from jax import lax +from jax._src import core +from jax._src import linear_util as lu +from jax._src import util +from jax._src.interpreters import partial_eval as pe +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives +from jax.experimental import pallas as pl +import jax.numpy as jnp + + +map = util.safe_map +zip = util.safe_zip + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class BufferedRef: + spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True}) + is_index_invariant: bool = dataclasses.field(metadata={"static": True}) + gmem_ref: pallas_core.AbstractMemoryRef + # ``None`` if the ref is pinned to GMEM; otherwise, has shape + # [num_slots, *spec.block_shape]. + smem_ref: pallas_core.AbstractMemoryRef | None + + def get_ref_for_slot( + self, slot: int | jax.Array + ) -> pallas_core.AbstractMemoryRef: + if self.smem_ref is None: + return self.gmem_ref + return self.smem_ref.at[slot] + + def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: + index_map = self.spec.index_map + assert index_map is not None + return tuple( + pl.Slice(idx * size, size) # type: ignore[arg-type] + for idx, size in zip( + index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] + ) + ) + + def copy_in(self, slot, grid_indices, barrier_ref): + if not _in_smem(self.spec): + return + assert self.smem_ref is not None + gmem_slices = self.compute_gmem_slice(grid_indices) + gpu_primitives.copy_gmem_to_smem( + self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands + self.smem_ref.at[slot], + barrier_ref.at[slot], + ) + + def copy_out(self, slot, grid_indices, predicate=None): + if not _in_smem(self.spec): + return + assert self.smem_ref is not None + gmem_slices = self.compute_gmem_slice(grid_indices) + gpu_primitives.copy_smem_to_gmem( + self.smem_ref.at[slot], + self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands + predicate=predicate, + ) + + +def _uses_arguments( + index_map: Callable[..., Any], num_args: int +) -> Sequence[bool]: + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args + ) + _, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars)) + return used_inputs + + +def _is_index_invariant( + spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid +) -> bool: + if (index_map := spec.index_map) is None: + return True + return not any(_uses_arguments(index_map, len(grid))) + + +def _inc_grid_by_1( + indices: tuple[jax.Array, ...], grid: Sequence[int] +) -> tuple[jax.Array, ...]: + next_indices = [] + carry: bool | jax.Array = True + for idx, size in reversed(list(zip(indices, grid))): + next_idx = lax.select(carry, idx + 1, idx) + carry = next_idx == size + next_indices.append(lax.select(carry, 0, next_idx).astype(idx.dtype)) + return tuple(reversed(next_indices)) + + +def _in_smem(spec: pallas_core.BlockSpec) -> bool: + return spec.memory_space in (None, gpu_core.SMEM) + + +# ``pl.Slice`` uses a different pytree encoding, depending on whether the +# start/size are static or dynamic. This leads to pytree structure mismatch +# in the pipeline body. So, we define a different ``Slice`` class below. + + +@dataclasses.dataclass(frozen=True) +class _Slice: + start: int | jax.Array + size: int | jax.Array + + def __eq__(self, other: _Slice) -> jax.Array: # type: ignore + return lax.bitwise_and(self.start == other.start, self.size == other.size) + + +jax.tree_util.register_dataclass( + _Slice, data_fields=["start", "size"], meta_fields=[] +) + + +def emit_pipeline( + body: Callable[..., None], + *, + grid: pallas_core.StaticGrid, + in_specs: Sequence[pallas_core.BlockSpec] = (), + out_specs: Sequence[pallas_core.BlockSpec] = (), + max_concurrent_steps: int = 1, + delay_release: int = 0, +): + """Creates a function to emit a manual pipeline within a Pallas kernel. + + Args: + body: The pipeline body. + grid: The grid to use for the pipeline. + in_specs: The block specs for the inputs. + out_specs: The block specs for the outputs. + max_concurrent_steps: The maximum number of sequential stages that are + active concurrently. Defaults to 1. + delay_release: The number of steps to wait before reusing the input/output + references. Defaults to 0, and must be strictly smaller than + ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you + don't await the WGMMA in the body. + """ + num_steps = math.prod(grid) + + if max_concurrent_steps <= delay_release: + raise ValueError( + "max_concurrent_steps must be greater than delay_release, but" + f" {max_concurrent_steps=}, {delay_release=}" + ) + + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to + # reduce the size of the refs allocated in SMEM. + if max_concurrent_steps > num_steps: + max_concurrent_steps = num_steps + delay_release = 0 # No need to delay anything. + + def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): + in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) + in_smem_refs, out_smem_refs = util.split_list( + [ + gpu_core.SMEM( + (max_concurrent_steps, *spec.block_shape), # type: ignore + ref.dtype, + transforms=tuple( + t.batch(1) for t in getattr(spec, "transforms", ()) + ), + ) + if _in_smem(spec) + else None + for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs) + ], + [len(in_specs)], + ) + return pl.run_scoped( + functools.partial( + scoped_pipeline, + in_gmem_refs=in_gmem_refs, + out_gmem_refs=out_gmem_refs, + ), + in_smem_refs=in_smem_refs, + out_smem_refs=out_smem_refs, + barrier_ref=gpu_core.Barrier( + # TODO(slebedev): Change this to arrive only once. + sum(map(_in_smem, in_specs)), + num_barriers=max_concurrent_steps, + ), + ) + + def scoped_pipeline( + *, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref + ): + in_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref) + for spec, gmem_ref, smem_ref in zip( + in_specs, in_gmem_refs, in_smem_refs + ) + ] + out_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref) + for spec, gmem_ref, smem_ref in zip( + out_specs, out_gmem_refs, out_smem_refs + ) + ] + + for step, indices in enumerate( + it.islice(it.product(*map(range, grid)), max_concurrent_steps) + ): + map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + + def loop_body(step, carry): + slot = step % max_concurrent_steps + indices, fetch_indices, last_store_slices = carry + + if in_specs: + # Wait for the current GMEM->SMEM copy to complete. + gpu_primitives.barrier_wait(barrier_ref.at[slot]) + # Wait for the previous output SMEM->GMEM copy to complete. + gpu_primitives.wait_smem_to_gmem( + max_concurrent_steps - (1 + delay_release), wait_read_only=True + ) + + with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): + body(*( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + )) + + if not all(bref.is_index_invariant for bref in out_brefs): + gpu_primitives.commit_smem() + + # Copy the output from SMEM to GMEM. + new_store_slices = last_store_slices[:] + for idx, bref in enumerate(out_brefs): + if bref.is_index_invariant: + assert last_store_slices[idx] is None + continue + assert last_store_slices[idx] is not None + new_store_slices[idx] = tuple( + _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) + ) + are_same_slices = map( + lambda old, new: old == new, + last_store_slices[idx], + new_store_slices[idx], + ) + slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) + is_last_step = step == num_steps - 1 + # TODO(apaszke,slebedev): This still diverges significantly from the + # TPU semantics in that it will move on to the next SMEM output slice + # even if it's not storing the previous one. + bref.copy_out( + slot, + indices, + predicate=lax.bitwise_or(slices_changed, is_last_step), + ) + + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = slot # (x + y) % y == x % y + jax.lax.cond( + lax.bitwise_and(fetch_step >= delay_release, fetch_step < num_steps), + lambda: map( + lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref), + in_brefs, + ), + lambda: [None] * len(in_brefs), + ) + + return ( + _inc_grid_by_1(indices, grid), + _inc_grid_by_1(fetch_indices, grid), + new_store_slices, + ) + + indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid) + fetch_indices = indices + for _ in range(max_concurrent_steps): + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + last_store_slices = [ + None + if bref.is_index_invariant + else (_Slice(-1, -1),) * len(bref.spec.block_shape) + for bref in out_brefs + ] + last_indices, _, _ = lax.fori_loop( + 0, num_steps, loop_body, (indices, fetch_indices, last_store_slices) + ) + + # Outputs invariant to the sequential axis are never written from inside the + # loop. This is the only place where we store them. + if all(bref.is_index_invariant for bref in out_brefs): + gpu_primitives.commit_smem() + last_slot = (num_steps - 1) % max_concurrent_steps + for bref in out_brefs: + if bref.is_index_invariant: + bref.copy_out(last_slot, last_indices, predicate=None) + + # Finalize the pipeline. + gpu_primitives.wait_smem_to_gmem(0) + + return pipeline diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 13d76174472c..85b7364ce2cc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -17,14 +17,18 @@ from __future__ import annotations import enum +import math from typing import Any, Literal import jax from jax._src import core as jax_core -from jax._src import effects from jax._src import state from jax._src import tree_util from jax._src import util +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -33,6 +37,10 @@ from jax._src.state import indexing from jax._src.state import primitives as state_primitives import jax.experimental.mosaic.gpu as mgpu +import jax.numpy as jnp + + +WARPGROUP_SIZE = 128 copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") @@ -50,20 +58,31 @@ def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, src, dst, - *flat_transforms, + *flat_args, src_transforms_treedef, dst_transforms_treedef, + has_user_predicate, ): + predicate = ctx.predicate + if has_user_predicate: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + predicate = arith_dialect.andi( + predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) + ) flat_src_transforms, flat_dst_transforms = util.split_list( - flat_transforms, + flat_args, [src_transforms_treedef.num_leaves], ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_indexing(src, src_transforms) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - mgpu.commit_shared() - ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) + ctx.launch_ctx.async_copy( + src_ref=src, + dst_ref=dst, + predicate=predicate, + **copy_params, + ) return () @@ -95,16 +114,25 @@ def _extract_smem_copy_params(transforms): def copy_smem_to_gmem( - src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef + src: pallas_core.AbstractMemoryRef, + dst: pallas_core.AbstractMemoryRef, + predicate: jax.Array | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. + Args: + src: The SMEM reference to copy from. + dst: The GMEM reference to copy to. + predicate: A boolean indicating whether the copy should be performed. If + ``None``, the copy is always performed. + See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` + :func:`jax.experimental.mosaic.gpu.commit_smem` """ if src.memory_space is not gpu_core.SMEM: raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") - if dst.memory_space is not gpu_core.GMEM: + if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM: raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}") src, src_transforms = state_primitives.get_ref_and_transforms( src, None, "copy_smem_to_gmem", force_trailing_indexer=False, @@ -123,8 +151,10 @@ def copy_smem_to_gmem( dst, *flat_src_transforms, *flat_dst_transforms, + *[] if predicate is None else [predicate], src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, + has_user_predicate=predicate is not None, ) return None @@ -170,8 +200,19 @@ def _copy_gmem_to_smem_lowering( barrier = barrier.__getitem__( *map(lowering._as_index, barrier_indexer.indices) ) + dst_ty = ir.MemRefType(dst.type) + bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) + if bytes % WARPGROUP_SIZE: + raise NotImplementedError("Only aligned copies are supported") + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: apaszke - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= WARPGROUP_SIZE + barrier.arrive_expect_tx(bytes) ctx.launch_ctx.async_copy( - src_ref=src, dst_ref=dst, barrier=barrier, **copy_params + src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params ) return () @@ -179,7 +220,6 @@ def _copy_gmem_to_smem_lowering( def copy_gmem_to_smem( src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef, - *, barrier: pallas_core.AbstractMemoryRef, ) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. @@ -188,7 +228,7 @@ def copy_gmem_to_smem( :func:`jax.experimental.mosaic.gpu.barrier_arrive` :func:`jax.experimental.mosaic.gpu.barrier_wait` """ - if src.memory_space is not gpu_core.GMEM: + if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM: raise TypeError(f"src must be a GMEM reference, got {src.memory_space}") if dst.memory_space is not gpu_core.SMEM: raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}") @@ -246,15 +286,6 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: raise ValueError("Barrier does not support arbirary transforms") -class MemoryEffect(jax_core.Effect): - ... - - -effects.control_flow_allowed_effects.add_type(MemoryEffect) - -_memory_effect = MemoryEffect() - - barrier_arrive_p = jax_core.Primitive("barrier_arrive") barrier_arrive_p.multiple_results = True @@ -262,7 +293,7 @@ class MemoryEffect(jax_core.Effect): @barrier_arrive_p.def_effectful_abstract_eval def _barrier_arrive_abstract_eval(*avals, **params): del avals, params # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(barrier_arrive_p) @@ -299,7 +330,7 @@ def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None: @barrier_wait_p.def_effectful_abstract_eval def _barrier_wait_abstract_eval(*avals, **params): del avals, params # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(barrier_wait_p) @@ -334,28 +365,31 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: @wait_smem_to_gmem_p.def_effectful_abstract_eval -def _wait_smem_to_gmem_abstract_eval(n): - del n # Unused. - return (), {_memory_effect} +def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): + del n, wait_read_only # Unused. + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(wait_smem_to_gmem_p) -def _wait_smem_to_gmem_lowering(ctx: lowering.LoweringRuleContext, n): - ctx.launch_ctx.await_async_copy(allow_groups=n) +def _wait_smem_to_gmem_lowering( + ctx: lowering.LoweringRuleContext, n, *, wait_read_only +): + ctx.launch_ctx.await_async_copy( + allow_groups=n, await_read_only=wait_read_only + ) return () -def wait_smem_to_gmem(n: int) -> None: - """Waits until there are no more than ``n`` SMEM->GMEM copies in flight.""" - wait_smem_to_gmem_p.bind(n) - - -class _WGMMAPipelineEffect(effects.Effect): - pass +def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None: + """Waits until there are no more than ``n`` SMEM->GMEM copies in flight. + Args: + n: The maximum number of copies in flight to wait for. + wait_read_only: If ``True``, wait for the in flight copies to finish + reading from SMEM. The writes to GMEM are not waited for. + """ + wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only) -_wgmma_pipeline_effect = _WGMMAPipelineEffect() -effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) # WGMMA on an accumulator reference wgmma_ref_p = jax_core.Primitive("wgmma_ref") @@ -367,7 +401,7 @@ def wgmma( a, b: pallas_core.TransformedRef, ) -> None: - """Performs and asynchronous warp group matmul-accumulate on the given references. + """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, except that the computation is performed asynchronously. @@ -419,7 +453,7 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}") return (), { - _wgmma_pipeline_effect, + gpu_core._wgmma_pipeline_effect, state.WriteEffect(0), state.ReadEffect(0), state.ReadEffect(2), @@ -483,6 +517,25 @@ def _wgmma_lowering( gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims ): rhs_transpose = True + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.TransposeRef((1, 0, 2, 3, 4)), + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef(permutation=(1, 0, 2)), + state.types.RefReshaper(shape=new_shape), + ): + if len(rhs_tiling) != 2 or len(new_shape) != 2: + raise ValueError("WGMMA expects shapes 2D tiled into 2D tiles.") + + if any(d % t != 0 for d, t in util.safe_zip(new_shape, rhs_tiling)): + raise ValueError( + f"The last reshape {new_shape} is not divisible by the tiling" + f" {rhs_tiling}." + ) + + high_dims = [d // t for d, t in util.safe_zip(new_shape, rhs_tiling)] + b = mgpu.memref_reshape(b, (*high_dims, *rhs_tiling)) + rhs_transpose = False case _: raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") @@ -510,7 +563,7 @@ def _wgmma_lowering( def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs): del args, kwargs return acc, { - _wgmma_pipeline_effect, + gpu_core._wgmma_pipeline_effect, state.ReadEffect(2), *([state.ReadEffect(1)] if isinstance(lhs_ref, state.AbstractRef) else []) } @@ -526,7 +579,7 @@ def wgmma_wait(n: int): @wgmma_wait_p.def_effectful_abstract_eval def wgmma_wait_effectful_abstract_eval(_): - return [], {_wgmma_pipeline_effect} + return [], {gpu_core._wgmma_pipeline_effect} @lowering.register_lowering_rule(wgmma_wait_p) @@ -549,9 +602,9 @@ def wgmma_accumulator_deref(acc): @wgmma_accumulator_deref_p.def_effectful_abstract_eval def _wgmma_accumulator_deref_abstract_eval(acc): # Dereferencing implies flushing so we have a wgmma pipeline effect. - ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc + ret = acc.inner_aval if isinstance(acc, state.AbstractRef) else acc assert isinstance(ret, jax_core.ShapedArray), acc - return ret, {_wgmma_pipeline_effect} + return ret, {gpu_core._wgmma_pipeline_effect} @discharge.register_discharge_rule(wgmma_accumulator_deref_p) @@ -601,7 +654,7 @@ def layout_cast(x: Any, new_layout: Layout): @set_max_registers_p.def_effectful_abstract_eval def _set_max_registers_abstract_eval(n, *, action): del n, action # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(set_max_registers_p) @@ -629,7 +682,7 @@ def set_max_registers(n: int, *, action: Literal["increase", "decrease"]): @commit_smem_p.def_effectful_abstract_eval def _commit_smem_abstract_eval(): - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(commit_smem_p) @@ -641,3 +694,47 @@ def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): def commit_smem(): """Commits all writes to SMEM, making them visible to loads, TMA and WGMMA.""" commit_smem_p.bind() + + +broadcasted_iota_p = jax_core.Primitive("broadcasted_iota") + +@broadcasted_iota_p.def_abstract_eval +def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): + del layout, dimension + return jax_core.ShapedArray(shape, dtype) + +@lowering.register_lowering_rule(broadcasted_iota_p) +def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout): + del ctx + # Unsigned integers (as opposed to signless) cause MLIR verification + # errors so we only use signless like Mosaic GPU does. + # + # TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead. + mlir_dtype = ( + ir.IntegerType.get_signless(dtype.itemsize * 8) + if jnp.issubdtype(dtype, jnp.integer) + else mlir.dtype_to_ir_type(dtype) + ) + undef = llvm_dialect.mlir_undef(mlir_dtype) + is_signed = ( + jnp.issubdtype(dtype, jnp.signedinteger) + if jnp.issubdtype(dtype, jnp.integer) + else None + ) + + i32 = ir.IntegerType.get_signless(32) + def _cast(x): + if ir.FloatType.isinstance(mlir_dtype): + x = arith_dialect.index_cast(i32, x) + return arith_dialect.uitofp(mlir_dtype, x) + else: + return arith_dialect.index_cast(mlir_dtype, x) + return mgpu.FragmentedArray.splat( + undef, shape, layout.value, is_signed=is_signed + ).foreach( + lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed + ) + + +def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None): + return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2bed4a0830ea..729d0e617a87 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -72,10 +72,6 @@ pallas_call_p.multiple_results = True def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): - if start_idx is None: - assert is_indexing is None - return value - assert is_indexing is not None start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, @@ -84,10 +80,6 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, is_indexing): - if start_idx is None: - assert is_indexing is None - return update - assert is_indexing is not None start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) broadcast_dims = tuple(i for i, b in enumerate(is_indexing) if not b) @@ -234,8 +226,7 @@ def _pallas_call_impl_interpret( for bm in grid_mapping.block_mappings ] block_shapes = [ - None if iid is None - else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) ] @@ -284,8 +275,9 @@ def body(carry): aval = jax_core.get_aval(s) s.aval = aval.update(dtype=jnp.int32) start_indices = [ - None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) - for bm in grid_mapping.block_mappings] + bm.compute_start_indices_interpret(loop_idx, *scalars) + for bm in grid_mapping.block_mappings + ] blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry_consts_ins, is_indexing_dim) with pallas_core.grid_env(local_grid_env): @@ -945,7 +937,15 @@ def get_size(i, x, d): ) for invar in eqn.invars ] - invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars) + try: + invar_raggedness, outvar_raggedness = rule( + eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type] + ) + except Exception as e: + raise RuntimeError( + f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:" + f" {eqn.outvars}. Underlying reason: {e}" + ) from e for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment] if isinstance(invar, jax_core.Var): @@ -1440,6 +1440,17 @@ def _trace_kernel_to_jaxpr( " dialect, instead of Trition IR." ), ) +_PALLAS_VERBOSE_ERRORS = config.bool_flag( + "jax_pallas_verbose_errors", + default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True), + help=( + "If True, print verbose error messages for Pallas kernels." + ), +) + + +def _verbose_errors_enabled() -> bool: + return _PALLAS_VERBOSE_ERRORS.value def _unsupported_lowering_error(platform: str) -> Exception: @@ -1550,12 +1561,6 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) -def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any: - if isinstance(ref_aval, pallas_core.AbstractMemoryRef): - return ref_aval.memory_space - return pallas_core.MemorySpace.ANY - - @state_discharge.register_discharge_rule(pallas_call_p) def _pallas_call_state_discharge_rule( avals_in, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index b41ce3632468..d77ca86c152a 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -68,8 +68,8 @@ def program_id(axis: int) -> jax.Array: """ return program_id_p.bind(axis=axis) -@program_id_p.def_custom_bind -def program_id_bind(*, axis: int): +def program_id_bind_with_trace(trace, _, params): + axis = params.pop("axis") grid_env = pallas_core.current_grid_env() if grid_env: return grid_env[axis].index @@ -77,7 +77,9 @@ def program_id_bind(*, axis: int): # Query the size of the axis to make sure it's a valid axis (and error # otherwise). _ = frame.size(axis) - return jax_core.Primitive.bind(program_id_p, axis=axis) + return jax_core.Primitive.bind_with_trace(program_id_p, trace, (), dict(axis=axis)) +# TODO(dougalm): figure out how put the grid_env contest on the relevant trace +program_id_p.def_bind_with_trace(program_id_bind_with_trace) @program_id_p.def_abstract_eval def _program_id_abstract_eval(**_): @@ -89,8 +91,8 @@ def num_programs(axis: int) -> int | jax.Array: """Returns the size of the grid along the given axis.""" return num_programs_p.bind(axis=axis) -@num_programs_p.def_custom_bind -def _num_programs_bind(*, axis: int): +def _num_programs_bind_with_trace(trace, _, params): + axis = params.pop("axis") # We might be using a local grid env grid_env = pallas_core.current_grid_env() if grid_env: @@ -99,8 +101,9 @@ def _num_programs_bind(*, axis: int): frame = pallas_core.axis_frame() size = frame.size(axis) if size is pallas_core.dynamic_grid_dim: - return jax_core.Primitive.bind(num_programs_p, axis=axis) + return jax_core.Primitive.bind_with_trace(num_programs_p, trace, (), dict(axis=axis)) return size +num_programs_p.def_bind_with_trace(_num_programs_bind_with_trace) @num_programs_p.def_abstract_eval def _num_programs_abstract_eval(**_): @@ -821,14 +824,13 @@ def debug_print_lowering_rule(ctx, *args, **params): # because they should appear as atomic JAX values to the users. # TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU # inferred by the compiler. -@lu.transformation -def wrap_with_transforms(transforms, *args): +@lu.transformation2 +def wrap_with_transforms(f, transforms, *args): new_args = tuple( state_types.TransformedRef(a, t) if t else a for a, t in zip(args, transforms) ) - res = yield new_args, {} - yield res + return f(*new_args) run_scoped_p = jax_core.Primitive("run_scoped") diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index a9babcba0577..84fae3913491 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -76,6 +76,7 @@ pytype_strict_library( ":lowering", "//jax", "//jax:config", + "//jax:core", "//jax:mlir", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index a61dfd61b9b1..097f8497e8f7 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -35,4 +35,4 @@ class TritonCompilerParams(pallas_core.CompilerParams): PLATFORM: ClassVar[str] = "triton" num_warps: int | None = None num_stages: int | None = None - serialized_metadata: str | None = None + serialized_metadata: bytes | None = None diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b0a2b4dbcae0..e2376a457cdf 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge @@ -86,7 +87,7 @@ class ModuleContext: class BlockInfo: full_shape_dtype: jax.ShapeDtypeStruct start_indices: Sequence[Any] - block_shape: tuple[int, ...] # TODO(necula): can this contain "mapped"? + block_shape: tuple[int | pallas_core.Mapped, ...] @dataclasses.dataclass @@ -94,7 +95,7 @@ class LoweringRuleContext: context: ModuleContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] - block_infos: Sequence[BlockInfo | None] # TODO(necula): can this be None? + block_infos: Sequence[BlockInfo | None] replace = dataclasses.replace @@ -247,7 +248,8 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): def _new_ir_context() -> ir.Context: - ctx = ir.Context() + ctx = mlir.JaxIrContext() + ctx.append_dialect_registry(mlir.upstream_dialects) tt_dialect.register_dialect(ctx) ctx.load_all_available_dialects() return ctx @@ -360,14 +362,15 @@ def read_env(atom: jax_core.Atom): def read_block_info_env(atom: jax_core.Atom): if isinstance(atom, jax_core.Literal): return None - return block_info_env.get(atom, None) + return block_info_env.get(atom) def write_env(var: jax_core.Var, val): env[var] = val if block_infos is not None: for invar, block_info in zip(jaxpr.invars, block_infos): - block_info_env[invar] = block_info + if block_info is not None: + block_info_env[invar] = block_info map(write_env, jaxpr.invars, args) @@ -390,6 +393,8 @@ def write_env(var: jax_core.Var, val): except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: + if not pallas_call._verbose_errors_enabled(): + raise inval_types = map(lambda t: getattr(t, "type", None), invals) raise LoweringError( f"Exception while lowering eqn:\n {eqn}\nWith context:\n " @@ -470,14 +475,14 @@ def _atomic_lowering_rule( args_tree, atomic_type: primitives.AtomicOpType, ): + block_info, *_ = ctx.block_infos + assert block_info is not None ptr, indexers, val, mask = args_tree.unflatten(args_flat) *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) != 1: raise NotImplementedError("Only single indexer is supported.") idx = indexers[0] - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, idx) val = _ensure_ir_value(val, value_aval) if mask is not None: mask = _ensure_ir_value(mask, mask_aval) @@ -775,6 +780,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], ), + lax.square_p: lambda ctx, x: _mul(x, x), lax.pow_p: _make_dispatch_table( "pow", cuda=[ @@ -992,7 +998,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): lax.nextafter_p: _make_dispatch_table( "nextafter", cuda=[ - _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32 ), + _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32), _Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64), ], rocm=[ @@ -1463,10 +1469,22 @@ def _float_int_cast( dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: return _not_equal(src, _full(src.type, 0), signed=signed) - elif signed: - return arith_dialect.fptosi(dst_type, src) else: - return arith_dialect.fptoui(dst_type, src) + # We clamp the float value to the min/max integer destination value + # in order to match JAX/XLA casting behavior. Note that this differs + # from numpy casting behavior. + if signed: + maxint = 2**(dst_element_type.width-1) - 1 + minint = -2**(dst_element_type.width-1) + else: + maxint = 2**dst_element_type.width - 1 + minint = 0 + src = arith_dialect.minimumf(src, _full(src.type, maxint)) + src = arith_dialect.maximumf(src, _full(src.type, minint)) + if signed: + return arith_dialect.fptosi(dst_type, src) + else: + return arith_dialect.fptoui(dst_type, src) def _int_float_cast( @@ -1493,10 +1511,12 @@ def _cast( src, _dtype_to_ir_type(dst_type), signed=jnp.issubdtype(src_type, jnp.signedinteger), + dst_signed=jnp.issubdtype(dst_type, jnp.signedinteger), ) -def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: +def _ir_cast(src: ir.Value, dst_type: ir.Type, *, + signed: bool, dst_signed: bool = False) -> ir.Value: if ir.RankedTensorType.isinstance( src.type ) and not ir.RankedTensorType.isinstance(dst_type): @@ -1521,7 +1541,8 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: dst_element_type, ir.F32Type ): return _ir_cast( - _ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False + _ir_cast(src, ir.F32Type.get(), signed=False), + dst_type, signed=False, dst_signed=dst_signed ) if isinstance(src_element_type, ir.FloatType) and isinstance( @@ -1537,7 +1558,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: if isinstance(src_element_type, ir.FloatType) and isinstance( dst_element_type, ir.IntegerType ): - return _float_int_cast(src, dst_type, signed=signed) + return _float_int_cast(src, dst_type, signed=dst_signed) if isinstance(src_element_type, ir.IntegerType) and isinstance( dst_element_type, ir.FloatType ): @@ -1586,8 +1607,9 @@ def select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, y): @register_lowering(lax.broadcast_in_dim_p) def _broadcast_in_dim_lowering_rule( - ctx: LoweringRuleContext, x, *, broadcast_dimensions, shape + ctx: LoweringRuleContext, x, *, broadcast_dimensions, shape, sharding ): + del sharding x = _ensure_ir_value(x, *ctx.avals_in) if not ir.RankedTensorType.isinstance(x.type): return _bcast_to(x, shape) @@ -1603,20 +1625,9 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions): return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None) -def _reshape(x: ir.Value, shape: Sequence[int]) -> ir.Value: - if not shape: - raise ValueError("cannot reshape to an empty shape") - ty = ir.RankedTensorType(x.type) - return tt_dialect.reshape( - ir.RankedTensorType.get(shape, ty.element_type, ty.encoding), - x, - allow_reorder=False, - ) - - @register_lowering(lax.reshape_p) def _reshape_lowering_rule( - ctx: LoweringRuleContext, a, *, new_sizes, dimensions + ctx: LoweringRuleContext, a, *, new_sizes, dimensions, sharding, ): del new_sizes # Unused. if dimensions is not None: @@ -1628,52 +1639,24 @@ def _reshape_lowering_rule( assert all(dim_size == 1 for dim_size in out_aval.shape) return _splat(a, out_aval.shape) - # TODO(slebedev): Check that the following comment still applies. - # Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not - # currently implemented. - dst_shape = [*out_aval.shape] - i = 0 - while ( - ir.RankedTensorType.isinstance(a.type) - and (a_shape := ir.RankedTensorType(a.type).shape) != dst_shape - ): - dim_size = a_shape[i] if i < len(a_shape) else None - dst_dim_size = dst_shape[i] if i < len(dst_shape) else None - if dim_size == dst_dim_size: - i += 1 - elif dst_dim_size == 1: - a = _expand_dims(a, axis=i) - i += 1 - elif dim_size == 1: - in_shape = a_shape - out_shape = tuple(d for di, d in enumerate(a_shape) if di != i) - reduce_ctx = ctx.replace( - avals_in=[ctx.avals_in[0].update(shape=in_shape)], - avals_out=[ctx.avals_in[0].update(shape=out_shape)], - ) - a = _reduce_lowering(jnp.add, reduce_ctx, a, axes=(i,)) - else: # We expect this to fail. - return _reshape(a, dst_shape) + ty = ir.RankedTensorType(a.type) - return a + # Triton Reshape doesn't support scalar result types (only 0d tensors). + if not out_aval.shape: + return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(ty.rank))) + + return tt_dialect.reshape( + ir.RankedTensorType.get([*out_aval.shape], ty.element_type, ty.encoding), + a, + allow_reorder=False, + ) def _compute_pointers_from_indices( - root_ptr: ir.Value, - block_info: BlockInfo | None, - nd_indexer: NDIndexer, - array_shape_dtype: Any, + root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: - if block_info is None: # TODO(necula): is this branch dead? - full_shape = array_shape_dtype.shape - num_mapped_dims = 0 - block_shape = array_shape_dtype.shape - else: - full_shape = block_info.full_shape_dtype.shape - num_mapped_dims = sum( - b is pallas_core.mapped for b in block_info.block_shape - ) - block_shape = block_info.block_shape + full_shape = block_info.full_shape_dtype.shape + num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape) strides = pallas_utils.strides_from_shape(full_shape) indexer_shape = nd_indexer.get_indexer_shape() int_indexer_shape = nd_indexer.int_indexer_shape @@ -1681,14 +1664,10 @@ def _compute_pointers_from_indices( indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] other_shape_idx = 0 - if block_info is None: - start_index_offsets = [None] * len(indices) - else: - start_index_offsets = block_info.start_indices assert len(indices) + num_mapped_dims == len(full_shape) - assert len(start_index_offsets) == len(full_shape) + assert len(block_info.start_indices) == len(full_shape) - array_dtype = jnp.dtype(array_shape_dtype.dtype) + array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype) full_size = math.prod(full_shape) * array_dtype.itemsize # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) @@ -1699,7 +1678,7 @@ def _compute_pointers_from_indices( indexer_iter = iter(indices) for dim_stride, dim_block_size, start_offset in zip( - strides, block_shape, start_index_offsets + strides, block_info.block_shape, block_info.start_indices ): if dim_block_size is pallas_core.mapped: index = _ir_constant(0, offset_eltype) @@ -1859,6 +1838,8 @@ def _masked_load_lowering_rule( cache_modifier, is_volatile, ): + block_info, *_ = ctx.block_infos + assert block_info is not None ptr, indexers, mask, other = args_tree.unflatten(args_flat) *_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) > 1: @@ -1867,9 +1848,7 @@ def _masked_load_lowering_rule( if not tt_dialect.PointerType.isinstance(ptr.type): assert len(ctx.avals_in) == 1 return ptr - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, idx) if mask is not None: mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape()) if other is not None: @@ -1959,14 +1938,14 @@ def _store( def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy ): + block_info, *_ = ctx.block_infos + assert block_info is not None ptr, indexers, value, mask = args_tree.unflatten(args_flat) *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") idx = indexers[0] - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, idx) other = None if value is not None: value = _ensure_ir_value(value, value_aval) @@ -1982,6 +1961,8 @@ def _masked_swap_lowering_rule( @register_lowering(sp.addupdate_p) def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): + block_info, *_ = ctx.block_infos + assert block_info is not None indexers = tree_util.tree_unflatten(tree, idx) if not tt_dialect.PointerType.isinstance(ptr.type): assert len(indexers) == 0 @@ -1989,9 +1970,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") indexer = indexers[0] - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], indexer, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, indexer) op = tt_dialect.RMWOp.FADD if isinstance(_element_type(value.type), ir.IntegerType): op = tt_dialect.RMWOp.ADD @@ -2004,81 +1983,6 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): return tt_dialect.trans(x, permutation) -def _check_dot_operands( - x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any -): - # TODO(slebedev): Ensure that the dtypes are supported by CUDA. - return - - -def _dot( - x: ir.Value, - y: ir.Value, - acc: ir.Value | None = None, - *, - allow_tf32: bool = True, - max_num_imprecise_acc: int | None = None, - out_type: ir.Type | None = None, -) -> ir.Value: - if out_type is None: - out_type = ir.F32Type.get() - elif isinstance(out_type, ir.BF16Type): - raise NotImplementedError(f"unsupported output type: {out_type}") - - x_type = ir.RankedTensorType(x.type) - y_type = ir.RankedTensorType(y.type) - if min(*x_type.shape, *y_type.shape) < 16: - raise ValueError("all dimensions of x and y must be >= 16 ") - if x_type.element_type != y_type.element_type: - raise ValueError( - "x and y must have the same element type, but got:" - f" {x_type.element_type} and {y_type.element_type}" - ) - - _check_dot_operands(x_type, y_type, object()) - - element_type = x_type.element_type - if isinstance(element_type, ir.IntegerType): - if element_type.width != 8: - raise TypeError(f"unsupported element type: {element_type}") - element_type = ir.IntegerType.get_signless(32) - elif isinstance(element_type, (ir.F32Type, ir.BF16Type)): - element_type = ir.F32Type.get() - else: - element_type = out_type - - if element_type != out_type: - raise TypeError( - f"output type {out_type} does not match element type {element_type}" - ) - - m, _ = x_type.shape - _, n = y_type.shape - - if acc is None: - acc = _full(ir.RankedTensorType.get([m, n], element_type), 0) - - if max_num_imprecise_acc is None: - if isinstance(element_type, ir.FloatType) and element_type.width == 8: - # TODO(slebedev): Fill in from options. - raise NotImplementedError - else: - max_num_imprecise_acc = 0 - - # Ideally, replace all allow_tf32 usages with InputPrecision directly. - input_precision = tt_dialect.InputPrecision.IEEE - if allow_tf32: - input_precision = tt_dialect.InputPrecision.TF32 - - return tt_dialect.dot( - x, - y, - acc, - max_num_imprecise_acc=max_num_imprecise_acc, - input_precision=input_precision - ) - - _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) @@ -2089,10 +1993,11 @@ def _dot_general_lowering( b, *, dimension_numbers, + out_type, precision, preferred_element_type, ): - del preferred_element_type # Unused. + del preferred_element_type, out_type # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) @@ -2101,27 +2006,63 @@ def _dot_general_lowering( if b_contract_dim == 1: b = tt_dialect.trans(b, (1, 0)) - if precision is None: - allow_tf32 = True + a_aval, b_aval = ctx.avals_in + [out_aval] = ctx.avals_out + + if precision is None or (precision == lax.DotAlgorithmPreset.DEFAULT): + precision = (lax.Precision.DEFAULT, lax.Precision.DEFAULT) + + if isinstance(precision, lax.DotAlgorithmPreset): + match precision: + case lax.DotAlgorithmPreset.TF32_TF32_F32: + input_precision = tt_dialect.InputPrecision.TF32 + case lax.DotAlgorithmPreset.TF32_TF32_F32_X3: + input_precision = tt_dialect.InputPrecision.TF32x3 + case lax.DotAlgorithmPreset.F32_F32_F32: + input_precision = tt_dialect.InputPrecision.IEEE + case ( + lax.DotAlgorithmPreset.F16_F16_F16 + | lax.DotAlgorithmPreset.F16_F16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_BF16 + | lax.DotAlgorithmPreset.BF16_BF16_F32 + ): + input_precision = None + case _: + raise NotImplementedError(f"Unsupported dot algorithm: {precision}.") + + a = _cast(a, a_aval.dtype, precision.supported_lhs_types[0]) + b = _cast(b, b_aval.dtype, precision.supported_rhs_types[0]) + acc_dtype = precision.accumulation_type + elif isinstance(precision, tuple): + a_precision, b_precision = precision + if a_precision in _TF32_PRECISIONS or b_precision in _TF32_PRECISIONS: + input_precision = tt_dialect.InputPrecision.TF32 + elif a_aval.dtype == jnp.float32: + input_precision = tt_dialect.InputPrecision.IEEE + else: + input_precision = None + + acc_dtype = out_aval.dtype + if acc_dtype != jnp.int32 and acc_dtype != jnp.float16: + acc_dtype = jnp.float32 else: - prec_a, prec_b = precision - allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS + raise NotImplementedError(f"Unsupported dot precision: {precision}.") - [out_aval] = ctx.avals_out - out_dtype = acc_dtype = out_aval.dtype - if acc_dtype != jnp.int32 and acc_dtype != jnp.float16: - acc_dtype = jnp.dtype(jnp.float32) - - return _cast( - _dot( - a, - b, - allow_tf32=allow_tf32, - out_type=_dtype_to_ir_type(acc_dtype), - ), - acc_dtype, - out_dtype, - ) + a_type = ir.RankedTensorType(a.type) + b_type = ir.RankedTensorType(b.type) + if min(*a_type.shape, *b_type.shape) < 16: + raise ValueError("all dimensions of a and b must be >= 16 ") + if a_type.element_type != b_type.element_type: + raise ValueError( + "a and b must have the same element type, but got:" + f" {a_type.element_type} and {b_type.element_type}" + ) + + m, _ = a_type.shape + _, n = b_type.shape + acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) + return _cast(acc, acc_dtype, out_aval.dtype) def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes): @@ -2643,8 +2584,30 @@ def _i64_constant(v: int) -> ir.Value: return arith_dialect.constant(ir.IntegerType.get_signless(64), v) -def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: +def _dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type: + dtype = jnp.dtype(dtype) if jnp.issubdtype(dtype, np.integer): # All integer types in Triton are signless. return ir.IntegerType.get_signless(dtype.itemsize * 8) return mlir.dtype_to_ir_type(dtype) + + +@register_lowering(lax.bitcast_convert_type_p) +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, operand: ir.Value, *, new_dtype +) -> ir.Value: + # TODO(petebu) Handle case where src and dst types have different bitwidths + src_elem_type = _element_type(operand.type) + dst_elem_type = _element_type(_dtype_to_ir_type(new_dtype)) + assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) + assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) + if src_elem_type.width != dst_elem_type.width: + raise NotImplementedError( + f"cannot cast {operand} to {new_dtype} because of different widths" + ) + if ir.RankedTensorType.isinstance(operand.type): + shape = ir.RankedTensorType(operand.type).shape + result_type = ir.RankedTensorType.get(shape, dst_elem_type) + else: + result_type = dst_elem_type + return tt_dialect.bitcast(result_type, operand) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 67b0bd326616..1805f8c0923a 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -19,7 +19,7 @@ import io from typing import Any -from jax import core as jax_core +import jax._src.core as jax_core from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 23fce50dc4f9..b845a4079ff4 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -19,7 +19,7 @@ from collections.abc import Sequence import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas.triton import lowering diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index e485537216ca..0dc19aa75fb6 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -301,3 +301,91 @@ def sign_lowering_helper(x): return jnp.where(jnp.isnan(x), jnp.nan, out) raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}") + + +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L1339-L1422 +def nextafter_lowering_helper(x, y): + if x.dtype != y.dtype: + raise ValueError( + "The two inputs to `nextafter` must have the same dtype, but got" + f" {x.dtype} and {y.dtype}" + ) + + if x.dtype not in (jnp.float32, jnp.float64): + raise ValueError( + f"`nextafter` only supports float32 and float64, but got {x.dtype}" + ) + + jnp_float, jnp_uint, np_float, np_uint, np_int = ( + jnp.float32, jnp.uint32, np.float32, np.uint32, np.int32, + ) if x.dtype == jnp.float32 else ( + jnp.float64, jnp.uint64, np.float64, np.uint64, np.int64, + ) + + bitwidth = dtype_bitwidth(x.dtype) + + x_as_int = x.view(jnp_uint) + y_as_int = y.view(jnp_uint) + + # The result is NaN if either "x" or "y" are NaN. + nan_input = jnp.isnan(x) | jnp.isnan(y) + result_for_nan = jnp.full_like(x_as_int, np_float(np.nan).view(np_uint)) + + # The sign bit is the MSB. + sign_bit = jnp_uint(1 << (bitwidth - 1)) + # Discard the sign bit to make the result non-negative. + sign_mask = sign_bit + negated_sign_mask = ~sign_bit + x_abs = x_as_int & negated_sign_mask + y_abs = y_as_int & negated_sign_mask + + # When both "x" and "y" are equal, the result is "y". + x_and_y_are_equal = x == y + result_for_equal = y_as_int + + # When both "x" and "y" are 0, the result is "y". This is a separate case + # from above because "x" and "y" might have a different sign. + zero = jnp.zeros_like(x_as_int) + x_is_zero = x_abs == zero + y_is_zero = y_abs == zero + result_for_both_zero = y_as_int + + x_sign = x_as_int & sign_mask + y_sign = y_as_int & sign_mask + + # If x == 0 && y != 0, we need to return the smallest subnormal number + # signed like "y". + one = jnp.ones_like(x_as_int) + result_for_x_zero_y_non_zero = y_sign | one + + # If the sign of "x" and "y" disagree: + # - we need to make the magnitude of "from" smaller so that it is closer to + # zero. + # + # Otherwise the signs agree: + # - "x" with a magnitude larger than "y" means we need to make the magnitude + # smaller. + # - "x" with a magnitude smaller than "y" means we need to make the magnitude + # larger. + signs_disagree = x_sign != y_sign + x_magnitude_larger_than_y = x_abs > y_abs + result_has_smaller_magnitude = x_magnitude_larger_than_y | signs_disagree + minus_one = jnp.full_like(x_as_int, np_int(-1).view(np_uint)) + magnitude_adjustment = jnp.where(result_has_smaller_magnitude, minus_one, one) + result = x_as_int + magnitude_adjustment + + # Handle x == +-0. + result = jnp.where( + x_is_zero, + jnp.where(y_is_zero, result_for_both_zero, result_for_x_zero_y_non_zero), + result, + ) + + # Handle x == y. + result = jnp.where(x_and_y_are_equal, result_for_equal, result) + + # Handle isnan(x) || isnan(y). + result = jnp.where(nan_input, result_for_nan, result) + + # Cast back to the original type. + return result.view(jnp_float) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 18e7d18d931d..f9bc2b60cee9 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + class _UnconstrainedPartitionSingleton: def __repr__(self): @@ -48,3 +50,21 @@ def __repr__(self): def __reduce__(self): return (PartitionSpec, tuple(self)) + + def _normalized_spec(self, ndim: int) -> PartitionSpec: + out = [] # type: ignore + for p in self: + if p is None: + out.append(None) + elif p == self.UNCONSTRAINED: + out.append(p) + elif isinstance(p, (list, tuple)): + if len(p) == 1: + out.append(p[0]) + else: + out.append(tuple(p)) + else: + out.append(p) + if len(out) < ndim: + out.extend([None] * (ndim - len(out))) + return PartitionSpec(*out) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a69e8987b2d8..5f5a2b8b7692 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -23,7 +23,6 @@ import operator as op import weakref from typing import NamedTuple, Any, Union, cast -import threading import warnings import numpy as np @@ -67,8 +66,7 @@ from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, - ParsedPartitionSpec, get_single_pspec, is_unspecified, - is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) + ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -165,6 +163,7 @@ class PjitInfo(NamedTuple): inline: bool abstracted_axes: Any | None use_resource_env: bool # False for jit, True for pjit + compiler_options_kvs: tuple[tuple[str, Any], ...] # Hash and compare PjitInfo by identity when used as a cache key. def __hash__(self): @@ -185,7 +184,19 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - out_flat = pjit_p.bind(*args_flat, **p.params) + # TODO(yashkatariya): Maybe thread this into pjit params like resource_env + # and set the context manager down the stack? + with mesh_lib.set_abstract_mesh(p.abstract_mesh): + if (core.trace_state_clean() and + not config.debug_key_reuse.value and + not config.data_dependent_tracing_fallback.value): + args_flat = map(core.full_lower, args_flat) + core.check_eval_args(args_flat) + out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) + else: + out_flat = pjit_p.bind(*args_flat, **p.params) + compiled = None + profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if p.params['resource_env'] is None else 'pjit' @@ -215,7 +226,8 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): _set_states(p.attrs_tracked, final_states) outs = tree_unflatten(p.out_tree, out_flat) - return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked + return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], + p.attrs_tracked, compiled, profiler) def _set_states(attrs_tracked, vals): @@ -286,21 +298,6 @@ def _get_fastpath_data( return fastpath_data -class _MostRecentPjitCallExecutable(threading.local): - def __init__(self): - self.weak_key_dict = weakref.WeakKeyDictionary() - self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary() - -_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable() - - -def _read_most_recent_pjit_call_executable(jaxpr): - return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None) - - -def _read_pgle_profiler(jaxpr): - return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None) - def _cpp_pjit_evict_fn(self): self._clear_cache() _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error @@ -335,10 +332,10 @@ def cache_miss(*args, **kwargs): if config.no_tracing.value: raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( - fun, jit_info, *args, **kwargs) - executable = _read_most_recent_pjit_call_executable(jaxpr) - pgle_profiler = _read_pgle_profiler(jaxpr) + + (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable, + pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) + maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, @@ -358,7 +355,8 @@ def cache_miss(*args, **kwargs): in_layouts_leaves=jit_info.in_layouts_leaves, out_layouts_treedef=jit_info.out_layouts_treedef, out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env) + use_resource_env=jit_info.use_resource_env, + compiler_options_kvs=jit_info.compiler_options_kvs) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, cache_key, tree_util.dispatch_registry, @@ -399,7 +397,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames: str | Iterable[str] | None, device: xc.Device | None, backend: str | None, abstracted_axes: Any | None, keep_unused: bool, - inline: bool, use_resource_env: bool) -> PjitInfo: + inline: bool, compiler_options: dict[str, Any] | None, + use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. Performs any preprocessing and validation of the arguments that we can do @@ -418,10 +417,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") - if in_shardings is not None and not is_unspecified(in_shardings): + if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue): raise ValueError('If backend or device is specified on jit, then ' 'in_shardings should not be specified.') - if out_shardings is not None and not is_unspecified(out_shardings): + if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue): raise ValueError('If backend or device is specified on jit, then ' 'out_shardings should not be specified.') @@ -440,7 +439,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') user_specified_in_shardings = (in_shardings is not None and - not is_unspecified(in_shardings)) + not isinstance(in_shardings, UnspecifiedValue)) in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings) out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings) @@ -454,6 +453,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) + compiler_options_kvs = (() if compiler_options is None else + tuple(compiler_options.items())) return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -471,7 +472,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - use_resource_env=use_resource_env) + use_resource_env=use_resource_env, + compiler_options_kvs=compiler_options_kvs) def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @@ -483,7 +485,7 @@ def lower(*args, **kwargs): @api_boundary def eval_shape(*args, **kwargs): p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] + out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, weak_type=x.weak_type) @@ -496,10 +498,10 @@ def trace(*args, **kwargs) -> stages.Traced: donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) + pgle_profiler=None) return stages.Traced( p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) + lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts) wrapped = _cpp_pjit(fun, jit_info) wrapped.lower = lower @@ -515,12 +517,13 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames: str | Iterable[str] | None, device: xc.Device | None, backend: str | None, abstracted_axes: Any | None, keep_unused: bool, - inline: bool, use_resource_env: bool) -> Any: + inline: bool, compiler_options: dict[str, Any] | None, + use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env) + keep_unused, inline, compiler_options, use_resource_env) return _make_jit_wrapper(fun, jit_info) @@ -534,6 +537,7 @@ class PjitParams(NamedTuple): arg_names: tuple[str, ...] | None num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + abstract_mesh: AbstractMesh def _infer_params_impl( @@ -638,10 +642,15 @@ def _infer_params_impl( in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) attr_token = _attr_token(flat_fun, in_type) - jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( - flat_fun, in_type, attr_token, dbg, - HashableFunction(res_paths, closure=()), - IgnoreKey(ji.inline)) + + abstract_mesh = ( + get_abstract_mesh_from_avals(in_type) + if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh()) + with mesh_lib.set_abstract_mesh(abstract_mesh): + jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( + flat_fun, in_type, attr_token, dbg, + HashableFunction(res_paths, closure=()), + IgnoreKey(ji.inline)) _attr_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( @@ -677,10 +686,28 @@ def _infer_params_impl( name=fun_qual_name(flat_fun), keep_unused=ji.keep_unused, inline=ji.inline, + compiler_options_kvs=ji.compiler_options_kvs, ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names if dbg else None, len(consts), - attrs_tracked), args_flat + attrs_tracked, abstract_mesh), args_flat + + +def get_abstract_mesh_from_avals(in_avals): + if not config.sharding_in_types.value: + return None + m = None + for a in in_avals: + # TODO(yashkatariya): Remove this when mesh context can be set by the user. + if a.sharding is None: # type: ignore + continue + if m is not None and m != a.sharding.mesh: + raise ValueError( + f'Mesh for all inputs should be equal. Got one mesh: {m} and' + f' another mesh: {a.sharding.mesh}') + m = a.sharding.mesh # type: ignore + assert isinstance(m, AbstractMesh) + return m class InferParamsCacheEntry: @@ -816,6 +843,7 @@ def pjit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, + compiler_options: dict[str, Any] | None = None, ) -> JitWrapped: """Makes ``fun`` compiled and automatically partitioned across multiple devices. @@ -988,7 +1016,7 @@ def pjit( return make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env=True) + keep_unused, inline, compiler_options, use_resource_env=True) def hashable_pytree(pytree): @@ -1001,7 +1029,7 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x): + if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)): return x if mesh is None: msg = ('jax.jit only supports `Sharding`s being passed to' @@ -1110,7 +1138,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves) # Only do this if original in_shardings are unspecified. If it is AUTO, go # via flatten_axis_resources. - if is_unspecified(orig_in_shardings): + if isinstance(orig_in_shardings, UnspecifiedValue): in_shardings_flat = (orig_in_shardings,) * len(in_avals) else: in_shardings_flat = flatten_axis_resources( @@ -1312,8 +1340,7 @@ def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) - if (is_unspecified(orig_out_shardings) or - isinstance(orig_out_shardings, sharding.Sharding)): + if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_avals) else: out_shardings_flat = flatten_axis_resources( @@ -1391,7 +1418,7 @@ def pjit_check_aval_sharding( what_aval: str, allow_uneven_sharding: bool): new_names = [''] * len(shardings) if names is None else names for aval, s, name in zip(flat_avals, shardings, new_names): - if is_unspecified_or_auto(s): + if isinstance(s, (UnspecifiedValue, AUTO)): continue name_str = f' with pytree key path {name}' if name else '' shape = aval.shape @@ -1439,7 +1466,7 @@ def check_aval_layout_compatibility( # -------------------- pjit rules -------------------- -pjit_p = core.AxisPrimitive("pjit") +pjit_p = core.Primitive("pjit") pjit_p.multiple_results = True @@ -1466,7 +1493,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): else: arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. - is_pmap_sharding = (is_unspecified(rs) or + is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or isinstance(getattr(arg, 'sharding', None), PmapSharding)) if jit_in_l is None: if committed: @@ -1527,15 +1554,15 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if getattr(a, '_committed', True): committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) - resolved_in_shardings = [] + resolved_in_shardings: list[PjitSharding] = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) if hasattr(arg, 'sharding') and arg.sharding is not None else (UNSPECIFIED, False)) - if is_unspecified(pjit_in_s): - if is_unspecified(arg_s): + if isinstance(pjit_in_s, UnspecifiedValue): + if isinstance(arg_s, UnspecifiedValue): resolved_in_shardings.append(arg_s) else: if committed: @@ -1553,7 +1580,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'multiple devices is not supported.') else: if (isinstance(arg, np.ndarray) and - not pjit_in_s.is_fully_replicated and # type: ignore + not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] xb.process_count() > 1): raise ValueError( 'Passing non-trivial shardings for numpy ' @@ -1572,16 +1599,16 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, # either outside or inside jax.jit. - if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore + if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr] raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' - f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore + f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr] f'arg memory kind: {arg_s.memory_kind} for ' f'arg shape: {shaped_abstractify(arg).str_short()}') if (committed and not isinstance(arg_s, PmapSharding) and not op_shardings.are_op_shardings_equal( - pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore + pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr] arg_s._to_xla_hlo_sharding(arg.ndim))): raise ValueError('Sharding passed to pjit does not match the sharding ' 'on the respective arg. ' @@ -1596,43 +1623,44 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_platforms, lowering_parameters, pgle_profiler): + lowering_platforms, lowering_parameters, pgle_profiler, + compiler_options_kvs): in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) - lowered = _pjit_lower( + return _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, name, keep_unused, inline, + donated_invars, name, keep_unused, inline, compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) - return lowered + +_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): - global _most_recent_pjit_call_executable - - compile_options = None - pgle_profiler = None - pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): + pgle_compile_options, pgle_profiler = {}, None if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: - if jaxpr not in pgle_profiler_dict: - pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler( + compilation_target_key = jaxpr + pgle_profiler = _pgle_profiler_dict.get(compilation_target_key) + if pgle_profiler is None: + pgle_profiler = profiler.PGLEProfiler( config.pgle_profiling_runs.value, config.pgle_aggregation_percentile.value) + _pgle_profiler_dict[compilation_target_key] = pgle_profiler - pgle_profiler = pgle_profiler_dict[jaxpr] # The method below will return FDO profile when module was profiled # config.jax_pgle_profiling_runs amount of times, otherwise the result will # be None. fdo_profile = pgle_profiler.consume_fdo_profile() if fdo_profile is not None: - compile_options = {'fdo_profile': fdo_profile} + pgle_compile_options['fdo_profile'] = fdo_profile - # TODO(patrios): Do not pass mutable profile session through cached lowering - # chain. Instead we need to move profilers dictionary to pxla module and use - # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. + compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) + # Passing mutable PGLE profile here since it should be extracted by JAXPR to + # initialize the fdo_profile compile option. compiled = _resolve_and_lower( args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, @@ -1640,10 +1668,10 @@ def _pjit_call_impl_python( donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline, lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), - pgle_profiler=pgle_profiler - ).compile(compile_options) + pgle_profiler=pgle_profiler, + compiler_options_kvs=compiler_options_kvs, + ).compile() - _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.enable_checks.value: pxla.check_array_xla_sharding_layout_match( @@ -1665,7 +1693,7 @@ def _pjit_call_impl_python( ("abstract args", map(xla.abstractify, args)), ("fingerprint", fingerprint)) try: - return compiled.unsafe_call(*args), compiled + return compiled.unsafe_call(*args), compiled, pgle_profiler except FloatingPointError as e: assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case @@ -1695,7 +1723,7 @@ def _pjit_call_impl_python( @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): + keep_unused, inline, compiler_options_kvs): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to # the jaxpr defeating the purpose of weakref_lru_cache. So return a function @@ -1708,16 +1736,15 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, def _pjit_call_impl(*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, - donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): - out_flat, compiled = _pjit_call_impl_python( + out_flat, compiled, pgle_profiler = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) - pgle_profiler = _read_pgle_profiler(jaxpr) + inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, jaxpr.consts, None, pgle_profiler) @@ -1725,7 +1752,8 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline) + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) cache_key = pxla.JitGlobalCppCacheKeys( donate_argnums=donated_argnums, donate_argnames=None, @@ -1743,12 +1771,7 @@ def call_impl_cache_miss(*args_, **kwargs_): pjit_p.def_impl(_pjit_call_impl) -def _pjit_lower(*args, **kwargs): - return _pjit_lower_cached(*args, **kwargs) - - -@weakref_lru_cache -def _pjit_lower_cached( +def _pjit_lower( jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, @@ -1759,16 +1782,23 @@ def _pjit_lower_cached( name: str, keep_unused: bool, inline: bool, + compiler_options_kvs: tuple[tuple[str, Any], ...], *, lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - mesh, api_name = ((resource_env.physical_mesh, 'pjit') - if resource_env is not None else (None, 'jit')) + if config.sharding_in_types.value: + cur_mesh = mesh_lib.get_concrete_mesh() + mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None + api_name = 'jit' + else: + mesh, api_name = ((resource_env.physical_mesh, 'pjit') + if resource_env is not None else (None, 'jit')) return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), keep_unused=keep_unused, context_mesh=mesh, + compiler_options_kvs=compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) @@ -1780,16 +1810,17 @@ def pjit_staging_rule(trace, *args, **params): params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) if (params["inline"] and - all(is_unspecified(i) for i in params["in_shardings"]) and - all(is_unspecified(o) for o in params["out_shardings"]) and + all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and + all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # but redundantly performs abstract evaluation again. - out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) + with core.set_current_trace(trace): + out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, + propagate_source_info=False) else: out_tracers = pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) @@ -1809,7 +1840,7 @@ def pjit_staging_rule(trace, *args, **params): trace.frame.add_eqn(eqn) elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(trace.instantiate_const, consts) + consts = map(trace.new_const, consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) @@ -1830,7 +1861,7 @@ def pjit_staging_rule(trace, *args, **params): def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) @@ -1896,8 +1927,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: - arg_shardings = [None if is_unspecified(i) else i for i in in_shardings] - result_shardings = [None if is_unspecified(o) else o for o in out_shardings] + arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings] + result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings] # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. @@ -1912,7 +1943,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, keep_unused, inline): + donated_invars, keep_unused, inline, compiler_options_kvs): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1938,14 +1969,12 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, mlir.register_lowering(pjit_p, _pjit_lowering) -def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, - vals_in, dims_in, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): +def _pjit_batcher(axis_data, vals_in, dims_in, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) - new_jaxpr, axes_out = batching.batch_jaxpr2( - jaxpr, axis_size, dims_in, axis_name=axis_name, - spmd_axis_name=spmd_axis_name, main_type=main_type) + new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) if resource_env is not None: mesh = resource_env.physical_mesh @@ -1954,11 +1983,11 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( - _pjit_batcher_for_sharding(i, axis_in, spmd_axis_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim) if axis_in is not None else i for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) out_shardings = tuple( - _pjit_batcher_for_sharding(o, axis_out, spmd_axis_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) # TODO(yashkatariya): Figure out layouts should change under vmap. @@ -1978,21 +2007,22 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( vals_in, vals_out, axes_out) return vals_out, resolved_axes_out -batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher -batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None) +batching.fancy_primitive_batchers[pjit_p] = _pjit_batcher +batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( s: sharding.Sharding | UnspecifiedValue, dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): - if is_unspecified(s): + if isinstance(s, UnspecifiedValue): return s - hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore + hlo_s = s._to_xla_hlo_sharding(ndim) if spmd_axis_name is None: if sharding_impls.is_op_sharding_replicated(hlo_s): return s @@ -2004,7 +2034,7 @@ def _pjit_batcher_for_sharding( tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad new_gs = GSPMDSharding( - s._device_assignment, new_op, # type: ignore + s._device_assignment, new_op, _device_list=getattr(s, '_internal_device_list', None)) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: @@ -2028,7 +2058,8 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): if any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr) mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals) @@ -2060,7 +2091,8 @@ def _filter_zeros(is_nz_l, l): donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)]) assert len(primals_out) == len(jaxpr.jaxpr.outvars) @@ -2070,25 +2102,72 @@ def _filter_zeros(is_nz_l, l): ad.primitive_jvps[pjit_p] = _pjit_jvp -@weakref_lru_cache -def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, - in_fwd: tuple[int | None, ...]) -> core.ClosedJaxpr: - updated_jaxpr = known_jaxpr.jaxpr.replace( - outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd) - if i is None]) - return known_jaxpr.replace(jaxpr=updated_jaxpr) +def _pjit_linearization(nzs, *primals_in, jaxpr, + in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): + primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) + # constvars will become residuals. Move them to the end of the ordinary args. + res_shardings = (UNSPECIFIED,) * num_residuals + res_layouts = (None,) * num_residuals + res_donated = (False,) * num_residuals + def tangent_fun(consts_, *tangents): + tangents_nz = _filter_zeros(nzs, tangents) + assert len(consts_) == num_residuals + return pjit_p.bind(*(*tangents_nz, *consts_), + jaxpr=tangent_jaxpr, + in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, + out_shardings=_filter_zeros(nzs_out, out_shardings), + in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts, + out_layouts=_filter_zeros(nzs_out, out_layouts), + resource_env=resource_env, + donated_invars=_filter_zeros(nzs, donated_invars) + res_donated, + name=name, + keep_unused=keep_unused, + inline=inline, + compiler_options_kvs=compiler_options_kvs) + + def _filter_zeros(is_nz_l, l): + return tuple(x for nz, x in zip(is_nz_l, l) if nz) + + ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, + in_shardings=in_shardings, + out_shardings=(*res_shardings, *out_shardings), + in_layouts=in_layouts, + out_layouts=(*res_layouts, *out_layouts), + resource_env=resource_env, + donated_invars=donated_invars, + name=name, + keep_unused=keep_unused, + inline=inline, + compiler_options_kvs=compiler_options_kvs) + residuals_ans, primal_ans = split_list(ans, [num_residuals]) + + return primal_ans, nzs_out, residuals_ans, tangent_fun + +ad.primitive_linearizations[pjit_p] = _pjit_linearization def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, - name, keep_unused, inline): + name, keep_unused, inline, compiler_options_kvs): in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect)) + for e in jaxpr.effects): + known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \ + pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins, + False, False, None) + if num_res_ref: raise NotImplementedError + known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts) + unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts) + res_avals = unknown_jaxpr.in_avals[:num_res_val] + else: + known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) @@ -2107,7 +2186,7 @@ def keep_where(l, should_keep): # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) in_fwd = [ - fwd if is_unspecified(os) and ol is None else None + fwd if isinstance(os, UnspecifiedValue) and ol is None else None for os, ol, fwd in zip( keep_where(out_shardings, known_outs), keep_where(out_layouts, known_outs), in_fwd_primal) @@ -2140,7 +2219,8 @@ def keep_where(l, should_keep): in_layouts=keep_where(in_layouts, known_ins), out_layouts=known_out_layouts, resource_env=resource_env, donated_invars=keep_where(donated_invars, known_ins), - name=name, keep_unused=keep_unused, inline=inline) + name=name, keep_unused=keep_unused, inline=inline, + compiler_options_kvs=compiler_options_kvs) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals) @@ -2174,7 +2254,8 @@ def keep_where(l, should_keep): (False,) * num_residuals), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] unknown_out_avals = unknown_jaxpr.out_avals unknown_tracers_out = [ @@ -2254,7 +2335,8 @@ def _pjit_transpose_trace(fun, in_avals): def _pjit_transpose(cts_in, *primals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -2305,14 +2387,15 @@ def prune_type(ty, xs, maybe_zeros): donated_invars=(False,) * len(primals_and_nz_cts_in), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) _set_states(attrs_tracked, final_states) return tree_unflatten(cts_out_treedef, nz_cts_out) -ad.reducing_transposes[pjit_p] = _pjit_transpose +ad.primitive_transposes[pjit_p] = _pjit_transpose @weakref_lru_cache @@ -2325,6 +2408,10 @@ def _dce_jaxpr_pjit( def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + dced_jaxpr, used_inputs = _dce_jaxpr_pjit( eqn.params['jaxpr'], tuple(used_outputs)) @@ -2358,9 +2445,9 @@ def _pjit_pp_rule(eqn, context, settings): del params['inline'] if not any(params['donated_invars']): del params['donated_invars'] - if all(is_unspecified(s) for s in params['in_shardings']): + if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']): del params['in_shardings'] - if all(is_unspecified(s) for s in params['out_shardings']): + if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']): del params['out_shardings'] if all(l is None for l in params['in_layouts']): del params['in_layouts'] @@ -2371,6 +2458,8 @@ def _pjit_pp_rule(eqn, context, settings): if (params['resource_env'] is None or params['resource_env'].physical_mesh.empty): del params['resource_env'] + if not params['compiler_options_kvs']: + del params['compiler_options_kvs'] # Move name= to the front to make the resulting equation easier to scan. del params["name"] @@ -2382,8 +2471,7 @@ def _pjit_pp_rule(eqn, context, settings): def _pjit_state_discharge_rule( in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, **params): - if not (all(map(is_unspecified, in_shardings)) and - all(map(is_unspecified, out_shardings))): + if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)): raise NotImplementedError if not (all(l is None for l in in_layouts) and @@ -2451,12 +2539,18 @@ def with_sharding_constraint(x, shardings): shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings', 'with_sharding_constraint') for a in user_shardings_flat] + for s, u in zip(shardings_flat, user_shardings_flat): + if isinstance(s, (UnspecifiedValue, AUTO)): + raise ValueError( + f'One of with_sharding_constraint arguments got sharding {u} which is' + ' not allowed. Please only pass `jax.sharding.Sharding` instances.') + del user_shardings_flat + # TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's # already part of the shardings. unconstrained_dims = [get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {} for s in shardings_flat] - del user_shardings_flat pjit_check_aval_sharding( shardings_flat, x_flat, None, "with_sharding_constraint arguments", @@ -2537,24 +2631,23 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, def _sharding_constraint_batcher( - spmd_axis_name, axis_size, axis_name, main_type, vals_in, - dims_in, sharding, layout, resource_env, unconstrained_dims): - if spmd_axis_name is not None and isinstance(sharding, NamedSharding): + axis_data, vals_in, dims_in, sharding, layout, resource_env, unconstrained_dims): + if axis_data.spmd_name is not None and isinstance(sharding, NamedSharding): used = {n for ns in sharding.spec for n in (ns if isinstance(ns, tuple) else (ns,))} - if set(spmd_axis_name) & used: - raise ValueError(f"vmap spmd_axis_name {spmd_axis_name} cannot appear in " + if set(axis_data.spmd_name) & used: + raise ValueError(f"vmap spmd_axis_name {axis_data.spmd_name} cannot appear in " "with_sharding_constraint spec, but got spec " f"{sharding.spec}") x, = vals_in d, = dims_in - + # None means unconstrained in ParsedPartitionSpec unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims} - if spmd_axis_name is None: + if axis_data.spmd_name is None: unconstrained_dims.add(d) vmapped_sharding = _pjit_batcher_for_sharding( - sharding, d, spmd_axis_name, resource_env.physical_mesh, x.ndim) + sharding, d, axis_data.spmd_name, resource_env.physical_mesh, x.ndim) if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding): new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec)) for u in unconstrained_dims: @@ -2575,9 +2668,9 @@ def _sharding_constraint_batcher( resource_env=resource_env, unconstrained_dims=unconstrained_dims) return y, d -batching.spmd_axis_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher -batching.axis_primitive_batchers[sharding_constraint_p] = partial( - _sharding_constraint_batcher, None) +batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher +batching.skippable_batchers[sharding_constraint_p] = lambda _: () + # -------------------- helpers -------------------- diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7ca7db022d89..2256e12da1d4 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -279,6 +279,11 @@ def copy(self): __hash__ = None # type: ignore[assignment] __array_priority__ = 100 + def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray: + raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array." + " Use jax.random.key_data(arr) if you wish to extract the underlying" + " integer array.") + # Overwritten immediately below @property def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override] @@ -463,12 +468,13 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts): +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts, + copy_semantics): arrs = [x._base_array for x in xs] phys_shardings = [physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] # TODO(yashkatariya): `layouts` should be converted to physical layouts. - return pxla.shard_args(phys_shardings, layouts, arrs) + return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler @@ -807,7 +813,7 @@ def _threefry2x32_abstract_eval(*args): shape = lax_internal.broadcasting_shape_rule(*args) aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32)) else: - aval = core.UnshapedArray(jnp.dtype(jnp.uint32)) + raise TypeError(f"Arguments to threefry2x32 must all be arrays, got {args}") return (aval,) * 2 @@ -885,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): return tuple(x) -_threefry2x32_lowering_rule = mlir.lower_fun( +# Since the unrolled lowering is large, emit it as an out-of-line function. +_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True) + multiple_results=True)) _threefry2x32_cpu_lowering_rule = mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), @@ -1067,8 +1074,9 @@ def threefry_2x32(keypair, count): odd_size = count.size % 2 if not isinstance(odd_size, int): - msg = ("jax.random functions have limited support for shape polymorphism. " - "In particular, the product of the known dimensions must be even.") + msg = ("jax.random functions have limited support for shape polymorphism " + "when using threefry. " + f"In particular, the array size ({count.size}) must be even.") raise core.InconclusiveDimensionOperation(msg) if odd_size: diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64cda2..6bbcdd08471f 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/random.py b/jax/_src/random.py index 203f72d406e5..4313d9036eda 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -55,8 +55,6 @@ Shape = Sequence[int] PRNGImpl = prng.PRNGImpl -KeyArray = Array -KeyArrayLike = ArrayLike UINT_DTYPES = prng.UINT_DTYPES @@ -69,8 +67,8 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -def _check_prng_key(name: str, key: KeyArrayLike, *, - allow_batched: bool = False) -> tuple[KeyArray, bool]: +def _check_prng_key(name: str, key: ArrayLike, *, + allow_batched: bool = False) -> tuple[Array, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): wrapped_key = key wrapped = False @@ -113,7 +111,7 @@ def _return_prng_keys(was_wrapped, key): return prng.random_unwrap(key) if was_wrapped else key -def _random_bits(key: KeyArray, bit_width: int, shape: Shape) -> Array: +def _random_bits(key: Array, bit_width: int, shape: Shape) -> Array: assert jnp.issubdtype(key.dtype, dtypes.prng_key) return prng.random_bits(key, bit_width=bit_width, shape=shape) @@ -188,7 +186,7 @@ def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl: def _key(ctor_name: str, seed: int | ArrayLike, - impl_spec: PRNGSpecDesc | None) -> KeyArray: + impl_spec: PRNGSpecDesc | None) -> Array: impl = resolve_prng_impl(impl_spec) if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( @@ -200,7 +198,7 @@ def _key(ctor_name: str, seed: int | ArrayLike, return prng.random_seed(seed, impl=impl) def key(seed: int | ArrayLike, *, - impl: PRNGSpecDesc | None = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> Array: """Create a pseudo-random number generator (PRNG) key given an integer seed. The result is a scalar array containing a key, whose dtype indicates @@ -220,7 +218,7 @@ def key(seed: int | ArrayLike, *, return _key('key', seed, impl) def PRNGKey(seed: int | ArrayLike, *, - impl: PRNGSpecDesc | None = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> Array: """Create a legacy PRNG key given an integer seed. This function produces old-style legacy PRNG keys, which are arrays @@ -248,7 +246,7 @@ def PRNGKey(seed: int | ArrayLike, *, return _return_prng_keys(True, _key('PRNGKey', seed, impl)) -def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: +def fold_in(key: ArrayLike, data: IntegerArray) -> Array: """Folds in data to a PRNG key to form a new PRNG key. Args: @@ -267,7 +265,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: return _return_prng_keys(wrapped, key_out) -def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: +def _split(key: Array, num: int | tuple[int, ...] = 2) -> Array: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait # to always enable_custom_prng @@ -278,7 +276,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: shape = tuple(num) if isinstance(num, Sequence) else (num,) return prng.random_split(key, shape=shape) -def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: +def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: @@ -293,21 +291,22 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: return _return_prng_keys(wrapped, _split(typed_key, num)) -def _key_impl(keys: KeyArray) -> PRNGImpl: +def _key_impl(keys: Array) -> str | PRNGSpec: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) - return keys_dtype._impl + impl = keys_dtype._impl + return impl.name if impl.name in prng.prngs else PRNGSpec(impl) -def key_impl(keys: KeyArrayLike) -> PRNGSpec: +def key_impl(keys: ArrayLike) -> str | PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) - return PRNGSpec(_key_impl(typed_keys)) + return _key_impl(typed_keys) -def _key_data(keys: KeyArray) -> Array: +def _key_data(keys: Array) -> Array: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) return prng.random_unwrap(keys) -def key_data(keys: KeyArrayLike) -> Array: +def key_data(keys: ArrayLike) -> Array: """Recover the bits of key data underlying a PRNG key array.""" keys, _ = _check_prng_key("key_data", keys, allow_batched=True) return _key_data(keys) @@ -344,7 +343,7 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) -def bits(key: KeyArrayLike, +def bits(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeUInt | None = None) -> Array: """Sample uniform bits in the form of unsigned integers. @@ -373,7 +372,7 @@ def bits(key: KeyArrayLike, return _random_bits(key, bit_width, shape) -def uniform(key: KeyArrayLike, +def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., @@ -443,7 +442,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: lax.reshape(floats * (maxval - minval) + minval, shape)) -def randint(key: KeyArrayLike, +def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, @@ -532,7 +531,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: return lax.add(minval, lax.convert_element_type(random_offset, dtype)) -def permutation(key: KeyArrayLike, +def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, independent: bool = False) -> Array: @@ -581,6 +580,10 @@ def _shuffle(key, x, axis) -> Array: # another analysis (where the keys are generated one bit at a time). exponent = 3 # see tjablin@'s analysis for explanation of this parameter uint32max = jnp.iinfo(np.uint32).max + if not core.is_constant_dim(x.size): + raise NotImplementedError( + "shape polymorphism for `permutation` or `shuffle`" + f" for arrays of non-constant size: {x.size}") num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max))) for _ in range(num_rounds): @@ -591,7 +594,7 @@ def _shuffle(key, x, axis) -> Array: return x -def choice(key: KeyArrayLike, +def choice(key: ArrayLike, a: int | ArrayLike, shape: Shape = (), replace: bool = True, @@ -640,7 +643,9 @@ def choice(key: KeyArrayLike, if n_inputs <= 0: raise ValueError("a must be greater than 0 unless no samples are taken") if not replace and n_draws > n_inputs: - raise ValueError("Cannot take a larger sample than population when 'replace=False'") + raise ValueError( + f"Cannot take a larger sample (size {n_draws}) than " + f"population (size {n_inputs}) when 'replace=False'") if p is None: if replace: @@ -653,7 +658,9 @@ def choice(key: KeyArrayLike, check_arraylike("choice", p) p_arr, = promote_dtypes_inexact(p) if p_arr.shape != (n_inputs,): - raise ValueError("p must be None or match the shape of a") + raise ValueError( + "p must be None or a 1D vector with the same size as a.shape[axis]. " + f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.") if replace: p_cuml = jnp.cumsum(p_arr) r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype)) @@ -665,10 +672,10 @@ def choice(key: KeyArrayLike, result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) return result.reshape(shape if arr.ndim == 0 else - np.insert(np.delete(arr.shape, axis), axis, shape)) + arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:]) -def normal(key: KeyArrayLike, +def normal(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample standard normal random values with given shape and float dtype. @@ -721,7 +728,7 @@ def _normal_real(key, shape, dtype) -> Array: return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u)) -def multivariate_normal(key: KeyArrayLike, +def multivariate_normal(key: ArrayLike, mean: RealArray, cov: RealArray, shape: Shape | None = None, @@ -804,7 +811,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: return result -def truncated_normal(key: KeyArrayLike, +def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, @@ -870,7 +877,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) -def bernoulli(key: KeyArrayLike, +def bernoulli(key: ArrayLike, p: RealArray = np.float32(0.5), shape: Shape | None = None) -> Array: r"""Sample Bernoulli random values with given shape and mean. @@ -915,7 +922,7 @@ def _bernoulli(key, p, shape) -> Array: return uniform(key, shape, lax.dtype(p)) < p -def beta(key: KeyArrayLike, +def beta(key: ArrayLike, a: RealArray, b: RealArray, shape: Shape | None = None, @@ -976,7 +983,7 @@ def _beta(key, a, b, shape, dtype) -> Array: return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled) -def cauchy(key: KeyArrayLike, +def cauchy(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Cauchy random values with given shape and float dtype. @@ -1015,7 +1022,7 @@ def _cauchy(key, shape, dtype) -> Array: return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5)))) -def dirichlet(key: KeyArrayLike, +def dirichlet(key: ArrayLike, alpha: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1087,7 +1094,7 @@ def _softmax(x, axis) -> Array: return unnormalized / unnormalized.sum(axis, keepdims=True) -def exponential(key: KeyArrayLike, +def exponential(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Exponential random values with given shape and float dtype. @@ -1126,7 +1133,7 @@ def _exponential(key, shape, dtype) -> Array: return lax.neg(lax.log1p(lax.neg(u))) -def _gamma_one(key: KeyArray, alpha, log_space) -> Array: +def _gamma_one(key: Array, alpha, log_space) -> Array: # Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang # The algorithm can also be founded in: # https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables @@ -1254,7 +1261,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space): multiple_results=False), platform='cpu') batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule -def gamma(key: KeyArrayLike, +def gamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1301,7 +1308,7 @@ def gamma(key: KeyArrayLike, return _gamma(key, a, shape=shape, dtype=dtype) -def loggamma(key: KeyArrayLike, +def loggamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1443,7 +1450,7 @@ def _poisson(key, lam, shape, dtype) -> Array: return lax.select(lam == 0, jnp.zeros_like(result), result) -def poisson(key: KeyArrayLike, +def poisson(key: ArrayLike, lam: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: @@ -1488,7 +1495,7 @@ def poisson(key: KeyArrayLike, return _poisson(key, lam, shape, dtype) -def gumbel(key: KeyArrayLike, +def gumbel(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: """Sample Gumbel random values with given shape and float dtype. @@ -1524,7 +1531,7 @@ def _gumbel(key, shape, dtype) -> Array: uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) -def categorical(key: KeyArrayLike, +def categorical(key: ArrayLike, logits: RealArray, axis: int = -1, shape: Shape | None = None) -> Array: @@ -1566,7 +1573,7 @@ def categorical(key: KeyArrayLike, axis=axis) -def laplace(key: KeyArrayLike, +def laplace(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Laplace random values with given shape and float dtype. @@ -1603,7 +1610,7 @@ def _laplace(key, shape, dtype) -> Array: return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) -def logistic(key: KeyArrayLike, +def logistic(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample logistic random values with given shape and float dtype. @@ -1639,7 +1646,7 @@ def _logistic(key, shape, dtype): return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) -def pareto(key: KeyArrayLike, +def pareto(key: ArrayLike, b: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1688,7 +1695,7 @@ def _pareto(key, b, shape, dtype) -> Array: return lax.exp(e / b) -def t(key: KeyArrayLike, +def t(key: ArrayLike, df: RealArray, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: @@ -1740,7 +1747,7 @@ def _t(key, df, shape, dtype) -> Array: return n * jnp.sqrt(half_df / g) -def chisquare(key: KeyArrayLike, +def chisquare(key: ArrayLike, df: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1792,7 +1799,7 @@ def _chisquare(key, df, shape, dtype) -> Array: return chi2 -def f(key: KeyArrayLike, +def f(key: ArrayLike, dfnum: RealArray, dfden: RealArray, shape: Shape | None = None, @@ -1856,7 +1863,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: return f -def rademacher(key: KeyArrayLike, +def rademacher(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeInt = int) -> Array: r"""Sample from a Rademacher distribution. @@ -1891,7 +1898,7 @@ def _rademacher(key, shape, dtype) -> Array: return (2 * bernoulli_samples - 1).astype(dtype) -def maxwell(key: KeyArrayLike, +def maxwell(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample from a one sided Maxwell distribution. @@ -1931,7 +1938,7 @@ def _maxwell(key, shape, dtype) -> Array: return jnp.linalg.norm(norm_rvs, axis=-1) -def double_sided_maxwell(key: KeyArrayLike, +def double_sided_maxwell(key: ArrayLike, loc: RealArray, scale: RealArray, shape: Shape = (), @@ -1983,7 +1990,7 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array: return random_sign * maxwell_rvs * scale + loc -def weibull_min(key: KeyArrayLike, +def weibull_min(key: ArrayLike, scale: RealArray, concentration: RealArray, shape: Shape = (), @@ -2029,7 +2036,7 @@ def _weibull_min(key, scale, concentration, shape, dtype) -> Array: def orthogonal( - key: KeyArrayLike, + key: ArrayLike, n: int, shape: Shape = (), dtype: DTypeLikeFloat = float @@ -2064,7 +2071,7 @@ def orthogonal( return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2])) def generalized_normal( - key: KeyArrayLike, + key: ArrayLike, p: float, shape: Shape = (), dtype: DTypeLikeFloat = float @@ -2099,7 +2106,7 @@ def generalized_normal( return r * g ** (1 / p) def ball( - key: KeyArrayLike, + key: ArrayLike, d: int, p: float = 2, shape: Shape = (), @@ -2131,7 +2138,7 @@ def ball( return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None] -def rayleigh(key: KeyArrayLike, +def rayleigh(key: ArrayLike, scale: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2184,7 +2191,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: ray = lax.mul(scale, sqrt_u) return ray -def wald(key: KeyArrayLike, +def wald(key: ArrayLike, mean: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2242,7 +2249,7 @@ def _wald(key, mean, shape, dtype) -> Array: w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x) return w -def geometric(key: KeyArrayLike, +def geometric(key: ArrayLike, p: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: @@ -2295,7 +2302,7 @@ def _geometric(key, p, shape, dtype) -> Array: return g.astype(dtype) -def triangular(key: KeyArrayLike, +def triangular(key: ArrayLike, left: RealArray, mode: RealArray, right: RealArray, @@ -2359,7 +2366,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: return tri -def lognormal(key: KeyArrayLike, +def lognormal(key: ArrayLike, sigma: RealArray = np.float32(1), shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2564,7 +2571,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: def binomial( - key: KeyArray, + key: Array, n: RealArray, p: RealArray, shape: Shape | None = None, diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d014e5ceb24e..2e3632700759 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -1679,7 +1679,7 @@ def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float U is a Unitary Matrix: - >>> jnp.round(U.T @ U) + >>> jnp.round(U.T @ U) # doctest: +SKIP Array([[ 1., -0., -0.], [-0., 1., 0.], [-0., 0., 1.]], dtype=float32) @@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: - r"""Construct a Toeplitz matrix + r"""Construct a Toeplitz matrix. JAX implementation of :func:`scipy.linalg.toeplitz`. @@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: Notice this implies that :math:`r_0` is ignored. Args: - c: array specifying the first column. Will be flattened - if not 1-dimensional. - r: (optional) array specifying the first row. If not specified, defaults - to ``conj(c)``. Will be flattened if not 1-dimensional. + c: array of shape ``(..., N)`` specifying the first column. + r: (optional) array of shape ``(..., M)`` specifying the first row. Leading + dimensions must be broadcast-compatible with those of ``c``. If not specified, + ``r`` defaults to ``conj(c)``. Returns: - toeplitz matrix of shape ``(c.size, r.size)``. + A Toeplitz matrix of shape ``(... N, M)``. Examples: Specifying ``c`` only: @@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) M is Hermitian: True + + For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices: + + >>> c = jnp.array([[1, 2, 3], [4, 5, 6]]) + >>> jax.scipy.linalg.toeplitz(c) + Array([[[1, 2, 3], + [2, 1, 2], + [3, 2, 1]], + + [[4, 5, 6], + [5, 4, 5], + [6, 5, 4]]], dtype=int32) """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) else: check_arraylike("toeplitz", c, r) + return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r))) - c_arr = jnp.asarray(c).flatten() - r_arr = jnp.asarray(r).flatten() - - ncols, = c_arr.shape - nrows, = r_arr.shape - +@partial(jnp.vectorize, signature="(m),(n)->(m,n)") +def _toeplitz(c: Array, r: Array) -> Array: + ncols, = c.shape + nrows, = r.shape if ncols == 0 or nrows == 0: - return jnp.empty((ncols, nrows), - dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype)) - + return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype)) nelems = ncols + nrows - 1 - elems = jnp.concatenate((c_arr[::-1], r_arr[1:])) + elems = jnp.concatenate((c[::-1], r[1:])) patches = lax.conv_general_dilated_patches( elems.reshape((1, nelems, 1)), (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) - @partial(jit, static_argnames=("n",)) def hilbert(n: int) -> Array: r"""Create a Hilbert matrix of order n. diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 837aa011f165..2fffe6381b97 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -16,7 +16,7 @@ from functools import partial import operator -from typing import cast, overload, Any +from typing import cast, Any import numpy as np @@ -28,7 +28,6 @@ from jax._src import core from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact @@ -67,6 +66,7 @@ def gammaln(x: ArrayLike) -> Array: return lax.lgamma(x) +@jit def gammasgn(x: ArrayLike) -> Array: r"""Sign of the gamma function. @@ -82,6 +82,13 @@ def gammasgn(x: ArrayLike) -> Array: Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Because :math:`\Gamma(x)` is never zero, no condition is required for this case. + * if :math:`x = -\infty`, NaN is returned. + * if :math:`x = \pm 0`, :math:`\pm 1` is returned. + * if :math:`x` is a negative integer, NaN is returned. The sign of gamma + at a negative integer depends on from which side the pole is approached. + * if :math:`x = \infty`, :math:`1` is returned. + * if :math:`x` is NaN, NaN is returned. + Args: x: arraylike, real valued. @@ -93,8 +100,14 @@ def gammasgn(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function """ x, = promote_args_inexact("gammasgn", x) + typ = x.dtype.type floor_x = lax.floor(x) - return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0) + x_negative = x < 0 + return jnp.select( + [(x_negative & (x == floor_x)) | jnp.isnan(x), + (x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))], + [typ(np.nan), typ(-1.0)], + typ(1.0)) def gamma(x: ArrayLike) -> Array: @@ -116,6 +129,13 @@ def gamma(x: ArrayLike) -> Array: \Gamma(n) = (n - 1)! + * if :math:`z = -\infty`, NaN is returned. + * if :math:`x = \pm 0`, :math:`\pm \infty` is returned. + * if :math:`x` is a negative integer, NaN is returned. The sign of gamma + at a negative integer depends on from which side the pole is approached. + * if :math:`x = \infty`, :math:`\infty` is returned. + * if :math:`x` is NaN, NaN is returned. + Args: x: arraylike, real valued. @@ -128,7 +148,8 @@ def gamma(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function Notes: - Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs. + Unlike the scipy version, JAX's ``gamma`` does not support complex-valued + inputs. """ x, = promote_args_inexact("gamma", x) return gammasgn(x) * lax.exp(lax.lgamma(x)) @@ -189,16 +210,8 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array: n, = promote_args_inexact("factorial", n) return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) -@overload -def beta(a: ArrayLike, b: ArrayLike) -> Array: ... - -@overload -def beta(a: ArrayLike, *, y: ArrayLike) -> Array: ... - -@overload -def beta(*, x: ArrayLike, y: ArrayLike) -> Array: ... -def beta(*args, **kwds): +def beta(a: ArrayLike, b: ArrayLike) -> Array: r"""The beta function JAX implementation of :obj:`scipy.special.beta`. @@ -220,24 +233,6 @@ def beta(*args, **kwds): - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.betaln` """ - # TODO(jakevdp): deprecation warning added 2024-06-10; finalize after 2024-09-10 - if 'x' in kwds: - msg = "The `x` parameter of jax.scipy.special.beta is deprecated, use `a` instead." - deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) - if 'a' in kwds: - raise TypeError("beta() got both parameter 'a' and parameter 'x'.") - kwds['a'] = kwds.pop('x') - if 'y' in kwds: - msg = "The `y` parameter of jax.scipy.special.beta is deprecated, use `b` instead." - deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) - if 'b' in kwds: - raise TypeError("beta() got both parameter 'b' and parameter 'y'.") - kwds['b'] = kwds.pop('y') - if extra := kwds.keys() - {'a', 'b'}: - raise TypeError(f"beta() got unexpected keyword arguments {list(extra)}") - return _beta(*args, **kwds) - -def _beta(a, b): a, b = promote_args_inexact("beta", a, b) sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) return sign * lax.exp(betaln(a, b)) diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 08d1c0b6b538..65c457f79cc8 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -198,17 +198,16 @@ def rankdata( return jnp.apply_along_axis(rankdata, axis, a, method) arr = jnp.ravel(a) - sorter = jnp.argsort(arr) + arr, sorter = jax.lax.sort_key_val(arr, jnp.arange(arr.size)) inv = invert_permutation(sorter) if method == "ordinal": return inv + 1 - arr = arr[sorter] - obs = jnp.insert(arr[1:] != arr[:-1], 0, True) + obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]]) dense = obs.cumsum()[inv] if method == "dense": return dense - count = jnp.nonzero(obs, size=arr.size + 1, fill_value=len(obs))[0] + count = jnp.nonzero(obs, size=arr.size + 1, fill_value=obs.size)[0] if method == "max": return count[dense] if method == "min": diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index f410d08e4f3d..4343c080251c 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -51,12 +51,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) - :func:`jax.scipy.stats.gamma.logsf` """ x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale) + ok = lax.ge(x, loc) one = _lax_const(x, 1) - y = lax.div(lax.sub(x, loc), scale) + y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + return jnp.where(ok, log_probs, -jnp.inf) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py index 574d725c4999..e2ddec15e8d4 100644 --- a/jax/_src/shard_alike.py +++ b/jax/_src/shard_alike.py @@ -15,6 +15,7 @@ from functools import partial import itertools +from jax._src import config from jax._src import core from jax._src.interpreters import ad from jax._src.interpreters import mlir @@ -24,7 +25,7 @@ from jax._src.util import safe_zip from jax._src.lib import xla_client as xc from jax._src.api_util import shaped_abstractify -from jax._src.lib.mlir import ir +from jax._src.lib.mlir import dialects, ir _next_shard_group_id = itertools.count() @@ -91,6 +92,11 @@ def _group_shard( ) -> tuple[ir.Value, ir.Value]: shard_group_id = next(_next_shard_group_id) + if config.use_shardy_partitioner.value: + dialects.sdy.ShardingGroupOp(x, shard_group_id) + dialects.sdy.ShardingGroupOp(y, shard_group_id) + return x, y + unknown_op_sharding = xc.OpSharding() unknown_op_sharding.type = xc.OpSharding.Type.UNKNOWN unknown_op_sharding.is_shard_group = True diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index cee3542f0006..23f0ef13cb00 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -43,7 +43,8 @@ def _addressable_devices_indices_map( if d.process_index == d.client.process_index()} @cache(max_size=4096, trace_context_in_key=False) -def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]: +def common_devices_indices_map( + s: Sharding, global_shape: Shape) -> Mapping[Device, Index]: s.shard_shape(global_shape) # raises a good error message hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) indices = op_sharding_to_indices(hlo_sharding, global_shape, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index f54c39efebce..5e1def1079ac 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -37,7 +37,6 @@ are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method -from jax._src.lib import xla_extension_version import numpy as np @@ -138,9 +137,12 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - if self._manual_axes: + mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() + if t == mesh_lib.AxisTypes.Collective} + manual_axes = self._manual_axes.union(mesh_manual_axes) + if manual_axes: axis_names = self.mesh.axis_names - for manual_axis in self._manual_axes: + for manual_axis in manual_axes: special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL replicated_mesh_axes = [] @@ -242,8 +244,6 @@ class NamedSharding(sharding.Sharding): _parsed_pspec: ParsedPartitionSpec _manual_axes: frozenset[MeshAxisName] _logical_device_ids: tuple[int, ...] | None - if xla_extension_version < 292: - _logical_device_ids = None @use_cpp_method() def __init__( @@ -308,15 +308,10 @@ def _from_parsed_pspec( cls, mesh, parsed_pspec, *, memory_kind=None, _manual_axes=frozenset(), _logical_device_ids=None, ): - if xla_extension_version >= 292: - return cls(mesh, parsed_pspec.get_partition_spec(), - memory_kind=memory_kind, _parsed_pspec=parsed_pspec, - _manual_axes=_manual_axes, - _logical_device_ids=_logical_device_ids) - else: - return cls(mesh, parsed_pspec.get_partition_spec(), - memory_kind=memory_kind, _parsed_pspec=parsed_pspec, - _manual_axes=_manual_axes) + return cls(mesh, parsed_pspec.get_partition_spec(), + memory_kind=memory_kind, _parsed_pspec=parsed_pspec, + _manual_axes=_manual_axes, + _logical_device_ids=_logical_device_ids) @property def num_devices(self) -> int: @@ -368,20 +363,10 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) - def _normalized_spec(self, ndim: int) -> PartitionSpec: - out = [] # type: ignore - for p in self._parsed_pspec: - if p is None: - raise ValueError("UNCONSTRAINED is not supported yet.") - if not p: - out.append(None) - elif isinstance(p, tuple) and len(p) == 1: - out.append(p[0]) - else: - out.append(p) - if len(out) < ndim: - out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: + if not isinstance(spec, PartitionSpec): + spec = PartitionSpec(*spec) + return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) @@ -543,7 +528,7 @@ def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore # TODO(yashkatariya): Expose `sharded_dim_size` in the API if required. @classmethod - def default(cls, shape: Shape, sharded_dim: int = 0, + def default(cls, shape: Shape, sharded_dim: int | None = 0, devices: Sequence[xc.Device] | None = None) -> PmapSharding: """Creates a :class:`PmapSharding` which matches the default placement used by :func:`jax.pmap`. @@ -555,6 +540,13 @@ def default(cls, shape: Shape, sharded_dim: int = 0, device order used by pmap is used, which is the order of :func:`jax.local_devices`. """ + if sharded_dim is None: + if devices is None: + raise ValueError("One of sharded_dim or devices must be set.") + nrep = len(devices) + return cls(np.array(devices), + sharding_specs.pmap_sharding_spec(nrep, nrep, shape, None)) + # The dtype doesn't matter here. Its only used for creating the # sharding_spec. sharding_spec = sharding_specs.create_pmap_sharding_spec( @@ -573,11 +565,6 @@ def default(cls, shape: Shape, sharded_dim: int = 0, raise NotImplementedError( 'Multiple chunks in Chunked dimension not supported.') - if num_ways_sharded is None: - raise NotImplementedError( - '`None` to sharded_dim is not supported. Please file a jax ' - 'issue if you need this feature.') - if devices is None: pmap_devices: np.ndarray = np.array( xla_bridge.local_devices()[:num_ways_sharded]) @@ -965,21 +952,11 @@ def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) -def is_auto(x): - return isinstance(x, AUTO) - - class UnspecifiedValue: def __repr__(self): return "UnspecifiedValue" UNSPECIFIED = UnspecifiedValue() -def is_unspecified(x): - return isinstance(x, UnspecifiedValue) - -def is_unspecified_or_auto(x): - return is_auto(x) or is_unspecified(x) - MeshAxisName = Any @@ -1022,8 +999,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): def get_array_mapping( axis_resources: ParsedPartitionSpec | AUTO | UnspecifiedValue ) -> ArrayMappingOrAutoOrUnspecified: - # TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported. - # Don't use `is_auto` here to satisfy pytype and mypy. if isinstance(axis_resources, (AUTO, UnspecifiedValue)): return axis_resources return OrderedDict((axis, i) @@ -1121,7 +1096,7 @@ def prepare_axis_resources(axis_resources, arg_name, new_entries = [] for entry in entries: - if is_unspecified_or_auto(entry) or entry is None: + if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None: new_entries.append(entry) elif isinstance(entry, sharding.Sharding): if isinstance(entry, PmapSharding): @@ -1139,8 +1114,7 @@ def prepare_axis_resources(axis_resources, arg_name, def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue - if (is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, sharding.Sharding)): + if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( @@ -1740,17 +1714,25 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], """ if devices is None: devices = xla_bridge.devices() - axis_size = math.prod(axis_shapes) + new_axis_shapes = mesh_utils._canonicalize_axis_sizes(axis_shapes) + if new_axis_shapes is None: + raise ValueError( + '`axis_shapes` passed to `make_mesh` should be a sequence of ints.' + f' Got {axis_shapes}') + del axis_shapes + + axis_size = math.prod(new_axis_shapes) if axis_size > len(devices): raise ValueError( f'Number of devices {len(devices)} must be >= the product ' - f'of mesh_shape {axis_shapes}') + f'of mesh_shape {new_axis_shapes}') elif axis_size < len(devices): devices = devices[:axis_size] - if devices[0].device_kind == mesh_utils._TPU_V5_LITE: + if devices[0].device_kind in (mesh_utils._TPU_V5_LITE, mesh_utils._TPU_V5E): allow_split_physical_axes = True else: allow_split_physical_axes = False mesh_devices = mesh_utils.create_device_mesh( - axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) + new_axis_shapes, devices, + allow_split_physical_axes=allow_split_physical_axes) return mesh_lib.Mesh(mesh_devices, axis_names) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3a2c375b64db..db26813de8bc 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -43,7 +43,8 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.sharding_impls import is_unspecified_or_auto +from jax._src import mesh as mesh_lib +from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -649,7 +650,7 @@ def out_info(self): # PyTree of OutInfo out_avals = self._lowering.compile_args["global_out_avals"] out_shardings = self._lowering.compile_args["out_shardings"] return self.out_tree.unflatten( - [OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s) + [OutInfo(o.shape, o.dtype, None if isinstance(s, (UnspecifiedValue, AUTO)) else s) for o, s in zip(out_avals, out_shardings)]) def compile( @@ -716,13 +717,14 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, args_flat=None, arg_names=None, - num_consts: int = 0): + lower_callable, abstract_mesh=None, + args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info self.fun_name = fun_name self._out_tree = out_tree self._lower_callable = lower_callable + self._abstract_mesh = abstract_mesh self._args_flat = args_flat self._arg_names = arg_names self._num_consts = num_consts @@ -743,7 +745,10 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, self._lower_callable, lowering_platforms=lowering_platforms, lowering_parameters=_private_parameters) try: - lowering = new_callable() + # TODO(yashkatariya): Maybe thread this into pjit params like resource_env + # and set the context manager down the stack? + with mesh_lib.set_abstract_mesh(self._abstract_mesh): + lowering = new_callable() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args msg = pjit._device_assignment_mismatch_error( diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index f1c4994b473b..21199b9bfd68 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -17,12 +17,12 @@ from collections.abc import Callable, Sequence import dataclasses from functools import partial +import math import operator from typing import Any, Protocol, TypeVar from jax._src import ad_util from jax._src import api_util -from jax._src import config from jax._src import core from jax._src import linear_util as lu from jax._src import source_info_util @@ -34,7 +34,14 @@ from jax._src.lax import slicing as lax_slicing from jax._src.state import indexing from jax._src.state.primitives import addupdate_p, get_p, swap_p -from jax._src.state.types import AbstractRef, RefBitcaster, RefEffect, RefReshaper +from jax._src.state.types import ( + AbstractRef, + RefBitcaster, + RefEffect, + RefReshaper, + get_ref_aval_from_value, + uninitialized, +) from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array from jax._src.util import ( @@ -44,6 +51,7 @@ safe_zip, split_dict, split_list, + unzip2, weakref_lru_cache, ) import numpy as np @@ -145,6 +153,10 @@ def _eval_jaxpr_discharge_state( [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) refs_to_discharge.add(id(outvar.aval)) + elif eqn.primitive is core.freeze_p: + [invar], [outvar] = eqn.invars, eqn.outvars + ans = env.read(invar) + refs_to_discharge.remove(id(invar.aval)) elif (any(should_discharge) or core.internal_mutable_array_effect in eqn.effects ): @@ -356,7 +368,7 @@ def transform_swap_array(x, transforms, val): case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): - _results.append(None) + _results.append(_results[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. @@ -469,40 +481,73 @@ def _closed_call_discharge_rule( run_state_p = core.Primitive("run_state") run_state_p.multiple_results = True -def _run_state_bind(*args: Any, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): - if config.enable_checks.value: - core.check_jaxpr(jaxpr) - assert len(jaxpr.invars) == len(args) - assert len(which_linear) == len(args) - return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr, - which_linear=which_linear) -run_state_p.def_custom_bind(_run_state_bind) +def _default_initialization(x): + assert hasattr(x, 'shape') + assert hasattr(x, 'dtype') + dtype = np.dtype(x) + if np.issubdtype(dtype, np.integer): + value = np.iinfo(dtype).min + else: + value = math.nan + return lax.full(x.shape, value, dtype) def _run_state_impl(*args: Any, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): del which_linear discharged_jaxpr, consts = discharge_state(jaxpr, ()) + # Initialize the args that are not initialized. + args_it = iter(args) + args = tuple( + next(args_it) if is_init else _default_initialization(var.aval) + for is_init, var in zip(is_initialized, discharged_jaxpr.invars) + ) return core.eval_jaxpr(discharged_jaxpr, consts, *args) run_state_p.def_impl(_run_state_impl) mlir.register_lowering(run_state_p, mlir.lower_fun(_run_state_impl)) def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): del which_linear + assert sum(is_initialized) == len(avals) # When we abstractly evaluate `run_state`, we want to keep track of which # input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to # "propagate" out its inner effects. Otherwise, the effects are local to this # `run_state`. + inner_to_outer_aval_mapping = {} + outer_ref_index = 0 + for i, is_init in enumerate(is_initialized): + if not is_init: + pass + inner_to_outer_aval_mapping[i] = outer_ref_index + outer_ref_index += 1 + nonlocal_effects = set() is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)} - nonlocal_effects = {e for e in jaxpr.effects - if (isinstance(e, RefEffect) and e.input_index in is_ref) - or not isinstance(e, RefEffect)} + for eff in jaxpr.effects: + if not isinstance(eff, RefEffect): + nonlocal_effects.add(eff) + continue + if eff.input_index not in inner_to_outer_aval_mapping: + # This means that this effect corresponds to an uninitialized Ref and + # should not propagate out of the primitive. + continue + # If we do propagate the effect, we need to update the input index to + # correspond to the outer index. + outer_index = inner_to_outer_aval_mapping[eff.input_index] + if outer_index in is_ref: + # This means that the effect corresponds to a Ref from an outside scope. + nonlocal_effects.add( + eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index]) + ) return avals, nonlocal_effects run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval) def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, - jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): + jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError("Uninitialized Refs are not supported in jvp.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): @@ -524,7 +569,9 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_) jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents)) out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, - which_linear=jvp_which_linear) + which_linear=jvp_which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jvp_jaxpr.invars)) out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts), len(primals)]) del out_consts @@ -576,7 +623,12 @@ def eval_jaxpr(*refs): return jaxpr def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, - jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): + jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in partial_eval." + ) num_inputs = len(tracers) assert num_inputs == len(jaxpr.invars) in_unknowns = [not t.pval.is_known() for t in tracers] @@ -636,7 +688,9 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, - which_linear=jaxpr_known_which_linear) + which_linear=jaxpr_known_which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_known.invars)) known_outputs, residuals = split_list(out_flat, [len(known_tracers)]) residuals = map(trace.new_instantiated_const, residuals) ref_res, nonref_res = split_list(residuals, [num_res_ref]) @@ -664,7 +718,9 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, assert len(unknown_inputs) == len(res_ref_unknown_outputs) assert len(unknown_inputs) == len(jaxpr_unknown.invars) - uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear) + uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear, + # TODO(sharadmv); compute this in the general case + is_initialized=(True,) * len(jaxpr_unknown.invars)) _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], **uk_params) eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, @@ -682,7 +738,13 @@ def _run_state_partial_eval_custom( eqn: core.JaxprEqn): if not any(in_unknowns): return eqn, None, in_unknowns, [False] * len(in_unknowns), [] - jaxpr, which_linear = split_dict(eqn.params, ["jaxpr", "which_linear"]) + jaxpr, which_linear, is_initialized = split_dict( + eqn.params, ["jaxpr", "which_linear", "is_initialized"] + ) + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in partial_eval_custom." + ) num_inputs = len(eqn.invars) # We first need to run a fixpoint to determine which of the `Ref`s are unknown # after running the for loop. However, the jaxpr has no outputs. Instead, we @@ -709,7 +771,8 @@ def _run_state_partial_eval_custom( break in_unknowns = map(operator.or_, in_unknowns, out_unknowns) else: - if num_inputs > 0: raise Exception("Invalid fixpoint") + if num_inputs > 0: + raise Exception("Invalid fixpoint") del out_unknowns # Redundant since it's the same as `in_unknowns` new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst) if type(x) is core.Var and inst and not already] @@ -748,7 +811,9 @@ def _run_state_partial_eval_custom( jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars] - known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear) + known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_known.invars)) _, known_effects = run_state_p.abstract_eval( *[v.aval for v in known_and_res_invars], **known_params) eqn_known = pe.new_jaxpr_eqn(known_and_res_invars, @@ -760,7 +825,9 @@ def _run_state_partial_eval_custom( _, staged_which_linear = partition_list(in_unknowns, which_linear) which_linear_unknown = (*[False] * num_res, *staged_which_linear) - staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown) + staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_staged.invars)) rejiggered_resvars = [*nonref_resvars, *ref_resvars] _, staged_invars = partition_list(in_unknowns, eqn.invars) res_staged_invars = [*rejiggered_resvars, *staged_invars] @@ -791,8 +858,12 @@ def staged(*args): return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom -def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool] - ) -> tuple[core.Jaxpr, Any]: +def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool], + is_initialized: tuple[bool, ...]) -> tuple[core.Jaxpr, Any]: + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in transpose." + ) def trans(*args): # First we want to run the computation to read all the residual refs. We can # do that by using partial evaluation with all linear inputs unknown. @@ -811,8 +882,14 @@ def trans(*args): all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]] empty_res = map(ad.zeros_like_aval, all_avals) res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_) - res = run_state_p.bind(*res_args, *empty_res, jaxpr=res_jaxpr, - which_linear=(False,) * (len(res_args) + len(empty_res))) + res = run_state_p.bind( + *res_args, + *empty_res, + jaxpr=res_jaxpr, + which_linear=(False,) * (len(res_args) + len(empty_res)), + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(res_jaxpr.invars), + ) res = res[len(res_args):] ref_res_, nonref_res_ = split_list(res, [num_res_ref]) @@ -835,7 +912,12 @@ def trans(*args): return jaxpr_trans, consts def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in transpose." + ) # if any in_ct is nonzero, we definitely want it in args_ (and the # corresponding x in args could be an undefined primal, but doesn't have to be) # for non-res stuff: @@ -859,12 +941,19 @@ def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x): # the loop was 'getting and setting', grab that cotangent! transpose_args.append(ct) - jaxpr_transpose_, consts = _transpose_jaxpr(jaxpr, which_linear) + jaxpr_transpose_, consts = _transpose_jaxpr( + jaxpr, which_linear, is_initialized + ) jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_) which_linear = (*[False] * len(consts), *which_linear) - const_all_outs = run_state_p.bind(*consts, *transpose_args, - jaxpr=jaxpr_transpose, - which_linear=which_linear) + const_all_outs = run_state_p.bind( + *consts, + *transpose_args, + jaxpr=jaxpr_transpose, + which_linear=which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_transpose.invars), + ) _, all_outs = split_list(const_all_outs, [len(consts)]) ct_outs = [ct if ad.is_undefined_primal(x) else None for x, ct in zip(args, all_outs)] @@ -875,9 +964,15 @@ def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], *args: Any, jaxpr: core.Jaxpr, - which_linear: Sequence[bool]): + which_linear: Sequence[bool], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in discharge." + ) del out_avals - out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear) + out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear, + is_initialized=is_initialized) new_invals = [] for aval, out_val in zip(in_avals, out_vals): new_invals.append(out_val if isinstance(aval, AbstractRef) else None) @@ -896,16 +991,23 @@ def _initial_style_jaxpr(fun, in_tree, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) return jaxpr, consts, out_tree_thunk() + T = TypeVar('T') def run_state(f: Callable[..., None]) -> Callable[[T], T]: def wrapped(args): flat_args, in_tree = tree_util.tree_flatten(args) - avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args] - jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals)) + ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) + # There may be some uninitialized values here in ref_args. + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals) jaxpr = hoist_consts_to_refs(jaxpr_) - which_linear = (False,) * (len(consts) + len(flat_args)) - out_const_flat = run_state_p.bind(*consts, *flat_args, jaxpr=jaxpr, - which_linear=which_linear) + which_linear = (False,) * (len(consts) + len(ref_args)) + refs_is_initialized = tuple(r is not uninitialized for r in ref_args) + init_args = tuple(r for r in ref_args if r is not uninitialized) + # Consts are always initialized. + is_initialized = (True,) * len(consts) + refs_is_initialized + out_const_flat = run_state_p.bind(*consts, *init_args, jaxpr=jaxpr, + which_linear=which_linear, + is_initialized=is_initialized) _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped @@ -913,12 +1015,19 @@ def wrapped(args): def run_state_reference(f: Callable[..., None]): def wrapped(args): flat_args, in_tree = tree_util.tree_flatten(args) - avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args] - jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals)) + ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals) jaxpr = hoist_consts_to_refs(jaxpr_) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) + + # Initialize any uninitialized values here in ref_args in the reference. + ref_args = [ + _default_initialization(aval) if r is uninitialized else r + for r, aval in zip(ref_args, ref_avals) + ] + out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, - *consts, *args) + *consts, *ref_args) _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index cb653547baff..2da93e3d8e80 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -46,11 +46,11 @@ def __post_init__(self): @property def is_dynamic_start(self): - return not isinstance(self.start, int) + return not core.is_dim(self.start) @property def is_dynamic_size(self): - return not isinstance(self.size, int) + return not core.is_dim(self.size) def tree_flatten(self): # If `start` is statically known, we treat it as static information @@ -72,10 +72,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice: @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: - start, stop, step = slc.indices(size) + start, step, size = core.canonicalize_slice(slc, size) if step < 1: raise ValueError(f"slice must have a step >= 1 (found: {step})") - return cls(start, max((stop - start + step - 1) // step, 0), step) + return cls(start, size, step) def dslice( @@ -123,12 +123,7 @@ def _maybe_concretize(x: Any): # This is roughly the same logic as core.concrete_or_error, but we avoid # calling that because constructing the ConcretizationTypeError can be # expensive as the size of the tracing context (i.e. the jaxpr) grows. - if isinstance(x, core.Tracer): - if isinstance(x.aval, core.ConcreteArray): - return x.aval.val - else: - return None - return x + return core.to_concrete_value(x) @tree_util.register_pytree_node_class @dataclasses.dataclass diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index a0f70a126c8e..8b8d189b3e97 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -127,7 +127,7 @@ def ref_get( swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) -def swap_ragged_prop_rule(invar_raggedness, outvars): +def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1] @@ -214,7 +214,10 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) - out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype) + # TODO(yashkatariya): Transform the sharding too instead of setting it to + # None. + out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype, + sharding=None) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") @@ -230,7 +233,6 @@ def _swap_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) @@ -262,7 +264,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - val_aval = core.raise_to_shaped(val_aval) out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) @@ -653,3 +654,8 @@ def _broadcast_to_abstract_eval(aval, *, shape): mlir.register_lowering( broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False) ) + +# === AD rules for mutable arrays === + +ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g)) +ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g)) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 993eeb814e30..df3c63606ba4 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Protocol, Union +from typing import Any, Callable, Protocol, Union from jax._src import core from jax._src import dtypes @@ -291,15 +291,14 @@ def weak_type(self) -> bool: raise AttributeError return self.inner_aval.weak_type + def update_weak_type(self, weak_type): + return AbstractRef(self.inner_aval.update_weak_type(weak_type)) + def update(self, inner_aval=None): if inner_aval is None: return AbstractRef(self.inner_aval) return AbstractRef(inner_aval) - def join(self, other): - assert isinstance(other, AbstractRef) - return AbstractRef(self.inner_aval.join(other.inner_aval)) - ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) @@ -365,10 +364,6 @@ def __eq__(self, other): def __hash__(self): return hash((self.__class__, self.inner_aval)) -def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type): - return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type)) -core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped - def _map_ref(size, axis, ref_aval): return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval)) @@ -404,3 +399,26 @@ def _unshard_ref(mesh, names, ref_aval: AbstractRef): raise NotImplementedError("Can't unshard a Ref") return ref_aval core.unshard_aval_handlers[AbstractRef] = _unshard_ref + + +# Sentinel type for indicating an uninitialized value. +class Uninitialized: + pass +uninitialized = Uninitialized() + + +_ref_type_aval_mappings: dict[ + type[Any], Callable[[Any], tuple[AbstractRef, Array | Uninitialized]], +] = {} + + +def _default_value_to_ref_aval(x: Any) -> tuple[AbstractRef, Array]: + # Default type mapping just creates an AbstractRef from the array's aval. + aval = core.raise_to_shaped(core.get_aval(x)) + return AbstractRef(aval), x + + +def get_ref_aval_from_value(x: Any): + if type(x) in _ref_type_aval_mappings: + return _ref_type_aval_mappings[type(x)](x) + return _default_value_to_ref_aval(x) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 81737f27540b..0bd5c7b139a1 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -25,6 +25,7 @@ import logging import math import os +import platform import re import sys import tempfile @@ -44,11 +45,13 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes +from jax._src import lib as _jaxlib from jax._src import linear_util as lu from jax._src import monitoring from jax._src import pjit as pjit_lib from jax._src import stages from jax._src import xla_bridge +from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -451,13 +454,25 @@ def assert_num_jit_and_pmap_compilations(times): f"but executed {count[0]}") +def jaxlib_version() -> tuple[int, ...]: + return _jaxlib.version + + def device_under_test(): return _TEST_DUT.value or xla_bridge.get_backend().platform def supported_dtypes(): if device_under_test() == "tpu": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64} + np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e5m2} + elif device_under_test() == "gpu": + types = {np.bool_, np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + _dtypes.bfloat16, np.float16, np.float32, np.float64, + np.complex64, np.complex128, _dtypes.float8_e4m3fn, + _dtypes.float8_e5m2} elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: @@ -538,6 +553,14 @@ def is_cuda_compute_capability_at_least(capability: str) -> bool: current = tuple(int(x) for x in d.compute_capability.split(".")) return current >= target +def is_cuda_compute_capability_equal(capability: str) -> bool: + if not is_device_cuda(): + return False + d, *_ = jax.local_devices(backend="gpu") + target = tuple(int(x) for x in capability.split(".")) + current = tuple(int(x) for x in d.compute_capability.split(".")) + return current == target + def _get_device_tags(): """returns a set of tags defined for the device under test""" if is_device_rocm(): @@ -957,6 +980,31 @@ def fn(shape, dtype): size=shape, replace=False) return fn +def rand_indices_unique_along_axis(rng): + """Sample an array of given shape containing indices up to dim (exclusive), + such that the indices are unique along the given axis. + Optionally, convert some of the resulting indices to negative indices.""" + def fn(dim, shape, axis, allow_negative=True): + batch_size = math.prod(shape[:axis] + shape[axis:][1:]) + idx = [ + rng.choice(dim, size=shape[axis], replace=False) + for _ in range(batch_size) + ] + idx = np.array(idx).reshape(batch_size, shape[axis]) + idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],)) + idx = np.moveaxis(idx, -1, axis) + + # assert that indices are unique along the given axis + count = partial(np.bincount, minlength=dim) + assert (np.apply_along_axis(count, axis, idx) <= 1).all() + + if allow_negative: + mask = rng.choice([False, True], idx.shape) + idx[mask] -= dim + return idx + + return fn + def rand_bool(rng): def generator(shape, dtype): return _cast_to_shape( @@ -1162,10 +1210,8 @@ class JaxTestCase(parameterized.TestCase): _compilation_cache_exit_stack: ExitStack | None = None - # TODO(mattjj): this obscures the error messages from failures, figure out how - # to re-enable it - # def tearDown(self) -> None: - # assert core.reset_trace_state() + def tearDown(self) -> None: + assert core.reset_trace_state() def setUp(self): super().setUp() @@ -1397,6 +1443,16 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) +def with_user_mesh(sizes, names): + def decorator(fn): + def mesh_fn(*args, **kwargs): + mesh = create_mesh(sizes, names) + with mesh_lib.set_mesh(mesh): + return fn(*args, **kwargs, mesh=mesh) + return mesh_fn + return decorator + + def create_mesh(mesh_shape, axis_names, iota_order=False): size = math.prod(mesh_shape) if len(jax.devices()) < size: @@ -1433,10 +1489,19 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ - _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, - _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + float_dtypes = [ + _dtypes.bfloat16, + _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e4m3fn, + _dtypes.float8_e4m3fnuz, + _dtypes.float8_e5m2, + _dtypes.float8_e5m2fnuz, + ] + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return self.supported(float_dtypes) @_cached_property def floating(self): @@ -1452,8 +1517,7 @@ def integer(self): @_cached_property def all_integer(self): - return self.supported([ - _dtypes.int4, np.int8, np.int16, np.int32, np.int64]) + return self.supported([np.int8, np.int16, np.int32, np.int64]) @_cached_property def unsigned(self): @@ -1461,8 +1525,7 @@ def unsigned(self): @_cached_property def all_unsigned(self): - return self.supported([ - _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64]) + return self.supported([np.uint8, np.uint16, np.uint32, np.uint64]) @_cached_property def complex(self): @@ -1653,6 +1716,10 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): size_im = size_re finfo = np.finfo(dtype) + machine = platform.machine() + is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm') + smallest = np.nextafter(finfo.tiny, finfo.max) if is_arm_cpu and platform.system() == 'Darwin' else finfo.tiny + def make_axis_points(size): prec_dps_ratio = 3.3219280948873626 logmin = logmax = finfo.maxexp / prec_dps_ratio @@ -1671,8 +1738,8 @@ def make_axis_points(size): axis_points[1] = finfo.min axis_points[-2] = finfo.max if size > 0: - axis_points[size] = -finfo.tiny - axis_points[-size - 1] = finfo.tiny + axis_points[size] = -smallest + axis_points[-size - 1] = smallest axis_points[0] = -np.inf axis_points[-1] = np.inf return axis_points diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 265d36d62b50..9e54f62d9ea0 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -29,7 +29,7 @@ from typing import Any import jax -from jax import core +from jax._src import core from jax._src import config from jax._src import sharding_impls from jax._src.interpreters import mlir @@ -62,6 +62,11 @@ help="Allow hlo dialects in Mosaic", ) + +# This tracks the latest Mosaic IR version with a monthly delay. +FWD_COMPAT_IR_VERSION = 3 + + tpu_custom_call_p = core.Primitive("tpu_custom_call") tpu_custom_call_p.def_impl( functools.partial(xla.apply_primitive, tpu_custom_call_p)) @@ -204,9 +209,6 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") - # Prevent the compiler from sharding the custom call beyond what Mosaic does - # based on user annotations - config.write(b', "implicit_sharding": {"type": "MANUAL"}') config.write(b"}") return config.getvalue() @@ -281,6 +283,7 @@ def _lower_tpu_kernel( module: ir.Module, hardware_generation: int, target_shape: tuple[int, int], + kernel_name: str | None = None, ) -> ir.Module: """Runs MLIR passes lowering the given module to an MLIR module. @@ -306,8 +309,7 @@ def _lower_tpu_kernel( tpu.register_dialect(ctx) mhlo.register_mhlo_dialect(ctx) mhlo.register_mhlo_passes() - - dump_mlir(module, "original") + dump_mlir(module, "original", kernel_name) if _MOSAIC_ALLOW_HLO.value: # Run hlo dialect conversion: hlo -> linalg -> vector. @@ -409,6 +411,8 @@ def _lower_mosaic_module_to_asm( *, backend: str, device_type: str | None, + kernel_name: str | None, + ir_version: int | None = None, ) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]: has_communication, has_custom_barrier = tpu.private_has_communication( module.operation @@ -417,10 +421,11 @@ def _lower_mosaic_module_to_asm( needs_layout_passes = not device_type # We'll mutate the module, so clone it with module.context as ctx, module.operation.location as _: - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + module_op = module.operation some_tpu = jax.devices(backend)[0] device_kind = some_tpu.device_kind if not device_kind.startswith("TPU v"): @@ -431,19 +436,30 @@ def _lower_mosaic_module_to_asm( hardware_generation = int(device_kind[len("TPU v")]) target_shape = get_target_shape(hardware_generation) module = _lower_tpu_kernel( - module, hardware_generation, target_shape=target_shape + module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name, ) needs_hlo_passes = False needs_layout_passes = False + else: + module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True + # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. + if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): + target_version = "" + else: + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: - pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})") - pipeline.run(module.operation) + pipeline = PassManager.parse( + "builtin.module(mosaic-serde{serialize=true " + target_version + "})" + ) + pipeline.run(module_op) finally: ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects bytecode_buffer = io.BytesIO() - module.operation.write_bytecode(bytecode_buffer, desired_version=0) + module_op.write_bytecode(bytecode_buffer, desired_version=0) asm = bytecode_buffer.getvalue() return asm, ( has_communication, @@ -453,6 +469,44 @@ def _lower_mosaic_module_to_asm( ) +def _get_device_type(module: ir.Module) -> str | None: + """Determines the device type based on the core_type annotations.""" + sparsecore_func_found = False + tensorcore_func_found = False + + def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult: + nonlocal sparsecore_func_found + nonlocal tensorcore_func_found + if op.name == "func.func": + if "tpu.core_type" in op.attributes: + core_type = op.attributes["tpu.core_type"] + if str(core_type) in [ + f"#tpu.core_type<{c}>" + for c in ["sc_scalar_subcore", "sc_vector_subcore"] + ]: + sparsecore_func_found = True + if tensorcore_func_found: + return ir.WalkResult.INTERRUPT + return ir.WalkResult.SKIP + if str(core_type) == "#tpu.core_type": + tensorcore_func_found = True + return ir.WalkResult.SKIP + raise ValueError(f"Unknown core type: {core_type}") + return ir.WalkResult.ADVANCE + + module.operation.walk( + assign_device_type_based_on_core_type, walk_order=ir.WalkOrder.PRE_ORDER + ) + if tensorcore_func_found and sparsecore_func_found: + raise ValueError( + "A single Mosaic kernel cannot contain both " + "TensorCore and SparseCore functions." + ) + if sparsecore_func_found: + return "sparsecore" + return None + + def _lower_to_custom_call_config( module: ir.Module, *, @@ -466,6 +520,8 @@ def _lower_to_custom_call_config( collective_id: int | None, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + kernel_name: str | None = None, + ir_version: int | None = None, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -476,6 +532,8 @@ def _lower_to_custom_call_config( module, backend=backend, device_type=device_type, + kernel_name=kernel_name, + ir_version=ir_version, ) return _lowered_to_custom_call_config( lowered_module_asm, @@ -575,6 +633,8 @@ def lower_module_to_custom_call( device_type=device_type, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, + kernel_name=kernel_name, + ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, ) return _tpu_custom_call_lowering( ctx, @@ -592,7 +652,6 @@ def as_tpu_kernel( *, cost_estimate: CostEstimate | None = None, backend: str | xla_client.Client = "tpu", - device_type: str | None = None, kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, @@ -604,6 +663,7 @@ def as_tpu_kernel( output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" + device_type = _get_device_type(module) config = _lower_to_custom_call_config( module, backend=backend, @@ -616,6 +676,7 @@ def as_tpu_kernel( collective_id=collective_id, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, + kernel_name=kernel_name, ) return _as_jax_callable( config, @@ -697,7 +758,7 @@ def apply_kernel(*args): return jax.jit(apply_kernel) -def dump_mlir(module: ir.Module, name: str): +def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None): """A helper function to dump mosaic mlir module""" try: should_dump = FLAGS["xla_mosaic_dump_to"].value @@ -706,6 +767,8 @@ def dump_mlir(module: ir.Module, name: str): if should_dump == "sponge": outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None) if outdir: + if kernel_name: + name = f"{kernel_name}-{name}" path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt") with open(path, "w") as f: f.write(str(module)) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 474bdfe4ec04..77871f3a908f 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -16,13 +16,12 @@ import collections from collections.abc import Callable, Hashable, Iterable, Sequence import dataclasses -from dataclasses import dataclass import difflib import functools from functools import partial import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, Union, overload +from typing import Any, NamedTuple, TypeVar, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -209,12 +208,21 @@ def all_leaves(iterable: Iterable[Any], _Children = TypeVar("_Children", bound=Iterable[Any]) _AuxData = TypeVar("_AuxData", bound=Hashable) +KeyEntry = TypeVar("KeyEntry", bound=Any) +KeyLeafPair = tuple[KeyEntry, Any] +KeyLeafPairs = Iterable[KeyLeafPair] +KeyPath = tuple[KeyEntry, ...] @export -def register_pytree_node(nodetype: type[T], - flatten_func: Callable[[T], tuple[_Children, _AuxData]], - unflatten_func: Callable[[_AuxData, _Children], T]) -> None: +def register_pytree_node( + nodetype: type[T], + flatten_func: Callable[[T], tuple[_Children, _AuxData]], + unflatten_func: Callable[[_AuxData, _Children], T], + flatten_with_keys_func: ( + Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None + ) = None, +) -> None: """Extends the set of types that are considered internal nodes in pytrees. See :ref:`example usage `. @@ -279,9 +287,15 @@ def register_pytree_node(nodetype: type[T], >>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32) """ - default_registry.register_node(nodetype, flatten_func, unflatten_func) - none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) - dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) + default_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + ) + none_leaf_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + ) + dispatch_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + ) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) @@ -452,21 +466,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool return all(tree_leaves(tree, is_leaf=is_leaf)) -register_pytree_node( - collections.OrderedDict, - lambda x: (tuple(x.values()), tuple(x.keys())), - lambda keys, values: collections.OrderedDict(safe_zip(keys, values))) - -def _flatten_defaultdict(d): - keys = tuple(sorted(d)) - return tuple(d[k] for k in keys), (d.default_factory, keys) - -register_pytree_node( - collections.defaultdict, - _flatten_defaultdict, - lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) - - class _HashableCallableShim: """Object that delegates __call__, __hash__, and __eq__ to another object.""" @@ -578,11 +577,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any, # flatten_one_level is not exported. -def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: +def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]: """Flatten the given pytree node by one level. Args: - pytree: A valid pytree node, either built-in or registered via + tree: A valid pytree node, either built-in or registered via :func:`register_pytree_node` or related functions. Returns: @@ -601,9 +600,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: >>> meta ('a', 'b') """ - out = default_registry.flatten_one_level(pytree) + out = default_registry.flatten_one_level(tree) if out is None: - raise ValueError(f"can't tree-flatten type: {type(pytree)}") + raise ValueError(f"can't tree-flatten type: {type(tree)}") else: return out @@ -704,45 +703,10 @@ def _equality_errors(path, t1, t2, is_leaf): yield from _equality_errors((*path, k), c1, c2, is_leaf) -@export -@dataclass(frozen=True) -class SequenceKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - idx: int - def __str__(self): - return f'[{self.idx!r}]' - - -@export -@dataclass(frozen=True) -class DictKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - key: Hashable - def __str__(self): - return f'[{self.key!r}]' - - -@export -@dataclass(frozen=True) -class GetAttrKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - name: str - def __str__(self): - return f'.{self.name}' - - -@export -@dataclass(frozen=True) -class FlattenedIndexKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - key: int - def __str__(self): - return f'[]' - -BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey] - -KeyEntry = TypeVar("KeyEntry", bound=Hashable) -KeyPath = tuple[KeyEntry, ...] +SequenceKey: Any = pytree.SequenceKey # type: ignore +DictKey: Any = pytree.DictKey # type: ignore +GetAttrKey: Any = pytree.GetAttrKey # type: ignore +FlattenedIndexKey: Any = pytree.FlattenedIndexKey # type: ignore @export @@ -764,6 +728,7 @@ def keystr(keys: KeyPath): return ''.join(map(str, keys)) +# TODO(ivyzheng): remove this after _child_keys() also moved to C++. class _RegistryWithKeypathsEntry(NamedTuple): flatten_with_keys: Callable[..., Any] unflatten_func: Callable[..., Any] @@ -780,7 +745,6 @@ def flatten_with_keys(xs): flatten_with_keys, _registry[ty].from_iter ) - _registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} _register_keypaths( @@ -803,13 +767,9 @@ def flatten_with_keys(xs): @export def register_pytree_with_keys( nodetype: type[T], - flatten_with_keys: Callable[ - [T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData] - ], + flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]], unflatten_func: Callable[[_AuxData, Iterable[Any]], T], - flatten_func: None | ( - Callable[[T], tuple[Iterable[Any], _AuxData]] - ) = None, + flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None, ): """Extends the set of types that are considered internal nodes in pytrees. @@ -870,7 +830,9 @@ def flatten_func_impl(tree): return [c for _, c in key_children], treedef flatten_func = flatten_func_impl - register_pytree_node(nodetype, flatten_func, unflatten_func) + register_pytree_node( + nodetype, flatten_func, unflatten_func, flatten_with_keys + ) _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( flatten_with_keys, unflatten_func ) @@ -927,8 +889,8 @@ class that defines how it could be flattened with keys. @export def register_dataclass( nodetype: Typ, - data_fields: Sequence[str], - meta_fields: Sequence[str], + data_fields: Sequence[str] | None = None, + meta_fields: Sequence[str] | None = None, drop_fields: Sequence[str] = (), ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. @@ -945,24 +907,33 @@ def register_dataclass( attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among ``meta_fields`` or ``data_fields``. - meta_fields: auxiliary data field names. These fields *must* contain static, - hashable, immutable objects, as these objects are used to generate JIT cache - keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or - :class:`numpy.ndarray` objects. - data_fields: data field names. These fields *must* be JAX-compatible objects - such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or - pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be - ``None``, as this is recognized by JAX as an empty pytree. + meta_fields: metadata field names: these are attributes which will be treated as + {term}`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is + optional only if ``nodetype`` is a dataclass, in which case individual fields can + be marked static via :func:`dataclasses.field` (see examples below). + Metadata fields *must* be static, hashable, immutable objects, as these objects + are used to generate JIT cache keys. In particular, metadata fields cannot contain + :class:`jax.Array` or :class:`numpy.ndarray` objects. + data_fields: data field names: these are attributes which will be treated as non-static + when this pytree is passed to :func:`jax.jit`. ``data_fields`` is optional only if + ``nodetype`` is a dataclass, in which case fields are assumed data fields unless + marked via :func:`dataclasses.field` (see examples below). + Data fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array` + or :class:`numpy.ndarray`), scalars, or pytrees whose leaves are arrays or scalars. + Note that ``None`` is a valid data field, as JAX recognizes this as an empty pytree. Returns: The input class ``nodetype`` is returned unchanged after being added to JAX's - pytree registry. This return value allows ``register_dataclass`` to be partially - evaluated and used as a decorator as in the example below. + pytree registry, so that :func:`register_dataclass` can be used as a decorator. Examples: + In JAX v0.4.35 or older, you must specify ``data_fields`` and ``meta_fields`` + in order to use this decorator: + + >>> import jax >>> from dataclasses import dataclass >>> from functools import partial - >>> + ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) @@ -976,7 +947,26 @@ def register_dataclass( >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') - Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`: + Starting in JAX v0.4.36, the ``data_fields`` and ``meta_fields`` arguments are optional + for :func:`~dataclasses.dataclass` inputs, with fields defaulting to ``data_fields`` + unless marked as static using `static` metadata in :func:`dataclasses.field`. + + >>> import jax + >>> from dataclasses import dataclass, field + ... + >>> @jax.tree_util.register_dataclass + ... @dataclass + ... class MyStruct: + ... x: jax.Array # defaults to non-static data field + ... y: jax.Array # defaults to non-static data field + ... op: str = field(metadata=dict(static=True)) # marked as static meta field. + ... + >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') + >>> m + MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') + + Once this class is registered, it can be used with functions in :mod:`jax.tree` and + :mod:`jax.tree_util`: >>> leaves, treedef = jax.tree.flatten(m) >>> leaves @@ -987,7 +977,8 @@ def register_dataclass( MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') In particular, this registration allows ``m`` to be passed seamlessly through code - wrapped in :func:`jax.jit` and other JAX transformations: + wrapped in :func:`jax.jit` and other JAX transformations, with ``data_fields`` being + treated as dynamic arguments, and ``meta_fields`` being treated as static arguments: >>> @jax.jit ... def compiled_func(m): @@ -999,6 +990,21 @@ def register_dataclass( >>> compiled_func(m) Array([1., 2., 3.], dtype=float32) """ + if data_fields is None or meta_fields is None: + if (data_fields is None) != (meta_fields is None): + raise TypeError("register_dataclass: data_fields and meta_fields must both be specified" + f" when either is specified. Got {data_fields=} {meta_fields=}.") + if not dataclasses.is_dataclass(nodetype): + raise TypeError("register_dataclass: data_fields and meta_fields are required when" + f" nodetype is not a dataclass. Got {nodetype=}.") + data_fields = [f.name for f in dataclasses.fields(nodetype) + if not f.metadata.get('static', False)] + meta_fields = [f.name for f in dataclasses.fields(nodetype) + if f.metadata.get('static', False)] + + assert meta_fields is not None + assert data_fields is not None + # Store inputs as immutable tuples in this scope, because we close over them # for later evaluation. This prevents potentially confusing behavior if the # caller were to pass in lists that are later mutated. @@ -1048,6 +1054,23 @@ def flatten_func(x): return nodetype +register_pytree_with_keys( + collections.OrderedDict, + lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())), + lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), +) + +def _flatten_defaultdict_with_keys(d): + keys = tuple(sorted(d)) + return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys) + +register_pytree_with_keys( + collections.defaultdict, + _flatten_defaultdict_with_keys, + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), +) + + @export def register_static(cls: type[H]) -> type[H]: """Registers `cls` as a pytree with no leaves. @@ -1100,8 +1123,7 @@ def tree_flatten_with_path( which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree. """ - _, tree_def = tree_flatten(tree, is_leaf) - return _generate_key_paths(tree, is_leaf), tree_def + return default_registry.flatten_with_path(tree, is_leaf) @export @@ -1120,51 +1142,17 @@ def tree_leaves_with_path( - :func:`jax.tree_util.tree_leaves` - :func:`jax.tree_util.tree_flatten_with_path` """ - return _generate_key_paths(tree, is_leaf) + return tree_flatten_with_path(tree, is_leaf)[0] # generate_key_paths is not exported. def generate_key_paths( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: - return list(_generate_key_paths_((), tree, is_leaf)) + return tree_leaves_with_path(tree, is_leaf) _generate_key_paths = generate_key_paths # alias for backward compat -# The overall logic should be same as PyTreeDef::FlattenIntoImpl -def _generate_key_paths_( - key_path: KeyPath, - tree: Any, - is_leaf: Callable[[Any], bool] | None = None, -) -> Iterable[tuple[KeyPath, Any]]: - if is_leaf and is_leaf(tree): - yield key_path, tree - return - key_handler = _registry_with_keypaths.get(type(tree)) - if key_handler: - key_children, _ = key_handler.flatten_with_keys(tree) - for k, c in key_children: - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - return - - flat = default_registry.flatten_one_level(tree) - if flat is None: - yield key_path, tree # strict leaf type - return - - if (isinstance(tree, tuple) and hasattr(tree, '_fields') and - flat[1] == type(tree)): - # handle namedtuple as a special case, based on heuristic - key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] - for k, c in key_children: - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - return - - for i, c in enumerate(flat[0]): - k = FlattenedIndexKey(i) - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - - @export def tree_map_with_path(f: Callable[..., Any], tree: Any, *rest: Any, diff --git a/jax/_src/util.py b/jax/_src/util.py index fce342c493ed..8dcc5eaa5804 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -453,6 +453,10 @@ def tuple_update(t, idx, val): assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] +def tuple_replace(tupl, index, item): + # unlike tuple_update, works with negative indices as well + return tupl[:index] + (item,) + tupl[index:][1:] + class HashableFunction: """Decouples function equality and hash from its identity. diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 23b255ef1750..28148761c8a4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -90,6 +90,13 @@ help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) +_MOCK_GPU_TOPOLOGY = config.string_flag( + name="jax_mock_gpu_topology", + default="", + help='Mock multi-host GPU topology in GPU client. The value should ' + 'be of the form " x x ' + '". Empty string turns off mocking.', +) _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", @@ -425,6 +432,14 @@ def _version_check(name: str, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') +def _get_num_nodes_from_gpu_topology(topology: str) -> int: + try: + slices_str, hosts_per_slice_str, _ = topology.split("x", 2) + return int(slices_str) * int(hosts_per_slice_str) + except (IndexError, ValueError): + raise ValueError('Mock topology must be of the form ' + '" x x ' + '".') def make_gpu_client( *, platform_name: str, visible_devices_flag: config.Flag[str] @@ -434,12 +449,14 @@ def make_gpu_client( if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} - use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0 - num_nodes = ( - _MOCK_NUM_GPU_PROCESSES.value - if use_mock_gpu_client - else distributed.global_state.num_processes - ) + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + + use_mock_gpu_client = mock_num_gpu_processes > 0 + num_nodes = (mock_num_gpu_processes if use_mock_gpu_client + else distributed.global_state.num_processes) + if platform_name == "cuda": if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): _check_cuda_versions() @@ -634,10 +651,14 @@ def _options_from_jax_configs(plugin_name): visible_devices = CUDA_VISIBLE_DEVICES.value if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value - options['enable_mock_nccl'] = mock_gpu_processes > 0 - if options['enable_mock_nccl']: - options['num_nodes'] = mock_gpu_processes + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + options['enable_mock_nccl'] = mock_num_processes > 0 + if mock_num_processes > 0: + options['num_nodes'] = mock_num_processes + if mock_gpu_topology: + options['mock_gpu_topology'] = mock_gpu_topology return options diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 94b482e2dea4..94e26eeefa65 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -41,15 +41,13 @@ def set_xla_metadata(*args, **kwargs): thread_local_metadata.val, new_metadata, ) - config.update_thread_local_jit_state( - xla_metadata_context_manager=tuple( - (v, k) for k, v in sorted(new_metadata.items()))) + config.xla_metadata_context_manager.set_local( + tuple((v, k) for k, v in sorted(new_metadata.items())) + ) try: yield finally: thread_local_metadata.val = prev_metadata - config.update_thread_local_jit_state( - xla_metadata_context_manager=tuple( - (v, k) for k, v in sorted(prev_metadata.items()) - ) + config.xla_metadata_context_manager.set_local( + tuple((v, k) for k, v in sorted(prev_metadata.items())) ) diff --git a/jax/core.py b/jax/core.py index 9682d106e202..4d1742bc28ea 100644 --- a/jax/core.py +++ b/jax/core.py @@ -19,10 +19,10 @@ AbstractToken as AbstractToken, AbstractValue as AbstractValue, Atom as Atom, + axis_frame as axis_frame, AxisSize as AxisSize, + AxisName as AxisName, CallPrimitive as CallPrimitive, - ClosedJaxpr as ClosedJaxpr, - ConcreteArray as ConcreteArray, ConcretizationTypeError as ConcretizationTypeError, DShapedArray as DShapedArray, DropVar as DropVar, @@ -33,43 +33,29 @@ InDBIdx as InDBIdx, InconclusiveDimensionOperation as InconclusiveDimensionOperation, InputType as InputType, - Jaxpr as Jaxpr, JaxprDebugInfo as JaxprDebugInfo, - JaxprEqn as JaxprEqn, JaxprPpContext as JaxprPpContext, JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, - Literal as Literal, - MainTrace as MainTrace, MapPrimitive as MapPrimitive, nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OpaqueTraceState as OpaqueTraceState, - NameGatheringSubst as NameGatheringSubst, OutDBIdx as OutDBIdx, OutputType as OutputType, ParamDict as ParamDict, - Primitive as Primitive, ShapedArray as ShapedArray, - Sublevel as Sublevel, TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING, - ThreadLocalState as ThreadLocalState, - Token as Token, Trace as Trace, - TraceStack as TraceStack, - TraceState as TraceState, Tracer as Tracer, unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 + unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401 UnshapedArray as UnshapedArray, Value as Value, - Var as Var, abstract_token as abstract_token, - apply_todos as apply_todos, aval_mapping_handlers as aval_mapping_handlers, - axis_frame as axis_frame, call as call, - call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, check_jaxpr as check_jaxpr, @@ -77,69 +63,49 @@ concrete_aval as concrete_aval, concrete_or_error as concrete_or_error, concretization_function_error as concretization_function_error, - cur_sublevel as cur_sublevel, custom_typechecks as custom_typechecks, dedup_referents as dedup_referents, - do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr, ensure_compile_time_eval as ensure_compile_time_eval, escaped_tracer_error as escaped_tracer_error, eval_context as eval_context, eval_jaxpr as eval_jaxpr, - extend_axis_env as extend_axis_env, extend_axis_env_nd as extend_axis_env_nd, find_top_trace as find_top_trace, - full_lower as full_lower, gensym as gensym, get_aval as get_aval, get_type as get_type, get_referent as get_referent, + is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, - jaxpr_as_fun as jaxpr_as_fun, - jaxpr_uses_outfeed as jaxpr_uses_outfeed, jaxprs_in_params as jaxprs_in_params, join_effects as join_effects, - lattice_join as lattice_join, leaked_tracer_error as leaked_tracer_error, literalable_types as literalable_types, - map_bind as map_bind, - map_bind_with_continuation as map_bind_with_continuation, mapped_aval as mapped_aval, maybe_find_leaked_tracers as maybe_find_leaked_tracers, max_dim as max_dim, min_dim as min_dim, - new_base_main as new_base_main, new_jaxpr_eqn as new_jaxpr_eqn, - new_main as new_main, - new_sublevel as new_sublevel, no_axis_name as no_axis_name, no_effects as no_effects, - outfeed_primitives as outfeed_primitives, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, - primitive_uses_outfeed as primitive_uses_outfeed, - process_env_traces_call as process_env_traces_call, - process_env_traces_map as process_env_traces_map, pytype_aval_mappings as pytype_aval_mappings, - raise_as_much_as_possible as raise_as_much_as_possible, - raise_to_shaped as raise_to_shaped, raise_to_shaped_mappings as raise_to_shaped_mappings, reset_trace_state as reset_trace_state, - stash_axis_env as stash_axis_env, + set_current_trace as set_current_trace, str_eqn_compact as str_eqn_compact, subjaxprs as subjaxprs, - subst_axis_names as subst_axis_names, - subst_axis_names_eqn as subst_axis_names_eqn, - subst_axis_names_jaxpr as subst_axis_names_jaxpr, - subst_axis_names_var as subst_axis_names_var, substitute_vars_in_output_ty as substitute_vars_in_output_ty, - thread_local_state as thread_local_state, + take_current_trace as take_current_trace, + trace_ctx as trace_ctx, trace_state_clean as trace_state_clean, + TraceTag as TraceTag, traverse_jaxpr_params as traverse_jaxpr_params, typecheck as typecheck, typecompat as typecompat, typematch as typematch, unmapped_aval as unmapped_aval, - used_axis_names as used_axis_names, used_axis_names_jaxpr as used_axis_names_jaxpr, valid_jaxtype as valid_jaxtype, ) @@ -147,6 +113,37 @@ from jax._src import core as _src_core _deprecations = { + # Added 2024-12-10 + "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.ClosedJaxpr), + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.full_lower), + "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Jaxpr), + "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.JaxprEqn), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.jaxpr_as_fun), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.lattice_join), + "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Literal), + "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Primitive), + "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.raise_to_shaped), + "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Token), + "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Var), # Added 2024-08-14 "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn), "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type), @@ -167,28 +164,6 @@ "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), - # Finalized 2024-05-13; remove after 2024-08-13 - "DimSize": ( - "jax.core.DimSize is deprecated. Use DimSize = int | Any.", - None, - ), - "Shape": ( - "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", - None, - ), - # Finalized 2024-06-24; remove after 2024-09-24 - "canonicalize_shape": ( - "jax.core.canonicalize_shape is deprecated.", None, - ), - "dimension_as_value": ( - "jax.core.dimension_as_value is deprecated. Use jnp.array.", None, - ), - "definitely_equal": ( - "jax.core.definitely_equal is deprecated. Use ==.", None, - ), - "symbolic_equal_dim": ( - "jax.core.symbolic_equal_dim is deprecated. Use ==.", None, - ), # Added Jan 8, 2024 "non_negative_dim": ( "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim, @@ -197,10 +172,21 @@ import typing if typing.TYPE_CHECKING: + ClosedJaxpr = _src_core.ClosedJaxpr + Jaxpr = _src_core.Jaxpr + JaxprEqn = _src_core.JaxprEqn + Literal = _src_core.Literal + Primitive = _src_core.Primitive + Token = _src_core.Token + Var = _src_core.Var check_eqn = _src_core.check_eqn check_type = _src_core.check_type check_valid_jaxtype = _src_core.check_valid_jaxtype + full_lower = _src_core.full_lower + jaxpr_as_fun = _src_core.jaxpr_as_fun + lattice_join = _src_core.lattice_join non_negative_dim = _src_core.non_negative_dim + raise_to_shaped = _src_core.raise_to_shaped else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 3df3eb25c40d..c525185e6449 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import jax import jax.numpy as jnp +from jax._src import config from jax._src import test_util as jtu from jax._src import array from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding @@ -375,6 +376,8 @@ def cb1(index): @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -580,6 +583,8 @@ def test_load_with_layout(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_deserialization_with_int4(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") if jtu.test_device_matches(['gpu']): self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 62da0f231d50..b4adbadfa6c5 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -14,18 +14,20 @@ from __future__ import annotations -from contextlib import contextmanager from typing import Any from jax._src import core +from jax._src import source_info_util from jax._src import api_util from jax._src import linear_util as lu +from jax._src.ad_util import (Zero) from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, treedef_tuple) from jax._src.util import unzip2, safe_map, safe_zip, split_list +from jax._src.dtypes import dtype, float0 map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -35,23 +37,13 @@ register = api_util.register_class_with_attrs -@contextmanager -def top_trace(): - stack = core.thread_local_state.trace_state.trace_stack.stack - main = stack.pop() - try: - trace = main.with_cur_sublevel() - yield trace - finally: - stack.append(main) - def jax_getattr(obj: Any, attr: str): - with top_trace() as trace: - return trace.process_getattr(obj, attr) + with core.take_current_trace() as t: + return t.process_getattr(obj, attr) def jax_setattr(obj: Any, attr: str, val: Pytree): - with top_trace() as trace: - return trace.process_setattr(obj, attr, val) + with core.take_current_trace() as t: + return t.process_setattr(obj, attr, val) def _getattr_impl(_, obj, attr): return getattr(obj, attr) @@ -62,7 +54,7 @@ def _setattr_impl(_, obj, attr, val): core.EvalTrace.process_setattr = _setattr_impl def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): - frame = trace.main.jaxpr_stack[-1] # type: ignore + frame = trace.frame def new_tracer(x): aval = core.raise_to_shaped(core.get_aval(x)) @@ -105,48 +97,51 @@ def jvp(f, primals, tangents, attr_tangents): out_tangents = tree_unflatten(out_tree(), out_tangents_flat) return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def _set_attrs(attrs, attr_vals, *args): +@lu.transformation2 +def _set_attrs(f, attrs, attr_vals, *args): for (o, a), x in zip(attrs, attr_vals): jax_setattr(o, a, x) - yield (yield args, {}) + return f(*args) def _jvp(fun: lu.WrappedFun): return jvpfun2(jvp_subtrace2(fun)) -@lu.transformation -def jvpfun2(primals, tangents): - with core.new_main(ad.JVPTrace) as main: - out_primals, out_tangents, tangent_attrs_out = \ - yield (main, primals, tangents), {} - del main - yield out_primals, out_tangents, tangent_attrs_out - -@lu.transformation -def jvp_subtrace2(main, primals, tangents): - main.attrs_tracked = [] # attrs written to - trace = main.with_cur_sublevel() - in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x - for x, t in zip(primals, tangents)] - ans = yield in_tracers, {} - out_tracers = map(trace.full_raise, ans) - out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in out_tracers) - tangent_attrs_out = [] - for (obj, name) in main.attrs_tracked: - tracer = trace.full_raise(jax_getattr(obj, name)) - jax_setattr(obj, name, tracer.primal) - if type(tracer.tangent) is not ad.Zero: - tangent_attrs_out.append((obj, name, tracer.tangent)) - del main.attrs_tracked - yield out_primals, out_tangents, tangent_attrs_out +@lu.transformation2 +def jvpfun2(f, primals, tangents): + tag = core.TraceTag() + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and dtype(t) == float0 else t for t in tangents] + ctx = source_info_util.transform_name_stack('jvp') + with ctx: + out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) + return out_primals, out_tangents, tangent_attrs_out + +@lu.transformation2 +def jvp_subtrace2(f, tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = ad.JVPTrace(parent_trace, tag) + tag.attrs_tracked = [] # attrs written to + in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x + for x, t in zip(primals, tangents)] + with core.set_current_trace(trace): + ans = f(*in_tracers) + out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) + tangent_attrs_out = [] + for (obj, name) in tag.attrs_tracked: + primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) + jax_setattr(obj, name, primal) + if type(tangent) is not ad.Zero: + tangent_attrs_out.append((obj, name, tangent)) + del tag.attrs_tracked + return out_primals, out_tangents, tangent_attrs_out def _setattr_jvp(trace, obj, attr, maybe_tracer): - tracer = trace.full_raise(maybe_tracer) - if isinstance(tracer.tangent, ad.Zero): - return setattr(obj, attr, tracer.primal) - if (obj, attr) not in trace.main.attrs_tracked: - trace.main.attrs_tracked.append((obj, attr)) - return setattr(obj, attr, tracer) + primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) + if isinstance(tangent, ad.Zero): + return setattr(obj, attr, primal) + if (obj, attr) not in trace.tag.attrs_tracked: + trace.tag.attrs_tracked.append((obj, attr)) + return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) ad.JVPTrace.process_setattr = _setattr_jvp def _getattr_jvp(trace, obj, attr): @@ -180,11 +175,12 @@ def _linearize(traceable: lu.WrappedFun, *primals): return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], jaxpr, consts, attrs()) -@lu.transformation_with_aux -def _split_attrs(*args, **kwargs): - primals, tangents, tangent_attrs = yield args, kwargs +@lu.transformation_with_aux2 +def _split_attrs(f, store, *args, **kwargs): + primals, tangents, tangent_attrs = f(*args, **kwargs) attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - yield (primals, tangents, tangent_attr_vals), attrs + store.store(attrs) + return primals, tangents, tangent_attr_vals def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): in_tree, out_tree = io_tree diff --git a/jax/experimental/colocated_python/__init__.py b/jax/experimental/colocated_python/__init__.py new file mode 100644 index 000000000000..2e9b4f967cd7 --- /dev/null +++ b/jax/experimental/colocated_python/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python API.""" + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +# pylint: disable=useless-import-alias +from jax.experimental.colocated_python.api import ( + colocated_cpu_devices as colocated_cpu_devices, + colocated_python as colocated_python, +) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py new file mode 100644 index 000000000000..770820b39222 --- /dev/null +++ b/jax/experimental/colocated_python/api.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python top-level API.""" + +from __future__ import annotations + +import collections +from typing import Any, Callable, Sequence + +import jax +from jax._src import api_util +from jax.experimental.colocated_python.func import make_callable + + +def colocated_cpu_devices( + devices: Sequence[jax.Device], +) -> Sequence[jax.Device]: + """Finds CPU devices colocated with the given devices.""" + cpu_devices_by_colocation_id = collections.defaultdict(list) + for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access + if device.device_kind == "cpu": + cpu_devices_by_colocation_id[device.colocation_id].append(device) + if not cpu_devices_by_colocation_id: + raise ValueError("No CPU devices found") + + colocated_cpu_devices = [] + for device in devices: + matches = cpu_devices_by_colocation_id[device.colocation_id] + if not matches: + raise ValueError(f"Device {device} has no colocated devices") + elif len(matches) > 1: + raise ValueError( + f"Ambiguous colocated devices; device {device} has" + f" {len(matches)} colocated devices: f{matches}" + ) + colocated_cpu_devices.append(matches[0]) + return colocated_cpu_devices + + +def colocated_python(fun: Callable[..., Any]) -> Any: + """Executes the given Python function on the same device as the arguments.""" + return make_callable( + fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun) + ) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py new file mode 100644 index 000000000000..5567f2f765c1 --- /dev/null +++ b/jax/experimental/colocated_python/func.py @@ -0,0 +1,466 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python function API implementation.""" + +from __future__ import annotations + +import dataclasses +import inspect +import random +import threading +from typing import Any, Callable, Sequence + +import jax +from jax._src import api +from jax._src import tree_util +from jax._src.interpreters import pxla +from jax._src.lib import xla_client as xc +from jax._src.traceback_util import api_boundary +from jax._src.util import wraps +from jax.experimental.colocated_python import func_backend +from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.extend.ifrt_programs import ifrt_programs + +ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] + + +@dataclasses.dataclass(frozen=True, slots=True) +class FunctionInfo: + """User function wrapped by colocated_python.""" + + fun: Callable[..., Any] + fun_sourceinfo: str | None + fun_signature: inspect.Signature | None + + +@dataclasses.dataclass(frozen=True, slots=True) +class Specialization: + """Specialization for a colocated_python function.""" + + in_specs_treedef: tree_util.PyTreeDef | None = None + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None + out_specs_treedef: tree_util.PyTreeDef | None = None + out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None + devices: xc.DeviceList | None = None + + def update( + self, + *, + in_specs_treedef: tree_util.PyTreeDef | None = None, + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, + out_specs_treedef: tree_util.PyTreeDef | None = None, + out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, + devices: Sequence[jax.Device] | xc.DeviceList | None = None, + ) -> Any: + """Creates a new specialization with overrides.""" + if in_specs_treedef is None: + in_specs_treedef = self.in_specs_treedef + elif self.in_specs_treedef is not None: + raise ValueError("in_specs already specified") + if in_specs_leaves is None: + in_specs_leaves = self.in_specs_leaves + elif self.in_specs_leaves is not None: + raise ValueError("in_specs already specified") + + if out_specs_fn is None: + out_specs_fn = self.out_specs_fn + elif self.out_specs_fn is not None: + raise ValueError("out_specs_fn already specified") + + if out_specs_treedef is None: + out_specs_treedef = self.out_specs_treedef + elif self.out_specs_treedef is not None: + raise ValueError("out_specs already specified") + if out_specs_leaves is None: + out_specs_leaves = self.out_specs_leaves + elif self.out_specs_leaves is not None: + raise ValueError("out_specs already specified") + + if devices is None: + devices = self.devices + elif self.devices is not None: + raise ValueError("devices already specified") + elif not isinstance(devices, xc.DeviceList): + devices = xc.DeviceList(tuple(devices)) + + return Specialization( + in_specs_treedef, + in_specs_leaves, + out_specs_fn, + out_specs_treedef, + out_specs_leaves, + devices, + ) + + +def _get_spec(x: Any) -> api.ShapeDtypeStruct: + """Extracts a spec for a value, which must be a JAX Array.""" + # TODO(hyeontaek): Allow Python values and automatically apply `shard_arg` + # with a suitable sharding and layout. + if not isinstance(x, jax.Array): + raise ValueError( + "colocated_python only supports jax.Array as input and output, but got" + f" {type(x)}." + ) + return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + + +def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None: + """Returns a representative device list from function call arguments.""" + device_list_set: set[xc.DeviceList] = set() + for x in args: + sharding = getattr(x, "sharding", None) + if sharding is not None: + device_list_set.add(x.sharding._internal_device_list) + if not device_list_set: + return None + if len(device_list_set) != 1: + raise ValueError( + "All arguments must use the same device list, but got" + f" multiple device lists: {device_list_set}." + ) + return device_list_set.pop() + + +def _compile_to_executable( + name: str, + fun: Callable[..., Any], + in_specs_treedef: tree_util.PyTreeDef, + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...], + out_specs_treedef: tree_util.PyTreeDef, + out_specs_leaves: tuple[api.ShapeDtypeStruct, ...], + devices: xc.DeviceList, +) -> Callable[..., Any]: + """Compiles a Python function into a runtime executable.""" + fun_and_specialization = ( + fun, + in_specs_treedef, + in_specs_leaves, + out_specs_treedef, + out_specs_leaves, + devices, + ) + pickled_function = _serialize(fun_and_specialization) + program = ifrt_programs.make_colocated_python_program( + name, pickled_function, devices, in_specs_leaves, out_specs_leaves + ) + ifrt_client = devices[0].client + out_sdss = tuple( + jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves + ) + out_shardings = tuple(sds.sharding for sds in out_specs_leaves) + try: + compile_options = ifrt_programs.make_colocated_python_compile_options() + loaded_executable = ifrt_client.compile_ifrt_program( + program, compile_options + ) + out_handlers = pxla.global_avals_to_results_handler( + out_sdss, out_shardings, committed=True + ).handlers + + def call(*args, **kwargs): + args_leaves = tree_util.tree_leaves((args, kwargs)) + execute_result = loaded_executable.execute_sharded( + args_leaves, with_tokens=False + ) + results = execute_result.consume_with_handlers(out_handlers) + return tree_util.tree_unflatten(out_specs_treedef, results) + + return call + except jax.errors.JaxRuntimeError as e: + # TODO(hyeontaek): Implement colocated Python support in McJAX and remove + # this fallback path. + if "PjRtCompiler requires an HloProgram" in str(e): + return fun + raise + + +def _make_output_specs_and_push_result_fun( + info: FunctionInfo, specialization: Specialization, uid: int +) -> Callable[..., Any]: + """Creates a function that computes output specs and pushes the result to the result store.""" + assert specialization.in_specs_treedef is not None + assert specialization.in_specs_leaves is not None + assert specialization.out_specs_treedef is None + assert specialization.out_specs_leaves is None + assert specialization.devices is not None + + devices = specialization.devices + + def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: + result = info.fun(*args, **kwargs) + result_leaves, out_treedef = tree_util.tree_flatten(result) + out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) + func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves) + return _serialize_specs(out_treedef, out_spec_leaves, devices) + + out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( + _make_specs_for_serialized_specs(specialization.devices), + ) + name = getattr(info.fun, "__name__", "unknown") + name = f"{name}_output_specs_and_push_result" + return _compile_to_executable( + name=name, + fun=lowered_fun, + in_specs_treedef=specialization.in_specs_treedef, + in_specs_leaves=specialization.in_specs_leaves, + out_specs_treedef=out_specs_treedef, + out_specs_leaves=tuple(out_specs_leaves), + devices=specialization.devices, + ) + + +def _make_pop_result_fun( + info: FunctionInfo, specialization: Specialization, uid: int +) -> Callable[..., Any]: + """Makes a function that pops results from the result store.""" + assert specialization.out_specs_treedef is not None + assert specialization.out_specs_leaves is not None + assert specialization.devices is not None + + out_specs_treedef = specialization.out_specs_treedef + + def lowered_fun() -> Any: + result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) + return tree_util.tree_unflatten(out_specs_treedef, result_leaves) + + in_specs_leaves, in_specs_treedef = tree_util.tree_flatten(( + # args + (), + # kwargs + {}, + )) + name = getattr(info.fun, "__name__", "unknown") + name = f"{name}_pop_result" + return _compile_to_executable( + name=name, + fun=lowered_fun, + in_specs_treedef=in_specs_treedef, + in_specs_leaves=tuple(in_specs_leaves), + out_specs_treedef=specialization.out_specs_treedef, + out_specs_leaves=specialization.out_specs_leaves, + devices=specialization.devices, + ) + + +def _make_async_execution_fun( + info: FunctionInfo, specialization: Specialization +) -> Callable[..., Any]: + """Makes a function that asynchronously executes the function.""" + assert specialization.in_specs_treedef is not None + assert specialization.in_specs_leaves is not None + assert specialization.out_specs_treedef is not None + assert specialization.out_specs_leaves is not None + assert specialization.devices is not None + + name = getattr(info.fun, "__name__", "unknown") + return _compile_to_executable( + name=name, + fun=info.fun, + in_specs_treedef=specialization.in_specs_treedef, + in_specs_leaves=specialization.in_specs_leaves, + out_specs_treedef=specialization.out_specs_treedef, + out_specs_leaves=specialization.out_specs_leaves, + devices=specialization.devices, + ) + + +@jax.util.cache(max_size=None) +def _get_specialized_func( + info: FunctionInfo, specialization: Specialization +) -> Callable[..., Any]: + """Returns a specialized function for the given specialization.""" + assert specialization.in_specs_treedef is not None + assert specialization.in_specs_leaves is not None + assert specialization.devices is not None + uid = random.getrandbits(63) + + mutex = threading.Lock() + # Asynchronous execution function that has known output_specs. + async_execution_func = None + + def specialized_func(*args, **kwargs) -> Any: + """Specialized function to be executed with given args and kwargs.""" + nonlocal specialization, async_execution_func + with mutex: + if async_execution_func is None: + if specialization.out_specs_treedef is None: + if specialization.out_specs_fn is None: + serialized_out_specs = _make_output_specs_and_push_result_fun( + info, specialization, uid + )(*args, **kwargs) + + # Waits for the output_specs. This may block. + out_specs_treedef, out_specs_leaves = _deserialize_specs( + serialized_out_specs + ) + + # Subsequent calls would use async_execution_func with discovered + # output_specs. + specialization = specialization.update( + out_specs_treedef=out_specs_treedef, + out_specs_leaves=out_specs_leaves, + ) + async_execution_func = _make_async_execution_fun( + info, specialization + ) + + return _make_pop_result_fun(info, specialization, uid)() + else: + # Compute out_specs using out_specs_fn and inputs. + args_specs, kwargs_specs = tree_util.tree_map( + _get_spec, (args, kwargs) + ) + out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs) + # Type checking is ignored to silence mypy error: Incompatible types + # in assignment (expression has type "list[Any]", variable has type + # "tuple[ShapeDtypeStruct, ...]") [assignment] + out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment] + out_specs + ) + specialization = specialization.update( + out_specs_treedef=out_specs_treedef, + out_specs_leaves=tuple(out_specs_leaves), + ) + async_execution_func = _make_async_execution_fun( + info, specialization + ) + # Fall-through. + else: + async_execution_func = _make_async_execution_fun(info, specialization) + # Fall-through. + + # Asynchronous execution runs outside of the mutex to allow concurrent + # execution for inline executors. + return async_execution_func(*args, **kwargs) + + return specialized_func + + +def make_callable( + fun: Callable[..., Any], + fun_sourceinfo: str | None, + fun_signature: inspect.Signature | None, +) -> Callable[..., Any]: + """Makes a colocated Python callable.""" + return _make_callable( + FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() + ) + + +def _make_callable( + info: FunctionInfo, + specialization: Specialization, +) -> Callable[..., Any]: + """Internal implementation of make_callable.""" + + def specialize( + in_specs: ShapeDtypeStructTree | None = None, + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, + devices: Sequence[jax.Device] | None = None, + ) -> Callable[..., Any]: + """Returns a colocated Python callable with extra specialization. + + Args: + in_specs: Optionally specifies the expected input specs. Input specs are + expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a + function call. + out_specs_fn: Optionally specifies a function that computes the output + specs from input specs. If unspecified, colocated_python will compute + the output specs during the very first execution, and this execution + will be synchronous. + devices: Optionally specifies the devices to execute the function on. Must + be provided if in_specs has no leaves because devices cannot be inferred + from input specs or arguments. + + Returns: + A colocated Python callable with extra specialization. + """ + # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if + # `out_specs_fn(in_specs)` returns at least one leaf that we can use for + # inferring `devices`. + if in_specs is None: + in_specs_leaves, in_specs_treedef = None, None + else: + in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs) + in_specs_leaves = tuple(in_specs_leaves_list) + return _make_callable( + info, + specialization.update( + in_specs_treedef=in_specs_treedef, + in_specs_leaves=in_specs_leaves, + out_specs_fn=out_specs_fn, + devices=devices, + ), + ) + + @api_boundary + def __call__(*args, **kwargs) -> Any: + """Executes the function. + + If the output specs are not known, the very first execution will be + synchronous. + """ + args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) + + in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) + if specialization.in_specs_treedef is None: + # Allow input polymorphism by applying input_specs specialization + # temporarily for this call. + return _make_callable( + info, + specialization.update( + in_specs_treedef=in_specs_treedef, + in_specs_leaves=in_specs_leaves, + ), + )(*args, **kwargs) + + if specialization.devices is None: + devices = _infer_devices_from_args(args_leaves) + if devices is None: + raise ValueError( + "No devices found. colocated_python function without input" + " arguments must be first specialized with devices." + ) + # Allow device polymorphism by applying devices specialization temporarily + # for this call. + return _make_callable(info, specialization.update(devices=devices))( + *args, **kwargs + ) + + # Assertion is added to silence mypy error: Unsupported operand types for != + # ("PyTreeDef" and "None") [operator] + assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) + + # If input_specs is known, verify that it matches actual inputs. + if (specialization.in_specs_treedef != in_specs_treedef + or specialization.in_specs_leaves != in_specs_leaves): + raise ValueError( + "Input specs in specialization and input specs of arguments must have" + " the same pytree structure, but they have the following structural" + " differences:\n" + + ("\n".join( + f" - {tree_util.keystr(path)} is a {thing1} in value 1 and" + f" a {thing2} in value 2, so {explanation}.\n" + for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( + specialization.in_specs_treedef, in_specs_treedef + )))) + + return _get_specialized_func(info, specialization)(*args, **kwargs) + + __call__ = wraps(info.fun)(__call__) + __call__.specialize = specialize + return __call__ diff --git a/jax/experimental/colocated_python/func_backend.py b/jax/experimental/colocated_python/func_backend.py new file mode 100644 index 000000000000..aa514015004d --- /dev/null +++ b/jax/experimental/colocated_python/func_backend.py @@ -0,0 +1,44 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Backend for colocated_python.func.""" + +from __future__ import annotations + +import threading +from typing import Sequence + +import jax + + +class _ResultStore: + """Temporarily stores results from synchronous execution of functions.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._storage: dict[int, Sequence[jax.Array]] = {} + + def push(self, uid: int, out: Sequence[jax.Array]) -> None: + with self._lock: + if uid in self._storage: + raise ValueError(f"uid {uid} already exists") + self._storage[uid] = out + + def pop(self, uid: int) -> Sequence[jax.Array]: + with self._lock: + if uid not in self._storage: + raise ValueError(f"uid {uid} does not exist") + return self._storage.pop(uid) + + +SINGLETON_RESULT_STORE = _ResultStore() diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py new file mode 100644 index 000000000000..7e7654d4642a --- /dev/null +++ b/jax/experimental/colocated_python/serialization.py @@ -0,0 +1,242 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python serialization utilities.""" + +# TODO(jmudigonda): Use a string-typed array for output structure when it +# becomes available. Using a fixed uint8 array is only for prototyping. + +from __future__ import annotations + +import collections +import io +from typing import Any, Callable, Sequence + +try: + import cloudpickle # type: ignore[import-not-found] +except ImportError: + cloudpickle = None + +import jax +from jax._src import api +from jax._src import tree_util +from jax._src import xla_bridge as xb +from jax._src.lib import xla_client as xc +import numpy as np + +DeviceList = xc.DeviceList + +# Hard-coded limit for serialized specs size. +# TODO(jmudigonda): Use a string-typed array for output structure when it +# becomes available. Using a fixed uint8 array is only for prototyping. +_MAX_SERIALIZED_SPECS_SIZE = 1048576 + + +@jax.util.cache(max_size=None) +def _get_cpu_device_map() -> dict[int, jax.Device]: + """Returns a map from a device id to a matching device.""" + cpu_device_map: dict[int, jax.Device] = {} + # TODO(hyeontaek): We should look up CPU devices for a specific CPU backend. + # When deserializing a device on the controller, the backend should be the one + # associated with colocated_python. When deserializing on the colocated_python + # executor, it should be the CPU backend visible to the user function running + # under colocated_python. + + # Look for CPU devices in the default backend. + for d in xb.local_devices()[0].client._get_all_devices(): # pylint: disable=protected-access + if d.device_kind == "cpu": + if d.id in cpu_device_map: + raise ValueError( + f"Multiple CPU devices with id {d.id} found:" + f" {cpu_device_map[d.id]} and {d}" + ) + cpu_device_map[d.id] = d + if cpu_device_map: + return cpu_device_map + + # Fall back to searching CPU devices in all backends. + for backend in xb.backends().values(): + for d in backend._get_all_devices(): # pylint: disable=protected-access + if d.device_kind == "cpu": + if d.id in cpu_device_map: + raise ValueError( + f"Multiple CPU devices with id {d.id} found:" + f" {cpu_device_map[d.id]} and {d}" + ) + cpu_device_map[d.id] = d + return cpu_device_map + + +def _reduce_mesh( + mesh: jax.sharding.Mesh, +) -> tuple[Callable[..., jax.sharding.Mesh], Any]: + def make_mesh( + mesh_device_ids: np.ndarray, axis_names: Any + ) -> jax.sharding.Mesh: + cpu_device_map = _get_cpu_device_map() + mesh_devices = np.vectorize(lambda device_id: cpu_device_map[device_id])( + mesh_device_ids + ) + return jax.sharding.Mesh(mesh_devices, axis_names) + + mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices) + return make_mesh, (mesh_device_ids, mesh.axis_names) + + +def _reduce_device_list( + device_list: DeviceList, +) -> tuple[Callable[..., DeviceList], Any]: + def make_device_list(device_ids: Sequence[int]) -> DeviceList: + cpu_device_map = _get_cpu_device_map() + devices = np.vectorize(lambda device_id: cpu_device_map[device_id])( + device_ids + ) + return DeviceList(tuple(devices)) + + device_ids = [d.id for d in device_list] + return make_device_list, (device_ids,) + + +def _reduce_single_device_sharding( + sharding: jax.sharding.SingleDeviceSharding, +) -> tuple[Callable[..., jax.sharding.SingleDeviceSharding], Any]: + + def make_single_device_sharding(device_id: int): + cpu_device_map = _get_cpu_device_map() + return jax.sharding.SingleDeviceSharding(cpu_device_map[device_id]) + + return make_single_device_sharding, (sharding.device_set.pop().id,) + + +def _serialize(obj: Any) -> bytes: + """Serializes callables and input/output spec objects. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. + + This module contains utility functions used internally for implementiong + `colocated_python` when it ships callables and input/output specs through + IFRT. The pickled data is produced and consumed in an ephermeral fashion + without any persistence, and it does not expect any version compatibility + (which cloudpickle does not guarantee). Furthermore, serialization and + deserialization is expected to be done on machine(s) that are controlled by a + single tenant, which allows unpickling done during deserialization to be + trusted. + + Raises: + ModuleNotFoundError: If cloudpickle is not available. + """ + if cloudpickle is None: + raise ModuleNotFoundError('No module named "cloudpickle"') + + class _CustomPickler(cloudpickle.Pickler): + dispatch_table = collections.ChainMap( + {jax.sharding.Mesh: _reduce_mesh}, + {DeviceList: _reduce_device_list}, + {jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding}, + cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error + ) + dispatch = dispatch_table + + with io.BytesIO() as file: + _CustomPickler(file).dump(obj) + return file.getvalue() + + +def _deserialize(serialized: bytes) -> Any: + """Deserializes callables and input/output spec objects. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. See serialize() for details. + + Raises: + ModuleNotFoundError: If cloudpickle is not available. + """ + if cloudpickle is None: + raise ModuleNotFoundError('No module named "cloudpickle"') + + return cloudpickle.loads(serialized) + + +def _make_specs_for_serialized_specs( + devices: DeviceList, +) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]: + """Makes output specs for serialized specs.""" + # TODO(jmudigonda): Use a string-typed array for output structure when it + # becomes available. Using a fixed uint8 array is only for prototyping. + mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + replicated_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + return ( + api.ShapeDtypeStruct( + shape=(), dtype=np.int32, sharding=replicated_sharding + ), + api.ShapeDtypeStruct( + shape=(_MAX_SERIALIZED_SPECS_SIZE,), + dtype=np.uint8, + sharding=replicated_sharding, + ), + ) + + +def _serialize_specs( + specs_treedef: tree_util.PyTreeDef, + specs_leaves: tuple[api.ShapeDtypeStruct, ...], + devices: DeviceList, +) -> tuple[jax.Array, ...]: + """Serializes the output specs into a tuple of arrays. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. See serialize() for details. + """ + s = _serialize((specs_treedef, specs_leaves)) + assert ( + len(s) <= _MAX_SERIALIZED_SPECS_SIZE + ), f"Too large serialized spec size: {len(s)}" + # TODO(jmudigonda): Use a string-typed array for output structure when it + # becomes available. Using a fixed uint8 array is only for prototyping. + mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + replicated_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + len_array = jax.make_array_from_callback( + shape=(), + sharding=replicated_sharding, + data_callback=lambda _: np.array(len(s), dtype=np.int32), + ) + data_array = jax.make_array_from_callback( + shape=(_MAX_SERIALIZED_SPECS_SIZE,), + sharding=replicated_sharding, + data_callback=lambda _: np.frombuffer( + s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)), + dtype=np.uint8, + ), + ) + return len_array, data_array + + +def _deserialize_specs( + serialized_specs: tuple[jax.Array, ...], +) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]: + """Deserializes the specs from the serialized specs. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. See serialize() for details. + """ + # TODO(jmudigonda): Use a string-typed array for output structure when it + # becomes available. Using a fixed uint8 array is only for prototyping. + len_array, data_array = serialized_specs + length = int(len_array.addressable_shards[0].data) + data = np.asarray(data_array.addressable_shards[0].data).tobytes() + return _deserialize(data[:length]) diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py deleted file mode 100644 index d49aa296328a..000000000000 --- a/jax/experimental/export/__init__.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2023 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -_deprecation_message = ( - "The jax.experimental.export module is deprecated. " - "Use jax.export instead. " - "See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export." -) - -from jax._src.export import _export as _src_export -from jax._src.export import shape_poly as _src_shape_poly -from jax._src.export import serialization as _src_serialization -# Import only to set the shape poly decision procedure -from jax._src.export import shape_poly_decision -del shape_poly_decision - -# All deprecations added Jun 14, 2024 -_deprecations = { - # Added Jun 13, 2024 - "Exported": (_deprecation_message, _src_export.Exported), - "DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck), - "export": (_deprecation_message, _src_export.export_back_compat), - "call": (_deprecation_message, _src_export.call), - "call_exported": (_deprecation_message, _src_export.call_exported), - "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), - "minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version), - "maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version), - - "serialize": (_deprecation_message, _src_serialization.serialize), - "deserialize": (_deprecation_message, _src_serialization.deserialize), - - "SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope), - "is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim), - "symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape), - "symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs), -} - -import typing -if typing.TYPE_CHECKING: - Exported = _src_export.Exported - DisabledSafetyCheck = _src_export.DisabledSafetyCheck - export = _src_export.export_back_compat - call = _src_export.call - call_exported = _src_export.call_exported - default_lowering_platform = _src_export.default_lowering_platform - - serialize = _src_serialization.serialize - deserialize = _src_serialization.deserialize - - SymbolicScope = _src_shape_poly.SymbolicScope - is_symbolic_dim = _src_shape_poly.is_symbolic_dim - symbolic_shape = _src_shape_poly.symbolic_shape - symbolic_args_specs = _src_shape_poly.symbolic_args_specs -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _src_export -del _src_serialization -del _src_shape_poly diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index caf63df17bf1..da33a677ba07 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -237,7 +237,7 @@ params_vars = tf.nest.map_structure(tf.Variable, params) prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs) my_model = tf.Module() -# Tell the model saver what are the variables. +# Tell the model saver what the variables are. my_model._variables = tf.nest.flatten(params_vars) my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False) tf.saved_model.save(my_model) @@ -760,7 +760,7 @@ symbolic constraints: We plan to improve somewhat this area in the future. * Equality constraints are treated as normalization rules. E.g., `floordiv(a, b) = c` works by replacing all - occurences of the left-hand-side with the right-hand-side. + occurrences of the left-hand-side with the right-hand-side. You can only have equality constraints where the left-hand-side is a multiplication of factors, e.g, `a * b`, or `4 * a`, or `floordiv(a, b)`. Thus, the left-hand-side cannot contain @@ -1048,7 +1048,7 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32 tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64)) ``` -When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types +When the `JAX_ENABLE_X64` flag is set, JAX uses 64-bit types for Python scalars and respects the explicit 64-bit types: ```python @@ -1245,7 +1245,7 @@ Applies to both native and non-native serialization. trackable classes during attribute assignment. Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper classes. -In most situation, these Wrapper classes work exactly as the standard +In most situations, these Wrapper classes work exactly as the standard Python data types. However, the low-level pytree data structures are different and this can lead to errors. @@ -1499,7 +1499,7 @@ during lowering we try to generate one TensorFlow op for one JAX primitive. We expect that the lowering that XLA does is similar to that done by JAX before conversion. (This is a hypothesis, we have not yet verified it extensively.) -There is one know case when the performance of the lowered code will be different. +There is one known case when the performance of the lowered code will be different. JAX programs use a [stateless deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it. diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index baae52403053..2321a8a035f7 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -510,10 +510,17 @@ def _call_tf_lowering( else: captured_inputs.append(inp) - captured_ops = tuple( - mlir.ir_constant(np.asarray(inp)) - for inp in captured_inputs - ) + # The following use case happens when we call_tf a restored saved model that + # includes parameters (hence functions closing over tf.Variable), and then + # we jax2tf.convert it with native serialization, under tf.function (or + # for saving to saved model). The `np.asarray(inp)` fails because it thinks + # it is in TF graph mode. The `tf.init_scope()` lifts out of function-building + # graph scopes, and allows us to read the values of the variables + with tf.init_scope(): + captured_ops = tuple( + mlir.ir_constant(np.asarray(inp)) + for inp in captured_inputs + ) if call_tf_graph: with jax2tf_internal.inside_call_tf(): diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 310cbaab6d59..0d8c95d42676 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,6 +364,7 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + out_type=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index a5cfa5f9b928..188ffeb6d670 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -119,6 +119,10 @@ def _sanitize_scope_name(name): # Line below is different externally and internally. allow_enable_xla_false = lambda: True +# TODO(b/353437398): Deprecate support for `native_serialization=False`. +# Line below is different externally and internally. +allow_native_serialization_false = lambda: True + # A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any @@ -294,8 +298,8 @@ def convert(fun_jax: Callable, See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. - polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. + polymorphic_constraints: a sequence of constraints on symbolic dimension + expressions, of the form `e1 >= e2` or `e1 <= e2`. See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode @@ -332,28 +336,38 @@ def convert(fun_jax: Callable, tuple/lists/dicts thereof), and returns TfVals as outputs, and uses only TensorFlow ops and thus can be called from a TensorFlow program. """ - if not enable_xla: - if allow_enable_xla_false(): - warnings.warn("jax2tf.convert with enable_xla=False is deprecated.", - DeprecationWarning, - stacklevel=2) - else: - raise ValueError("jax2tf.convert with enable_xla=False is not supported.") - if native_serialization is DEFAULT_NATIVE_SERIALIZATION: if not enable_xla: native_serialization = False else: native_serialization = config.jax2tf_default_native_serialization.value - if not native_serialization: - warnings.warn( - "jax2tf.convert with native_serialization=False is deprecated.", - DeprecationWarning, - stacklevel=2) - if native_serialization and not enable_xla: - raise ValueError( - "native_serialization is not supported with enable_xla=False") + if not enable_xla: + if allow_enable_xla_false(): + warnings.warn( + "jax2tf.convert with enable_xla=False has been deprecated " + "since July 2024.", + DeprecationWarning, + stacklevel=2) + if native_serialization: + raise ValueError( + "native_serialization is not supported with enable_xla=False") + else: + raise ValueError( + "jax2tf.convert with enable_xla=False has been deprecated " + "since July 2024 and it is not supported anymore.") + + elif not native_serialization: + if allow_native_serialization_false(): + warnings.warn( + "jax2tf.convert with native_serialization=False has been deprecated " + "since July 2024.", + DeprecationWarning, + stacklevel=2) + else: + raise ValueError( + "jax2tf.convert with native_serialization=False has been deprecated " + "since July 2024 and it is not supported anymore.") if not native_serialization and polymorphic_constraints: raise ValueError( @@ -385,7 +399,7 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + - f"Trace state: {core.thread_local_state.trace_state.trace_stack}") + f"Trace state: {core.trace_ctx}") global _has_registered_tf_source_path if not _has_registered_tf_source_path: @@ -518,7 +532,16 @@ def __init__(self, fun_jax, *, self.convert_kwargs = dict(native_serialization=True, native_serialization_platforms=native_serialization_platforms, native_serialization_disabled_checks=native_serialization_disabled_checks) - self.fun_jax = fun_jax + if hasattr(fun_jax, "trace"): + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + fun_jit = fun_jax + else: + # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also + # convert(f_jax), in which case a "jit" is implied. In that case we raise + # an error if the lowered function contains non-replicated sharding annotations. + fun_jit = jax.jit(fun_jax) + self.fun_jax = fun_jit self.args_specs = args_specs self.kwargs_specs = kwargs_specs self.native_serialization_disabled_checks = native_serialization_disabled_checks @@ -533,9 +556,9 @@ def _restore_context(): self._restore_context = _restore_context _exported_device_assignment = [None] - self.exported = _export.export_back_compat( + self.exported = _export._export_internal( self.fun_jax, - lowering_platforms=self.native_serialization_platforms, + platforms=self.native_serialization_platforms, disabled_checks=self.native_serialization_disabled_checks, _device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment, )(*self.args_specs, **self.kwargs_specs) @@ -830,15 +853,11 @@ def _interpret_fun_jax( extra_name_stack: str | None, fresh_constant_cache: bool = False, ) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]: - with core.new_base_main(TensorFlowTrace) as main: - subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals) - with _extended_name_stack(extra_name_stack): - with core.new_sublevel(): - out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ - _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, - fresh_constant_cache=fresh_constant_cache) - del main - + subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), args_avals) + with _extended_name_stack(extra_name_stack): + out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ + _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, + fresh_constant_cache=fresh_constant_cache) return util.unzip2(out_vals) @@ -1021,20 +1040,20 @@ def impl_multiple_results_jax(*args_jax): return wrapped_tf -@lu.transformation -def _interpret_subtrace(main: core.MainTrace, - in_avals: Sequence[core.ShapedArray], +@lu.transformation2 +def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray], *in_vals: TfVal): - trace = TensorFlowTrace(main, core.cur_sublevel()) + trace = TensorFlowTrace() in_tracers = tuple( TensorFlowTracer(trace, val, aval) for val, aval in zip(in_vals, in_avals)) - outs = yield in_tracers, {} # type: Sequence[TfVal] + with core.set_current_trace(trace): + outs = f(*in_tracers) out_tracers: Iterable[TensorFlowTracer] = ( - map(trace.full_raise, outs)) + map(trace.to_tf_tracer, outs)) out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) - yield out_vals_with_avals + return out_vals_with_avals def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, @@ -1307,13 +1326,14 @@ class TensorFlowTrace(core.Trace): those will introduce their own MainTrace, and any operations involving those will be done on those traces, i.e., not a concern for TFT. """ - def pure(self, val: TfVal) -> TensorFlowTracer: + def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer: """Lifts a non-Tracer into the TensorFlowTracer. - - This function may be called by way of trace.full_raise. """ + if isinstance(val, TensorFlowTracer): + return val if hasattr(val, "__jax_array__"): - val = val.__jax_array__() + with core.set_current_trace(self): + val = val.__jax_array__() if isinstance(val, TensorFlowTracer): return val tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True) @@ -1321,20 +1341,10 @@ def pure(self, val: TfVal) -> TensorFlowTracer: self, tf_val, core.ShapedArray(np.shape(val), jax_dtype, weak_type=dtypes.is_weakly_typed(val))) - def lift(self, val: core.Tracer) -> TensorFlowTracer: - # This would be called when we need to raise a tracer from a lower-level - # main into the TensorFlowTrace. Since the TensorFlowTrace is never nested - # inside another transform, there are no lower-level main traces. - assert False - - def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer: - # This is called when we need to raise a tracer from the same main, - # but a lower sublevel. This could come from a nested jit. - return TensorFlowTracer(self, val.val, val._aval) - def process_primitive(self, primitive: core.Primitive, tracers: Sequence[TensorFlowTracer], params) -> TensorFlowTracer: + tracers = map(self.to_tf_tracer, tracers) impl, impl_needs_avals = self.get_primitive_impl(primitive) args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) # This is a bit conservative, doing abstract_eval even in op-by-op execution @@ -1410,39 +1420,18 @@ def invoke_impl() -> TfVal: def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, tracers: Sequence[TensorFlowTracer], params): assert call_primitive.multiple_results + tracers = map(self.to_tf_tracer, tracers) vals: Sequence[TfVal] = [t.val for t in tracers] avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) - interpreted_fun = _interpret_subtrace(fun, self.main, avals) + interpreted_fun = _interpret_subtrace(fun, avals) extra_name_stack = None with _extended_name_stack(extra_name_stack): - with core.new_sublevel(): - vals_out = interpreted_fun.call_wrapped(*vals) + vals_out = interpreted_fun.call_wrapped(*vals) return [TensorFlowTracer(self, v, a) for v, a in vals_out] - def post_process_call(self, call_primitive: core.Primitive, - out_tracers: Sequence[TensorFlowTracer], params): - # We encountered a call primitive whose result (out_tracers) include - # TensorFlowTracer that were not passed through its arguments (captured from - # the environment). - vals = tuple(t.val for t in out_tracers) - main = self.main - - def todo(vals: Sequence[TfVal]): - # TODO: is name_stack correct? - trace = TensorFlowTrace(main, core.cur_sublevel()) - return [ - TensorFlowTracer(trace, v, out_tracer.aval) - for v, out_tracer in zip(vals, out_tracers) - ] - - return vals, todo - def process_map(self, map_primitive, f, tracers, params): raise NotImplementedError("process_map") - def post_process_map(self, map_primitive, out_tracers, params): - raise NotImplementedError("post_process_map") - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so @@ -1450,9 +1439,6 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): del jvp, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable assuming jax2tf runs with clean trace state - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This @@ -1461,12 +1447,6 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, del fwd, bwd, out_trees, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable assuming jax2tf runs with clean trace state - - def post_process_custom_vjp_call_fwd(self, *_, **__): - assert False # unreachable assuming jax2tf runs with clean trace state - def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]: # Returns the primitive implementation and whether the implementation # takes abstract values (see definition of tf_impl_with_avals) @@ -1686,8 +1666,37 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl[lax.sinh_p] = tf.math.sinh tf_impl[lax.cos_p] = tf.math.cos tf_impl[lax.cosh_p] = tf.math.cosh -tf_impl_with_avals[lax.acos_p] = _convert_jax_impl( - lax_internal.acos_impl, multiple_results=False) + + +def _acos_impl(x): + if x.dtype.is_complex: + result = tf.multiply(tf.constant(1j, dtype=x.dtype), tf.math.acosh(x)) + # By convention, numpy chooses the branch with positive real part. + rpart = tf.math.real(result) + return tf.where( + tf.math.greater(rpart, tf.constant(0, dtype=rpart.dtype)), + result, + tf.math.negative(result), + ) + else: + return tf.where( + tf.math.not_equal(x, tf.constant(-1.0, dtype=x.dtype)), + tf.multiply( + tf.constant(2, dtype=x.dtype), + tf.math.atan2( + tf.math.sqrt( + tf.math.subtract( + tf.constant(1, dtype=x.dtype), tf.math.square(x) + ) + ), + tf.math.add(tf.constant(1, dtype=x.dtype), x), + ), + ), + tf.broadcast_to(tf.constant(np.pi, dtype=x.dtype), tf.shape(x)), + ) + + +tf_impl_with_avals[lax.acos_p] = _acos_impl tf_impl_with_avals[lax.asin_p] = _convert_jax_impl( lax_internal.asin_impl, multiple_results=False) tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( @@ -1717,6 +1726,7 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asinh_p] = tf.math.asinh tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.square_p] = tf.math.square tf_impl[lax.rsqrt_p] = tf.math.rsqrt def _cbrt(x): @@ -2183,11 +2193,12 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): def _dot_general(lhs, rhs, *, dimension_numbers, precision: lax_internal.CanonicalPrecision, preferred_element_type: DType | None, + out_type=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" # TODO(b/293247337): we ought to turn on this safety check, but this leads to - # failures. Since we are going to turn on native serializaton soon, wait + # failures. Since we are going to turn on native serialization soon, wait # until then to turn on this check. # lhs_aval, rhs_aval = _in_avals # if lhs_aval.dtype != rhs_aval.dtype: @@ -2252,7 +2263,7 @@ def _dot_general_convert_to_common_dtype( convert_result = lambda res: res return (lhs, rhs, convert_result) -def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, +def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, sharding=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # for i in range(len(operand.shape)): @@ -2280,7 +2291,7 @@ def _empty(*, dtype): tf_impl[lax_internal.empty_p] = _empty -def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval): +def _reshape(operand, *, new_sizes, dimensions, sharding, _in_avals, _out_aval): if dimensions is None: dimensions = tf.range(tf.rank(operand)) new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype) @@ -3522,7 +3533,7 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( s: sharding.Sharding, aval: core.ShapedArray) -> xla_client.HloSharding | None: - if sharding_impls.is_unspecified(s): + if isinstance(s, sharding_impls.UnspecifiedValue): return None return s._to_xla_hlo_sharding(aval.ndim) @@ -3580,6 +3591,7 @@ def _pjit(*args: TfVal, name: str, keep_unused: bool, inline: bool, + compiler_options_kvs, _in_avals: Sequence[core.ShapedArray], _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 492dfad4c855..f23bd58c48d3 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -90,7 +90,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) @@ -391,6 +391,20 @@ def fun_tf(x): res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose((x * 3. + 4. + 2.) * 3. + 5., res, check_dtypes=False) + def test_with_capture_then_convert_again(self): + captured_by_tf = tf.Variable(np.arange(1024, dtype=np.float32)) + def tf_fn(x): + return tf.math.add(x, captured_by_tf) + + x = np.arange(1024, dtype=np.float32) + res = jax2tf.convert(jax2tf.call_tf(tf_fn))(x) + self.assertAllClose(res, 2 * x) + + # The bug appears only when we use non-eager mode on the converted func + res = tf.function(jax2tf.convert(jax2tf.call_tf(tf_fn)), + autograph=False)(x) + self.assertAllClose(res, 2 * x) + @_parameterized_jit def test_grad(self, with_jit=False): x = np.float32(3.) @@ -789,7 +803,7 @@ def f_jax(x): lowering_platforms = ("tpu", "cpu", "cuda") exp = export.export(jax.jit(f_jax), - lowering_platforms=lowering_platforms)(x) + platforms=lowering_platforms)(x) for jax_platform in jax_and_tf_platforms: with self.subTest(jax_platform): jax_device = jax.devices(jax_platform)[0] @@ -883,7 +897,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) @@ -957,6 +971,13 @@ def f_jax(param, x): restored_jax = jax2tf.call_tf(restored_model.f) self.assertAllClose(f_jax(param, x), restored_jax(x)) self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x)) + self.assertAllClose(f_jax(param, x), jax2tf.convert(restored_jax)(x)) + self.assertAllClose(f_jax(param, x), + tf.function(jax2tf.convert(restored_jax), + autograph=False)(x)) + self.assertAllClose(f_jax(param, x), + tf.function(jax2tf.convert(restored_jax), + autograph=True)(x)) def test_saved_model_shape_poly(self): tracing_count = 0 @@ -1182,7 +1203,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 01ed4eed21fd..7d3313be6c92 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -44,20 +44,8 @@ import numpy as np import tensorflow as tf -# pylint: disable=g-direct-tensorflow-import -from tensorflow.compiler.tf2xla.python import xla as tfxla -# pylint: enable=g-direct-tensorflow-import config.parse_flags_with_absl() -_exit_stack = contextlib.ExitStack() - -# TODO(necula): Remove once tensorflow is 2.10.0 everywhere. -def setUpModule(): - if not hasattr(tfxla, "optimization_barrier"): - _exit_stack.enter_context(jtu.global_config_context(jax_remat_opt_barrier=False)) - -def tearDownModule(): - _exit_stack.close() class Jax2TfTest(tf_test_util.JaxToTfTestCase): @@ -79,7 +67,7 @@ def setUpClass(cls): def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) self.warning_ctx.__enter__() @@ -770,6 +758,7 @@ def test_checkpoint_wrapper_types(self): self.assertLen(jax.tree_util.tree_leaves(m.b), 2) self.assertLen(jax.tree_util.tree_leaves(m.c), 2) + @unittest.skip("Test fails at head") def test_issue_10586(self): class JaxModule(tf.Module): @@ -978,8 +967,8 @@ def caller_jax(x): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) + if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) def test_bfloat16_constant(self): # Re: https://github.com/jax-ml/jax/issues/3942 @@ -1530,7 +1519,7 @@ def apply_transform(func, transform: str): _ = func_to_convert(*args) exported = export.export( (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), - lowering_platforms=("tpu",) + platforms=("tpu",) )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) if transform1 == "shard_map": @@ -1689,6 +1678,22 @@ def f_jax(x): res, x + _testing_multi_platform_to_add[tf_device_jax_platform]) + def test_dot_algorithm(self): + # ref: https://github.com/jax-ml/jax/issues/24236 + if tf.version.VERSION.split(".") <= ["2", "18", "0"]: + self.skipTest("Because of an XLA bug this test segfaults with TF v2.18.0") + + if jtu.test_device_matches(["tpu"]): + algorithm = "BF16_BF16_F32" + else: + algorithm = "F32_F32_F32" + + def f_jax(x): + return jax.lax.dot(x, x, precision=algorithm) + + f_tf = jax2tf.convert(f_jax, native_serialization=True) + f_tf(np.ones((128, 128), dtype=np.float32)) # no crash + def test_dot_algorithm_non_native_unsupported(self): def f_jax(x): return jax.lax.dot(x, x, precision="F32_F32_F32") @@ -1704,7 +1709,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) self.warning_ctx.__enter__() @@ -1745,7 +1750,7 @@ def setUp(self): super().setUp() @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 2863ca4ed616..76d5b4cde6c7 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -183,6 +183,8 @@ def test_primitive_coverage(self): continue if p.name == "pallas_call": continue + if p.name == "ragged_all_to_all": + continue if p.name == "ffi_call": continue if p.name == "tpu_custom_call": diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 38af6d9d76d5..2a58e29dbbad 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1031,7 +1031,7 @@ def f_jax(x): # A function whose gradient is a constant self.assertAllClose(f_jax(x), restored_f(x)) @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) def test_readme_examples(self): """Some of the examples from the README.""" @@ -1124,31 +1124,6 @@ def f2_jax(x): # f32[b, b] # JAX with static shapes sees that x.shape[0] != x.shape[1] self.assertEqual(jnp.sum(x45), f2_jax(x45)) - # In graph serialization eager mode, we catch the broken assumption b >= 1 - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - re.escape( - "Found inconsistency between dimension size args[0].shape[1] (= 5) " - "and the specification 'b' (= 4)")): - jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False)(x45) - - # In graph serialization graph mode we also catch it (except on TPU, where - # the behavior is as for jit_compile=1) - - f2_tf = tf.function( - jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False), - autograph=False, - ).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) - if jtu.test_device_matches(["tpu"]): - self.assertEqual(1. + jnp.sum(x45), f2_tf(x45)) - else: - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - r"Found inconsistency"): - _ = f2_tf(x45) - # We also catch the error with native serialization with self.assertRaisesRegex( tf.errors.InvalidArgumentError, @@ -2114,7 +2089,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] polymorphic_shapes=[None, "b0, ..."], expect_error=( (core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else (None, None)), + "array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)), override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [ diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index ffe362974dcb..2681ad1a2a7b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -141,40 +141,43 @@ def jet(fun, primals, series): if not treedef_is_leaf(treedef): raise ValueError(f"term {j} for argument {i} is not an array") - @lu.transformation_with_aux - def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) + @lu.transformation_with_aux2 + def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, tree = tree_flatten(ans) + store.store(tree) + return ans f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) -@lu.transformation -def jet_fun(order, primals, series): - with core.new_main(JetTrace) as main: - main.order = order - out_primals, out_terms = yield (main, primals, series), {} - del main +@lu.transformation2 +def jet_fun(f, order, primals, series): + tag = core.TraceTag() + out_primals, out_terms = f(tag, order, primals, series) out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] - yield out_primals, out_terms - -@lu.transformation -def jet_subtrace(main, primals, series): - trace = JetTrace(main, core.cur_sublevel()) - in_tracers = map(partial(JetTracer, trace), primals, series) - ans = yield in_tracers, {} - out_tracers = map(trace.full_raise, ans) - out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) - yield out_primals, out_terms - -@lu.transformation_with_aux -def traceable(in_tree_def, *primals_and_series): + return out_primals, out_terms + +@lu.transformation2 +def jet_subtrace(f, tag, order, primals, series): + with core.take_current_trace() as parent_trace: + trace = JetTrace(tag, parent_trace, order) + in_tracers = map(partial(JetTracer, trace), primals, series) + with core.set_current_trace(trace): + ans = f(*in_tracers) + + out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans)) + return out_primals, out_terms + +@lu.transformation_with_aux2 +def traceable(f, store, in_tree_def, *primals_and_series): primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) - primals_out, series_out = yield (primals_in, series_in), {} + primals_out, series_out = f(primals_in, series_in) out_flat, out_tree_def = tree_flatten((primals_out, series_out)) - yield out_flat, out_tree_def + store.store(out_tree_def) + return out_flat class JetTracer(core.Tracer): @@ -198,33 +201,44 @@ def full_lower(self): class JetTrace(core.Trace): - def pure(self, val): - return JetTracer(self, val, zero_series) - - def lift(self, val): - return JetTracer(self, val, zero_series) + def __init__(self, tag, parent_trace, order): + self.tag = tag + self.parent_trace = parent_trace + self.order = order - def sublift(self, val): - return JetTracer(self, val.primal, val.terms) + def to_primal_terms_pair(self, val): + if isinstance(val, JetTracer) and val._trace.tag is self.tag: + return val.primal, val.terms + else: + return val, zero_series def process_primitive(self, primitive, tracers, params): - order = self.main.order # pytype: disable=attribute-error - primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) + order = self.order # pytype: disable=attribute-error + primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers)) + + if all(t is zero_series for t in series_in): + primal_out = primitive.bind_with_trace(self.parent_trace, primals_in, params) + if primitive.multiple_results: + return [JetTracer(self, p, zero_series) for p in primal_out] + else: + return JetTracer(self, primal_out, zero_series) + series_in = [[zero_term] * order if s is zero_series else s for s in series_in] - # TODO(mattjj): avoid always instantiating zeros - series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x)) - if t is zero_term else t for t in series] - for x, series in zip(primals_in, series_in)] - rule = jet_rules[primitive] - primal_out, terms_out = rule(primals_in, series_in, **params) + with core.set_current_trace(self.parent_trace): + # TODO(mattjj): avoid always instantiating zeros + series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x)) + if t is zero_term else t for t in series] + for x, series in zip(primals_in, series_in)] + rule = jet_rules[primitive] + primal_out, terms_out = rule(primals_in, series_in, **params) if not primitive.multiple_results: return JetTracer(self, primal_out, terms_out) else: return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)] def process_call(self, call_primitive, f, tracers, params): - primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) + primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers)) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) update_params = call_param_updaters.get(call_primitive) @@ -234,17 +248,6 @@ def process_call(self, call_primitive, f, tracers, params): primals_out, series_out = tree_unflatten(out_tree_def(), result) return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)] - def post_process_call(self, call_primitive, out_tracers, params): - primals, series = unzip2((t.primal, t.terms) for t in out_tracers) - out, treedef = tree_flatten((primals, series)) - del primals, series - main = self.main - def todo(x): - primals, series = tree_unflatten(treedef, x) - trace = JetTrace(main, core.cur_sublevel()) - return map(partial(JetTracer, trace), primals, series) - return out, todo - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): # TODO(mattjj): don't just ignore custom jvp rules? @@ -329,6 +332,7 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.reduce_sum_p) deflinear(lax.reduce_window_sum_p) deflinear(lax.fft_p) +deflinear(lax.copy_p) deflinear(dispatch.device_put_p) def _dynamic_slice_jet_rule(primals_in, series_in, **params): @@ -404,6 +408,7 @@ def def_comp(prim, comp): def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x ** 0.5) +def_comp(lax.square_p, lambda x: x * x) def_comp(lax.rsqrt_p, lambda x: x ** -0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 5d8a4dd9fc14..c1daa33576bb 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== from jax import ShapeDtypeStruct as ShapeDtypeStruct +from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 + from .core import ( Barrier as Barrier, ClusterBarrier as ClusterBarrier, @@ -25,6 +27,15 @@ Union as Union, as_gpu_kernel as as_gpu_kernel, ) + +if dialect is not None: + from .dialect_lowering import ( + gpu_address_space_to_nvptx as gpu_address_space_to_nvptx, + lower_mgpu_dialect as lower_mgpu_dialect + ) +else: + gpu_address_space_to_nvptx, lower_mgpu_dialect = None, None + from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, @@ -32,6 +43,7 @@ WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, + optimization_barrier as optimization_barrier, ) from .utils import ( BarrierRef as BarrierRef, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 8ee1bda2f41e..b03c3a5b54fc 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -77,7 +77,7 @@ os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) -mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @@ -97,7 +97,7 @@ def _mosaic_gpu_lowering_rule( out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), ): - del out_types # Unused. + assert len(out_types) == len(ctx.avals_out) kernel_id = hashlib.sha256(module).digest() # Note that this is technically only a half measure. Someone might load a # compiled module with a hash collision from disk. But that's so unlikely with @@ -133,6 +133,14 @@ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: raise NotImplementedError("Subclasses should override this method") + def batch(self, leading_rank: int) -> 'MemRefTransform': + """Returns a transform that accepts a ref with the extra `leading_rank` dims. + + The returned transform should leave the leading dimensions unchanged and + only apply to the suffix of the shape. + """ + raise NotImplementedError("Subclasses should override this method") + @dataclasses.dataclass(frozen=True) class TileTransform(MemRefTransform): @@ -198,6 +206,9 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: *self.tiling, ) + def batch(self, leading_rank: int) -> MemRefTransform: + return self + @dataclasses.dataclass(frozen=True) class TransposeTransform(MemRefTransform): @@ -217,6 +228,62 @@ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: return tuple(shape[p] for p in self.permutation) + def batch(self, leading_rank: int) -> MemRefTransform: + return TransposeTransform( + (*range(leading_rank), *(d + leading_rank for d in self.permutation)) + ) + + +@dataclasses.dataclass(frozen=True) +class CollapseLeadingIndicesTransform(MemRefTransform): + """Collapses leading indices into one.""" + strides: tuple[int, ...] + + @functools.cached_property + def common_stride(self) -> int: + return math.gcd(*self.strides) + + def apply(self, ref: ir.Value) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + strides, offset = ref_ty.get_strides_and_offset() + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + raise NotImplementedError("Dynamic offsets are not supported") + max_bound = sum( + (d - 1) * s // self.common_stride + for d, s in zip( + ref_ty.shape[: len(self.strides)], strides[: len(self.strides)] + ) + ) + 1 + new_shape = [max_bound, *ref_ty.shape[len(self.strides):]] + new_strides = [self.common_stride, *strides[len(self.strides):]] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ref_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.reinterpret_cast( + new_ref_ty, ref, [], [], [], + static_offsets=[offset], + static_sizes=new_shape, + static_strides=new_strides, + ) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + flat_idx = c(0, index) + for i, s in zip(idx[:len(self.strides)], self.strides): + flat_idx = arith.addi( + flat_idx, arith.muli(i, c(s // self.common_stride, index)) + ) + return (flat_idx, *idx[len(self.strides):]) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + if any(s != 1 for s in shape[:len(self.strides)]): + raise ValueError("Expected leading indices to be squeezed") + return (1, *shape[len(self.strides):]) + + def batch(self, leading_rank: int) -> MemRefTransform: + raise NotImplementedError # Unused + OnDeviceProfiler = profiler.OnDeviceProfiler @@ -340,7 +407,7 @@ def async_copy( arrive: bool | None = None, uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, - predicate: ir.Value | None = None, + predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. ): index = ir.IndexType.get() i16 = ir.IntegerType.get_signless(16) @@ -355,6 +422,8 @@ def async_copy( f"Expected same element type, got {element_type} and" f" {dst_ref_ty.element_type}" ) + if predicate is not None and not uniform: + raise ValueError("Predicate can only be defined when uniform is True") if not isinstance(gmem_transform, tuple): gmem_transform = (gmem_transform,) @@ -379,6 +448,17 @@ def async_copy( or gmem_ref.owner.opview.OPERATION_NAME != expected_name ): raise ValueError("GMEM reference in async_copy must be a kernel argument") + gmem_ref_ty = ir.MemRefType(gmem_ref.type) + gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() + if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape): + raise NotImplementedError( + "async_copy assumes the GMEM reference is contiguous" + ) + if any(s * element_bytewidth % 16 != 0 for s in gmem_strides[:-1]): + raise ValueError( + "async_copy requires all GMEM strides except the last one to be a" + " multiple of 16 bytes" + ) base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape @@ -386,16 +466,42 @@ def async_copy( dyn_base_indices = tuple( c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices ) + squeezed_dims = [i for i, squeezed in enumerate(is_squeezed) if squeezed] + sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed] + # Indexing is really slicing + squeezing, and user transforms are meant to + # apply after that. However, we actually have to apply the indexing last + # (it's fused into the TMA) and so we need to commute it with all the user + # transforms. For slicing this is done using transform_index and + # transform_shape. For squeezing we actually move all the squeezed dims to + # the front, and then batch each transform, making it ignore the extra dims. + if squeezed_dims: + gmem_transform = (TransposeTransform((*squeezed_dims, *sliced_dims)), + *(t.batch(len(squeezed_dims)) for t in gmem_transform)) + slice_shape = tuple(slice_shape) for t in gmem_transform: dyn_base_indices = t.transform_index(dyn_base_indices) slice_shape = t.transform_shape(slice_shape) - for dim, squeezed in enumerate(is_squeezed): - if squeezed: - smem_ref = utils.memref_unsqueeze(smem_ref, dim) - smem_ref_ty = ir.MemRefType(smem_ref.type) - if slice_shape != tuple(smem_ref_ty.shape): + num_squeezed_dims = len(squeezed_dims) + if len(slice_shape) > 5: + # We can try to collapse all squeezed dims into one. + if len(slice_shape) - num_squeezed_dims + 1 > 5: + raise ValueError( + "Async copies only support striding up to 5 dimensions" + ) + collapse = CollapseLeadingIndicesTransform( + tuple(gmem_strides[d] for d in squeezed_dims) + ) + gmem_transform = (*gmem_transform, collapse) + dyn_base_indices = collapse.transform_index(dyn_base_indices) + slice_shape = collapse.transform_shape(slice_shape) + num_squeezed_dims = 1 + del squeezed_dims, sliced_dims # Those no longer make sense. + + smem_ref_ty = ir.MemRefType(smem_ref.type) + # We moved all squeezed dims to the front. + if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape): raise ValueError( "Expected the SMEM reference to have the same shape as the" f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" @@ -409,6 +515,7 @@ def async_copy( dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) + assert all(d == 1 for d in slice_shape[:num_squeezed_dims]) collective_size = 1 if collective is not None: if isinstance(collective, gpu.Dimension): @@ -416,13 +523,16 @@ def async_copy( collective_size = math.prod(self.cluster_size[d] for d in collective) if collective_size > 1: def partition_dim(dim: int, idx: ir.Value, num_chunks: int): + # No need to partition squeezed dims. They don't even exist in smem_ref. + assert dim >= num_squeezed_dims nonlocal smem_ref slice_shape[dim] //= num_chunks block_offset = arith.muli(idx, c(slice_shape[dim], index)) dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) smem_ref = utils.memref_slice( smem_ref, - (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) + (slice(None),) * (dim - num_squeezed_dims) + + (utils.ds(block_offset, slice_shape[dim]),), ) stride = 1 idx = c(0, index) @@ -438,10 +548,12 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): rem_collective_size = 1 break elif rem_collective_size % slice_size == 0: - dim_idx = arith.remui(idx, c(slice_size, index)) - partition_dim(dim, dim_idx, slice_size) - idx = arith.divui(idx, c(slice_size, index)) - rem_collective_size //= slice_size + # This is an optimization and it lets us skip squeezed dims. + if slice_size > 1: + dim_idx = arith.remui(idx, c(slice_size, index)) + partition_dim(dim, dim_idx, slice_size) + idx = arith.divui(idx, c(slice_size, index)) + rem_collective_size //= slice_size else: break # We failed to partition the leading dimensions. del idx # We overwrote the block index in the loop. @@ -470,13 +582,10 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): uniform_ctx = ( functools.partial(utils.single_thread, per_block=False) - if uniform + if uniform and predicate is None else contextlib.nullcontext ) - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" @@ -697,16 +806,6 @@ def _launch( ) ) - smem_ref_tree = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers - ) - # TODO(apaszke): Skip the following if no barriers were initialized. - nvvm.fence_mbarrier_init() - if math.prod(cluster) != 1: - nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) - nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - gpu.barrier() - if profiler_spec: prof_smem = memref.view( ir.MemRefType.get( @@ -723,7 +822,19 @@ def _launch( ptr_ty = ir.Type.parse("!llvm.ptr") scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree + ctx = LaunchContext(launch_op, scratch_ptr, cluster, prof) + with ctx.named_region("Init"): + smem_ref_tree = _construct_smem_reftree( + cluster, dynamic_smem, smem_buffers + ) + # TODO(apaszke): Skip the following if no barriers were initialized. + nvvm.fence_mbarrier_init() + if math.prod(cluster) != 1: + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + gpu.barrier() + + yield ctx, smem_ref_tree if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() @@ -738,6 +849,7 @@ def _lower_as_gpu_kernel( out_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], module_name: str, + kernel_name: str | None = None, prof_spec: profiler.ProfilerSpec | None = None, ): ptr_ty = ir.Type.parse("!llvm.ptr") @@ -764,6 +876,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: module = ir.Module.create() attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) + if kernel_name is None: + kernel_name = getattr(body, "__name__", "anonymous") with ir.InsertionPoint(module.body): _declare_runtime_functions() gmem_scratch_bytes = 0 @@ -773,7 +887,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ir.Attribute.parse("#llvm.linkage"), addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}") def main(token_ptr, buffers): nonlocal gmem_scratch_bytes token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) @@ -838,6 +952,7 @@ def as_gpu_kernel( prof_spec: profiler.ProfilerSpec | None = None, cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", + kernel_name: str | None = None, ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -847,7 +962,7 @@ def as_gpu_kernel( module, out_shape, unwrap_output_tuple = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec + module_name, kernel_name, prof_spec ) ) @@ -905,6 +1020,7 @@ def as_torch_gpu_kernel( prof_spec: profiler.ProfilerSpec | None = None, cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", + kernel_name: str | None = None, ): try: import torch @@ -923,7 +1039,7 @@ def as_torch_gpu_kernel( module, out_shape, unwrap_output_tuple = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec + module_name, kernel_name, prof_spec ) ) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py new file mode 100644 index 000000000000..9bda5b5b7191 --- /dev/null +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -0,0 +1,128 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lowering rules and pass for the MLIR Mosaic GPU dialect.""" + +from collections.abc import Callable +import functools +import operator +from typing import Sequence, Type + +from jax._src.interpreters import mlir as mlir_interpreter +from jax._src.lib import mosaic_gpu_dialect as mgpu + +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import nvvm +from .utils import c, ptr_as_memref, single_thread_predicate + +# mypy: ignore-errors + + +MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]] + + +_lowerings: dict[str, MlirLoweringRule] = {} + + +# TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36. +# Jaxlib doesn't contain Mosaic GPU dialect bindings. +InitializeBarrierOp = mgpu.InitializeBarrierOp if mgpu is not None else None + +def _register_lowering( + op: str | Type[ir.OpView] +) -> Callable[[MlirLoweringRule], MlirLoweringRule]: + def wrapper(f): + op_name = op if isinstance(op, str) else op.OPERATION_NAME # pytype: disable=attribute-error + _lowerings[op_name] = f + return f + + return wrapper + + +def _lowered_barrier_type() -> ir.Type: + return ir.IntegerType.get_signless(64) + + +def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: + match address_space: + case gpu.AddressSpace.Global: + return 1 + case gpu.AddressSpace.Workgroup: + return 3 + case _: + raise NotImplementedError(f"address_space not supported: {address_space}") + + +@_register_lowering(InitializeBarrierOp) +def _initialize_barrier_op_lowering_rule( + initialize_barrier_op: InitializeBarrierOp) -> Sequence[ir.Value]: + + shape = initialize_barrier_op.barriers_ref.type.shape + num_barriers = functools.reduce(operator.mul, shape, 1) + + i32 = ir.IntegerType.get_signless(32) + workgroup_nvptx_address_space = gpu_address_space_to_nvptx( + gpu.AddressSpace.Workgroup) + ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") + + lowered_barrier_type = _lowered_barrier_type() + + predicate = single_thread_predicate(per_block=True) + for i in range(num_barriers): + nvvm.mbarrier_init_shared( + llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i], + lowered_barrier_type), + c(initialize_barrier_op.arrival_count.value, i32), + predicate=predicate + ) + + barrier_base_ptr = llvm.getelementptr( + ir.Type.parse("!llvm.ptr"), + initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type) + + return ptr_as_memref( + barrier_base_ptr, initialize_barrier_op.barriers_ref.type), + + +def lower_mgpu_dialect(module: ir.Module): + module.context.append_dialect_registry(mlir_interpreter.upstream_dialects) + module.context.load_all_available_dialects() + + lowered_operations: set[ir.Operation | ir.OpView] = set() + + def _lower_op(op: ir.OpView): + if op.name not in _lowerings: + return + lowering_rule = _lowerings[op.name] + new_results = lowering_rule(op) + for old, new in zip(op.results, new_results): + old.replace_all_uses_with(new) + lowered_operations.add(op) + + def _traverse_and_lower_op(op: ir.OpView): + for region in op.operation.regions: + for block in region: + for block_op in list(block): + with ir.InsertionPoint(block_op): + _traverse_and_lower_op(block_op) + _lower_op(op) + + with ir.InsertionPoint(module.body): + for op in module.body: + _traverse_and_lower_op(op) + + for lowered_op in lowered_operations: + lowered_op.erase() diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index daacefb135e9..4728f00a9243 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -300,9 +300,7 @@ def kv_loop(kv_step, carry): with ir.InsertionPoint(if_compute.else_block): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) with single_thread(per_block=False): - k_tr = ( - TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)), - ) + k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): @@ -396,10 +394,7 @@ def kv_copy_init(slot, kv_seq_base): with single_thread(per_block=False): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) - k_tr = ( - TileTransform(tiling), - TransposeTransform((0, 2, 1, 3, 4)), - ) + k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)): ctx.async_copy( @@ -605,9 +600,9 @@ def ref(q, k, v): if __name__ == "__main__": if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): + not jtu.is_cuda_compute_capability_equal("9.0")): warnings.warn( - "Mosaic GPU Flash Attention requires compute capability 9.0 to run, " + "Mosaic GPU Flash Attention requires compute capability 9.0a to run, " "skipping.") exit(0) @@ -649,7 +644,9 @@ def ref(q, k, v): matmul_flops = ( 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size ) - peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS + # Table 1 in + # https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper + peak_flops = 989.4 * 1e12 # f16 TensorCore peak optimal_time = matmul_flops / peak_flops * 1e6 # us achieved_tc_util = optimal_time / runtime_us * 100 has_tma_warp = impl == Implementation.TWO_COMPUTE_ONE_TMA_WG diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index ce99bf423bae..7aa96e7fa5d3 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -20,6 +20,7 @@ import jax from jax import random +from jax._src import test_util as jtu # noqa: F401 from jax._src.interpreters import mlir from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import * # noqa: F403 @@ -378,7 +379,7 @@ def ref_f(x, y): x, y, dimension_numbers=dimension_numbers, - preferred_element_type=jnp.float32, + preferred_element_type=out_dtype, ).astype(out_dtype) ref, ref_runtime = profiler.measure(ref_f, x, y) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 98c56de9ccda..7e40d86f2ae4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -14,10 +14,13 @@ # ============================================================================== """Utilities for code generator.""" +from __future__ import annotations + import dataclasses import functools import math -from typing import Callable +from collections.abc import Callable +from typing import Iterable, Protocol, Sequence, TypeVar import jax from jaxlib.mlir import ir @@ -35,44 +38,269 @@ # mypy: ignore-errors +T = TypeVar("T") WARPGROUP_SIZE = utils.WARPGROUP_SIZE +WARP_SIZE = 32 +WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE +SMEM_BANKS = 32 +SMEM_BANK_BYTES = 4 c = utils.c @dataclasses.dataclass(frozen=True) -class WGSplatFragLayout: - """A fragmented array where all the values are equal represented as a register per thread. +class Tiling: + """A tiling expression describing a permutation of elements of an nd-array. - FragmentedArrays in this layout can be are always the result of a - splat, each thread in the warpgroup has a single copy of the value, - while the FragmentedArray pretends it has whatever shape the user - wants. This means we can trivially broadcast, reshape and do - elementwise operations with all other layouts. + To apply one level of tiling to an array, each of the trailing dimensions (up + to the rank of the tile) is unfolded into two dimensions: first equal to the + ratio of the dimension size and the tile size, and second equal to the tile + size. Then, all newly unfolded minor dimensions are transposed to appear at + the end. - Examples: + This expression describes multi-level tiling, by applying each element of + `tiles` in sequence to the array. - To load a value in - ``` - FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2)) - ``` + See https://openxla.org/xla/tiled_layout for a more detailed explanation. + """ + tiles: tuple[tuple[int, ...], ...] - A shape is always provided for sanity check reasons. + def __post_init__(self): + max_rank = math.inf + for tile in self.tiles: + if not tile: + raise ValueError("Tiles must not be empty") + if len(tile) > max_rank: + raise ValueError("Tile ranks must be non-increasing") + max_rank = len(tile) + if any(d <= 0 for d in tile): + raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") + + def __str__(self): + return f"Tiling({''.join(map(str, self.tiles))})" + + def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Computes the shape of an array after tiling.""" + def fail(): + raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") + for tile in self.tiles: + if len(tile) > len(shape): + fail() + untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):] + if any(s % t != 0 for s, t in zip(tiled_dims, tile)): + fail() + shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile) + return shape + + def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Computes the shape of an array before tiling from its tiled shape.""" + def fail(): + raise ValueError("Shape does not look like it's been tiled?") + for tile in reversed(self.tiles): + if len(tile) > len(shape): + fail() + untiled_dims = shape[:-2 * len(tile)] + tiled_dims = shape[-2 * len(tile):-len(tile)] + tiling_dims = shape[-len(tile):] + if tiling_dims != tile: + fail() + shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) + return shape + + def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: + """Computes the strides of an array after tiling.""" + for tile in self.tiles: + untiled, tiled = strides[:-len(tile)], strides[-len(tile):] + strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) + return strides + + def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: + for tile in self.tiles: + untiled, tiled = indices[:-len(tile)], indices[-len(tile):] + indices = ( + *untiled, + *(i // t for i, t in zip(tiled, tile)), + *(i % t for i, t in zip(tiled, tile)), + ) + return indices + def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: + for tile in reversed(self.tiles): + untiled = indices[:-2 * len(tile)] + outer = indices[-2 * len(tile):-len(tile)] + inner = indices[-len(tile):] + indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) + return indices + +def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: + """Like built-in enumerate, but returns negative indices into the sequence.""" + offset = len(elems) + for i, e in enumerate(elems): + yield i - offset, e + + +@dataclasses.dataclass(frozen=True) +class TiledLayout: + """A FragmentedArray layout derived from a tiling expression. + + A logical array is transformed according to the tiling expression, and then + split across warps (within a warpgroup), lanes, and vectorized according to + the dimension indices. All dimension indices must be negative and should refer + to the dimensions after tiling is applied. + + Note that warp_dim and vector_dim could be sets as well, but we don't have a + usecase for that yet. + + To better understand this layout, consider the example of WGMMA-related tiling + from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as + applied to a 128x128 array. The corresponding TiledLayout has a tiling of: + + (64, 8)(16, 8)(8, 8)(1, 2) + + and warp_dim=-8, lane_dims={-4, -3}, vector_dim=-1. + + We begin by applying the tiling (note that it always applies to a suffix): + + Tiled shape Remaining tiling actions + =========================================================================== + 128 128 (64, 8)(16, 8)(8, 8)(1, 2) + 2 16 64 8 (16, 8)(8, 8)(1, 2) + 2 16 4 1 16 8 (8, 8)(1, 2) + 2 16 4 1 2 1 8 8 (1, 2) + 2 16 4 1 2 1 8 4 1 2 + + The last expression is our final shape. At this stage, we're ready to + interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the + end is partitioned over 4 warps in a warpgroup (and so it must be of size 4). + lane_dims={-4, -3} indicate that those two dimensions are partitioned over + the lanes within a warp (their product must be equal to 32, i.e. warp size). + Finally, vector_dim=-1 indicates that each (logical) register is a vector + containing 2 elements (there are no shape restrictions here). + + Given the above, the shape of the (logical) register array used to represent + the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set all + the dimensions above to 1, since each thread is a member of a single warp, + a single lane, and the elements along the vectorized dimension are represented + by a single (logical) register. """ + tiling: Tiling + warp_dim: int + lane_dims: frozenset[int] + vector_dim: int - shape: tuple[int, ...] = () + def __post_init__(self): + if not self.tiling.tiles: + raise ValueError("Tiling must have at least one tile") + min_shape = self.tiling.tiles[0] + min_tiled_shape = self.tiling.tile_shape(min_shape) + dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} + if len(dims_set) != len(self.lane_dims) + 2: + raise ValueError + for d in dims_set: + if d >= 0: + raise ValueError("All dimensions must be negative") + if d < -(len(min_tiled_shape) - len(min_shape)): + raise ValueError("Dimension out of range") + if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: + raise ValueError + if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: + raise ValueError - def can_broadcast_to(self, shape) -> bool: - """Check that the shape can be broadcast. + @property + def base_tile_shape(self) -> int: + """The shape of the first tile in the tiling expression. - Only dimensions of size 1 can be broadcast. All other dimensions - must be the same as the argument shape. + This tile acts as the divisibility constraint for a suffix of arrays to + which this layout applies. """ - return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + return self.tiling.tiles[0] - def thread_idxs(self, shape): - assert shape == self.shape - raise NotImplementedError + @functools.cached_property + def tiled_tiling_shape(self) -> tuple[int, ...]: + """The shape of the suffix of the array after tiling. + + We only allow our repeated tiling actions to further subdivide the + dimensions created by previous tiling actions (except for the first one), + so the tiled shape always ends with this suffix, no matter what array shape + it's applied to. + """ + return self.tiling.tile_shape(self.base_tile_shape) + + @property + def vector_length(self) -> int: + return self.tiled_tiling_shape[self.vector_dim] + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + tiled_shape = list(self.tiling.tile_shape(shape)) + tiled_shape[self.warp_dim] = 1 + for d in self.lane_dims: + tiled_shape[d] = 1 + tiled_shape[self.vector_dim] = 1 + return tuple(tiled_shape) + + def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the logical shape of an array given its register array shape. + + Inverse to `registers_shape`. + """ + tiled_tiling = self.tiled_tiling_shape + shape = list(shape) + shape[self.warp_dim] = WARPS_IN_WARPGROUP + for d in self.lane_dims: + shape[d] = tiled_tiling[d] + shape[self.vector_dim] = tiled_tiling[self.vector_dim] + return self.tiling.untile_shape(tuple(shape)) + + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = tuple( + d if i in self.lane_dims else 1 + for i, d in enumerate_negative(self.tiled_tiling_shape) + ) + assert math.prod(tiled_shape) == WARP_SIZE + lane_strides = utils.get_contiguous_strides(tiled_shape) + lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) + # TODO(apaszke): Rewrite so that we can be sure that this never actually + # does arithmetic for any dimensions that are not in lane_dims. + return tuple( + arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) + for stride, size in zip(lane_strides, tiled_shape) + ) + + def warp_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = tuple( + d if i == self.warp_dim else 1 + for i, d in enumerate_negative(self.tiled_tiling_shape) + ) + assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP + warp_idx = arith.remui( + arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), + c(WARPS_IN_WARPGROUP, i32), + ) + indices = [arith.constant(i32, 0)] * len(tiled_shape) + indices[self.warp_dim] = warp_idx + return tuple(indices) + + +def _tiled_wgmma_layout(shape: tuple[int, ...]): + """Returns the tiled layout relevant for WGMMA operations. + + The tiled layout is equivalent to one described here in PTX documentation: + https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d + + This tiled layout is equivalent to WGMMAFragLayout and will subsume it. + """ + if len(shape) != 2: + raise ValueError(f"Shape {shape} is not 2D") + if shape[0] % 64 != 0 or shape[1] % 8 != 0: + raise ValueError(f"Shape {shape} is not a multiple of 64x8") + return TiledLayout( + Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=frozenset((-4, -3)), + vector_dim=-1, + ) @dataclasses.dataclass(frozen=True) @@ -96,6 +324,11 @@ def thread_idxs(self, shape): row = arith.addi(row_base, c(row_group + row_subgroup, index)) yield row, arith.addi(col_base, c(col_group, index)) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + assert len(shape) == 2 + assert shape[0] % 64 == 0 and shape[1] % 8 == 0 + return (shape[0] // 64, shape[1] // 8, 2, 1) + @dataclasses.dataclass(frozen=True) class WGMMARowFragLayout: @@ -105,6 +338,42 @@ def thread_idxs(self, shape): raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class WGSplatFragLayout: + """A fragmented array where all the values are equal represented as a register per thread. + + FragmentedArrays in this layout can be are always the result of a + splat, each thread in the warpgroup has a single copy of the value, + while the FragmentedArray pretends it has whatever shape the user + wants. This means we can trivially broadcast, reshape and do + elementwise operations with all other layouts. + + Examples: + + To load a value in + ``` + FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2)) + ``` + + A shape is always provided for sanity check reasons. + + """ + + shape: tuple[int, ...] = () + + def can_broadcast_to(self, shape) -> bool: + """Check that the shape can be broadcast. + + Only dimensions of size 1 can be broadcast. All other dimensions + must be the same as the argument shape. + """ + return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + + def thread_idxs(self, shape): + assert shape == self.shape + raise NotImplementedError + + @dataclasses.dataclass(frozen=True) class WGStridedFragLayout: """Convert the array to 1D and then shard across threads.""" @@ -162,7 +431,7 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout | TiledLayout WGMMA_LAYOUT = WGMMAFragLayout() @@ -196,8 +465,8 @@ def __init__( if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): raise TypeError( - "is_signed must only be non-None if the MLIR type is an integer" - f" type, got {_is_signed=} for {self.mlir_dtype}" + "is_signed must be non-None if and only if the MLIR type is an" + f" integer type, got {_is_signed=} for {self.mlir_dtype}" ) match self.layout: @@ -230,6 +499,14 @@ def __init__( if _registers.size != 1: raise ValueError(f"Invalid register array shape: {_registers.shape}") + case TiledLayout(): + try: + self.layout.shape_from_registers_shape(_registers.shape) + except ValueError: + raise ValueError( + "Register array shape does not match the tiled layout" + ) from None + case _: raise NotImplementedError @@ -304,15 +581,21 @@ def shape(self): return shape case WGSplatFragLayout(shape=shape): return shape + case TiledLayout(): + return self.layout.shape_from_registers_shape(self.registers.shape) + case _: + raise NotImplementedError @property def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: - case WGMMAFragLayout() | WGStridedFragLayout(): + case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty + case _: + raise NotImplementedError def to_layout(self, new_layout: FragmentedLayout): """Converts the fragmented array to the given layout. @@ -321,6 +604,17 @@ def to_layout(self, new_layout: FragmentedLayout): """ if self.layout == new_layout: return self + shape = self.shape + if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0: + tiled_layout = _tiled_wgmma_layout(shape) + if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or ( + self.layout == tiled_layout and new_layout == WGMMA_LAYOUT + ): + return FragmentedArray( + _registers=self.registers.reshape(new_layout.registers_shape(shape)), + _layout=new_layout, + _is_signed=self.is_signed, + ) if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" @@ -331,9 +625,28 @@ def to_layout(self, new_layout: FragmentedLayout): ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): - is_signed = ( - output_is_signed if output_is_signed is not None else self.is_signed - ) + # If our layout is a splat, then we should either dispatch to a non-splat + # layout, or broadcast ourselves to the output shape first. + if isinstance(self.layout, WGSplatFragLayout): + output_shape = self.shape + for i, o in enumerate(other): + if not isinstance(o, FragmentedArray): + continue + elif not isinstance(o.layout, WGSplatFragLayout): + return o._pointwise( + lambda o, this, *args: op(this, *args[:i], o, *args[i:]), + self, + *other[:i], + *other[i + 1 :], + output_is_signed=output_is_signed, + ) + else: + output_shape = np.broadcast_shapes(output_shape, o.shape) + # If we get here then we haven't found any non-splat layout. + if self.shape != output_shape: + return self.broadcast(output_shape)._pointwise( + op, *other, output_is_signed=output_is_signed + ) other_arrs = [] for o in other: @@ -344,17 +657,18 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): raise NotImplementedError(o) o = FragmentedArray.splat( - o, shape=self.shape, layout=self.layout, is_signed=is_signed + o, shape=self.shape, layout=self.layout, is_signed=self.is_signed ) if isinstance(o.layout, WGSplatFragLayout): if not o.layout.can_broadcast_to(self.shape): - raise ValueError("Can't broadcast shape.") + raise ValueError( + f"Cannot broadcast shape {self.shape} to layout {o.layout}") o = FragmentedArray.splat( o.registers.flat[0], shape=self.shape, layout=self.layout, - is_signed=is_signed, + is_signed=o.is_signed, ) else: if self.layout != o.layout: @@ -367,8 +681,13 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) + reg_ty = new_regs.flat[0].type + if ir.VectorType.isinstance(reg_ty): + reg_ty = ir.VectorType(reg_ty).element_type + if output_is_signed is None and ir.IntegerType.isinstance(reg_ty): + output_is_signed = self.is_signed return FragmentedArray( - _registers=new_regs, _layout=self.layout, _is_signed=is_signed + _registers=new_regs, _layout=self.layout, _is_signed=output_is_signed ) def __pos__(self): @@ -384,7 +703,7 @@ def __neg__(self): def __add__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.addf, other) + return self._pointwise(addf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.addi, other) else: @@ -395,7 +714,7 @@ def __radd__(self, other): def __mul__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.mulf, other) + return self._pointwise(mulf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.muli, other) else: @@ -406,7 +725,7 @@ def __rmul__(self, other): def __sub__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.subf, other) + return self._pointwise(subf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.subi, other) else: @@ -414,7 +733,7 @@ def __sub__(self, other): def __rsub__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(lambda s, o: arith.subf(o, s), other) + return self._pointwise(lambda s, o: subf(o, s), other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(lambda s, o: arith.subi(o, s), other) else: @@ -430,6 +749,32 @@ def __rtruediv__(self, other): return NotImplemented return self._pointwise(lambda s, o: arith.divf(o, s), other) + def __floordiv__(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise( + lambda s, o: mlir_math.floor(arith.divf(s, o)), other + ) + elif ir.IntegerType.isinstance(self.mlir_dtype): + if self.is_signed: + return self._pointwise(arith.floordivsi, other) + else: + return self._pointwise(arith.divui, other) + else: + return NotImplemented + + def __rfloordiv__(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise( + lambda s, o: mlir_math.floor(arith.divf(o, s)), other + ) + elif ir.IntegerType.isinstance(self.mlir_dtype): + if self.is_signed: + return self._pointwise(lambda s, o: arith.floordivsi(o, s), other) + else: + return self._pointwise(lambda s, o: arith.divui(o, s), other) + else: + return NotImplemented + def __mod__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented @@ -446,6 +791,35 @@ def __rmod__(self, other): else: return self._pointwise(lambda s, o: arith.remui(o, s), other) + def __invert__(self): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self ^ ~0 + + def __or__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self._pointwise(arith.ori, other) + + def __ror__(self, other): + return self | other + + def __and__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self._pointwise(arith.andi, other) + + def __rand__(self, other): + return self & other + + def __xor__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self._pointwise(arith.xori, other) + + def __rxor__(self, other): + return self ^ other + def __eq__(self, other): return self._compare( other, @@ -508,35 +882,52 @@ def _compare(self, other, *, f_pred, si_pred, ui_pred): def max(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.maximumf, other) + maximumf = arith.maximumf + if ir.F32Type.isinstance(self.mlir_dtype): + maximumf = self._lift_fast_instr("max.NaN.f32") + return self._pointwise(maximumf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise( arith.maxsi if self.is_signed else arith.maxui, other ) else: - return NotImplemented + return NotImplementedError + + def min(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.minimumf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise( + arith.minsi if self.is_signed else arith.minui, other + ) + else: + return NotImplementedError def exp(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: - f32 = ir.F32Type.get() - if self.mlir_dtype != f32: - raise NotImplementedError - log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634)) - def fast_exp(x): - scaled = arith.mulf(x, log2e) - return llvm.inline_asm(f32, [scaled], "ex2.approx.f32 $0, $1;", "=f,f") - return self._pointwise(self._lift_fast_unary(fast_exp)) + dtype = self.mlir_dtype + log2e = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.4426950408889634)) + return (self * log2e).exp2() return self._pointwise(mlir_math.exp) + def exp2(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + if approx: + if not ir.F32Type.isinstance(self.mlir_dtype): + raise NotImplementedError(self.mlir_dtype) + return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32")) + return self._pointwise(mlir_math.exp2) + def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("sin.approx.f32") if approx else mlir_math.sin + self._lift_fast_instr("sin.approx.f32") if approx else mlir_math.sin ) def cos(self, *, approx: bool = False): @@ -545,7 +936,16 @@ def cos(self, *, approx: bool = False): if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos + self._lift_fast_instr("cos.approx.f32") if approx else mlir_math.cos + ) + + def tanh(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + if approx and self.mlir_dtype != ir.F32Type.get(): + raise NotImplementedError + return self._pointwise( + self._lift_fast_instr("tanh.approx.f32") if approx else mlir_math.tanh ) def rsqrt(self, *, approx: bool = False): @@ -554,42 +954,47 @@ def rsqrt(self, *, approx: bool = False): if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("rsqrt.approx.f32") if approx else mlir_math.rsqrt + self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt ) @staticmethod - def _lift_fast_unary( + def _lift_fast_instr( instr: str | Callable[[ir.Value], ir.Value], ) -> Callable[[ir.Value], ir.Value]: - def fast_instr(x): + def fast_instr(*args): f32 = ir.F32Type.get() - if x.type == f32: + arg_ty = args[0].type + assert all(a.type == arg_ty for a in args) + if arg_ty == f32: if isinstance(instr, str): - return llvm.inline_asm(f32, [x], instr + " $0, $1;", "=f,f") + args_ptx = ", ".join(f"${i}" for i in range(len(args) + 1)) + return llvm.inline_asm( + f32, args, f"{instr} {args_ptx};", "=f" + ",f" * len(args) + ) else: - return instr(x) - elif ir.VectorType.isinstance(x.type): + return instr(*args) + elif ir.VectorType.isinstance(arg_ty): index = ir.IndexType.get() - result = llvm.mlir_undef(x.type) - for i in range(2): - v = vector.extractelement(x, position=c(i, index)) - vr = fast_instr(v) + result = llvm.mlir_undef(arg_ty) + [vec_len] = ir.VectorType(arg_ty).shape + for i in range(vec_len): + vs = [vector.extractelement(a, position=c(i, index)) for a in args] + vr = fast_instr(*vs) result = vector.insertelement(vr, result, position=c(i, index)) return result else: - raise NotImplementedError(x.type) + raise NotImplementedError(arg_ty) return fast_instr - def __and__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): - raise ValueError( - "Bitwise operations only defined for integer types, not" - f" {self.mlir_dtype}" + def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): + if (output_is_signed is not None) != ir.IntegerType.isinstance(elt): + raise TypeError( + "output_is_signed must be non-None if and only if the MLIR type is an" + f" integer type, got {output_is_signed=} for {elt}" ) - return self._pointwise(arith.andi, other) - - def bitcast(self, elt: ir.Type): + if elt == self.mlir_dtype: + return self reg_type = self.registers.flat[0].type if ir.VectorType.isinstance(reg_type): reg_shape = ir.VectorType(reg_type).shape @@ -597,7 +1002,9 @@ def bitcast(self, elt: ir.Type): else: ty = elt - return self._pointwise(lambda x: arith.bitcast(ty, x)) + return self._pointwise( + lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed + ) def __getitem__(self, idx): if self.layout != WGMMA_LAYOUT: @@ -640,12 +1047,11 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): ) reg_type = self.registers.flat[0].type is_vector_reg = ir.VectorType.isinstance(reg_type) - reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else () - if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,): + reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) + [vector_len] = reg_shape # This is meant to be a 1D assertion. + if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}: new_registers = np.empty_like(self.registers) - for idx, reg in np.ndenumerate(self.registers): - reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) - val_16 = llvm.extractelement(reg_16, c(0, i32)) + def upcast_to_bf16(reg, high): # We first embed the s8 into a bf16 with the exponent equal to # bias + mantissa bits. Then, we zero the msb that didn't fit into the # mantissa, zero out all bits other than msb, and subtract the last @@ -653,24 +1059,36 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): # lsb of the exponent (msb of the second byte) is zero, which allows us # to losslesly pack the msb there. When 1, it doubles the value of s2, # making the result negative. - new_val_32 = llvm.inline_asm( + return llvm.inline_asm( i32, - [val_16], - """ - { + [reg], + f""" + {{ .reg .b32 s<3>; - prmt.b32 s0, $1, 0x43, 0x4140; + prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140}; and.b32 s1, s0, 0xff7fff7f; and.b32 s2, s0, 0xff80ff80; sub.bf16x2 $0, s1, s2; - } + }} """, "=r,r", ) - new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32)) - new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32)) + empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32)) + for idx, reg in np.ndenumerate(self.registers): + if vector_len == 2: + reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) + new_reg_32 = upcast_to_bf16(reg_16, high=False) + new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) + elif vector_len == 4: + reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg) + low = upcast_to_bf16(reg_32, high=False) + high = upcast_to_bf16(reg_32, high=True) + new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32)) + new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32)) + else: + raise NotImplementedError(vector_len) new_registers[idx] = vector.bitcast( - ir.VectorType.get((2,), new_dtype), new_vec + ir.VectorType.get((vector_len,), new_dtype), new_vec_32 ) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed @@ -698,10 +1116,9 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") new_registers = np.empty_like(self.registers) match self.layout: - case WGMMAFragLayout(): - new_reg_ty = ir.VectorType.get((2,), new_dtype) - case WGStridedFragLayout(vec_size=vec_size): - new_reg_ty = ir.VectorType.get((vec_size,), new_dtype) + case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout(): + shape = ir.VectorType(self.registers.flat[0].type).shape + new_reg_ty = ir.VectorType.get(shape, new_dtype) case WGMMARowFragLayout() | WGSplatFragLayout(): new_reg_ty = new_dtype case _: @@ -713,9 +1130,9 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): ) # NOTE: scratch can be reused immediately once this function returns. - def reduce_sum(self, scratch) -> ir.Value: + def reduce_sum(self, scratch): if ir.FloatType.isinstance(self.mlir_dtype): - op = arith.addf + op = addf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.addi else: @@ -752,9 +1169,29 @@ def reduce_sum(self, scratch) -> ir.Value: utils.warpgroup_barrier() result = memref.load(scratch, [zero_index]) utils.warpgroup_barrier() # Make sure everyone is done using scratch. - return result - - def reduce(self, op, axis): + return FragmentedArray.splat(result, (), is_signed=self.is_signed) + + def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): + if isinstance(op, str): + match op: + case "add": + if ir.FloatType.isinstance(self.mlir_dtype): + op = addf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.addi + else: + raise NotImplementedError(self.mlir_dtype) + case "max": + if ir.F32Type.isinstance(self.mlir_dtype): + op = self._lift_fast_instr("max.NaN.f32") + elif ir.FloatType.isinstance(self.mlir_dtype): + op = arith.maximumf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.maxsi if self.is_signed else arith.maxui + else: + raise NotImplementedError(self.mlir_dtype) + case _: + raise ValueError(f"Unrecognized reduction operator: {op}") if self.layout != WGMMA_LAYOUT: raise NotImplementedError(self.layout) if axis != 1: @@ -846,17 +1283,36 @@ def select(self, on_true, on_false): or ir.IntegerType(self.mlir_dtype).width != 1 ): raise NotImplementedError - return self._pointwise(arith.select, on_true, on_false) + # We change the receiver here, because the return type is defined by + # `on_true` and `on_false` and not the predicate `self`. + return on_true._pointwise( + lambda t, p, f: arith.select(p, t, f), self, on_false, + ) - def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): + def foreach( + self, + fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None], + *, + create_array=False, + is_signed=None, + ): """Call a function for each value and index.""" index = ir.IndexType.get() - for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True): - assert len(idx) == len(self.shape), (idx, self.shape) + new_regs = None + if create_array: + new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type)) + for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): + reg = self.registers[reg_idx] + assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) [elems] = ir.VectorType(reg.type).shape for i in range(elems): i = c(i, index) - fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i))) + val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if create_array: + new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + + if create_array: + return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) def store_untiled(self, ref: ir.Value): if not ir.MemRefType.isinstance(ref.type): @@ -869,6 +1325,8 @@ def store_untiled(self, ref: ir.Value): self._store_untiled_splat(ref) case WGStridedFragLayout(): self._store_untiled_wg_strided(ref) + case TiledLayout(): + self._store_untiled_tiled(ref) case _: raise NotImplementedError(self.layout) @@ -935,42 +1393,97 @@ def c(x): col = arith.addi(col_base, c(col_tile * 8 + col_idx)) memref.store(value, ref, [row, col]) + def _store_untiled_tiled(self, ref: ir.Value): + """Stores an array with a tiled layout. Not optimized at the moment.""" + i32 = ir.IntegerType.get_signless(32) + layout = self.layout + assert isinstance(layout, TiledLayout) + ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() + if ref_strides[layout.vector_dim] != 1: + raise NotImplementedError( + "Can't use vector stores with non-unit minormost stride" + ) + strides = layout.tiling.tile_strides(ref_strides) + ptr = utils.memref_ptr(ref) + # Fold warp and lane offsets into the pointer once, since they are dynamic. + dyn_strides = [arith.constant(i32, s) for s in strides] + warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) + lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) + dyn_offset = arith.addi(warp_offset, lane_offset) + ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) + # All warp tile offsets are static and can be fused into the store. + for tile_idx, reg in np.ndenumerate(self.registers): + lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True)) + reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) + llvm.store(reg, reg_ptr) + def store_tiled(self, ref, swizzle: int | None): - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError - dtype = self.mlir_dtype - bw = mgpu.bytewidth(dtype) - m, n = self.shape - assert m % 64 == 0 # This is implied by the layout. - cols_per_tile = swizzle // bw - expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] - if n < cols_per_tile: # We allow singular tiles shorter than swizzle. - expected_shape = [m // 64, 1, 64, cols_per_tile] - if ir.MemRefType(ref.type).shape != expected_shape: - raise ValueError(ref.type, (m, n)) - for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): - vector.store(get(self.registers), ref, idxs) + match self.layout: + case WGMMAFragLayout(): + dtype = self.mlir_dtype + bw = mgpu.bytewidth(dtype) + m, n = self.shape + assert m % 64 == 0 # This is implied by the layout. + cols_per_tile = swizzle // bw + expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if n < cols_per_tile: # We allow singular tiles shorter than swizzle. + expected_shape = [m // 64, 1, 64, cols_per_tile] + if ir.MemRefType(ref.type).shape != expected_shape: + raise ValueError(ref.type, (m, n)) + for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): + vector.store(get(self.registers), ref, idxs) + case TiledLayout(): + layout, shape = self.layout, self.shape + for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + llvm.store(get(self.registers), ptr) + case _: + raise NotImplementedError(self.layout) @classmethod def load_tiled( - cls, ref, swizzle: int | None, *, is_signed: bool | None = None + cls, + ref, + swizzle: int | None, + *, + is_signed: bool | None = None, + layout: FragmentedLayout = WGMMA_LAYOUT, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type - bw = mgpu.bytewidth(dtype) - m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape - if m_tile_size != 64 or n_tile_size != (swizzle // bw): - raise ValueError - m, n = m_tiles * m_tile_size, n_tiles * n_tile_size - assert m % 64 == 0 # This is implied by the layout. - registers = np.full( - (m_tiles, n // 8, 2, 1), - vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), - dtype=object, - ) - for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): - update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) - return cls(_registers=registers, _layout=WGMMA_LAYOUT, _is_signed=is_signed) + match layout: + case TiledLayout(): + ref_ty = ir.MemRefType(ref.type) + tiled_shape = ref_ty.shape + if len(tiled_shape) % 2: + raise ValueError("Tiled reference must have even rank") + tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],)) + shape = tiling.untile_shape(tiled_shape) + zero = ( + vector.splat( + ir.VectorType.get((layout.vector_length,), dtype), c(0, dtype) + ), + ) + registers = np.full(layout.registers_shape(shape), zero, dtype=object) + reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) + for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): + update(registers, llvm.load(reg_ty, ptr)) + case WGMMAFragLayout(): + bw = mgpu.bytewidth(dtype) + m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape + if m_tile_size != 64 or n_tile_size != (swizzle // bw): + raise ValueError + m, n = m_tiles * m_tile_size, n_tiles * n_tile_size + assert m % 64 == 0 # This is implied by the layout. + registers = np.full( + (m_tiles, n // 8, 2, 1), + vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), + dtype=object, + ) + for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): + update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) + case _: + raise NotImplementedError(layout) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): @@ -1056,6 +1569,116 @@ def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new) yield get_register, update_registers, idx + @staticmethod + def transfer_tiled2( + ref: ir.Value, + swizzle: int | None, + layout: TiledLayout, + shape: tuple[int, ...], + ): + """Generate a transfer schedule for a tiled layout. + + Given a ref with one level tiling applied to it (we assume all dimensions + have been tiled), this function generates an iterable describing a good + schedule for swizzled SMEM loads/stores. + + At each step, the iterable yields a tuple of three values: + * a function that takes a register array and returns the register to be + stored at the current address + * a function that takes a register array and a register loaded from the + current address, and updates the register array with that register + * the current address for load/store instructions + """ + # TODO(apaszke): Use ldmatrix/stmatrix when possible. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + tiling = layout.tiling + + ref_ty = ir.MemRefType(ref.type) + dtype = ref_ty.element_type + if ref_ty.rank % 2: + raise ValueError("Tiled refence must have even rank") + ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:]) + ref_tiling = Tiling((ref_tiling_shape,)) + ref_strides, _ = ref_ty.get_strides_and_offset() + if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: + raise ValueError() + if len(layout.base_tile_shape) > len(ref_tiling_shape): + raise ValueError("Memory tiling must be a multiple of the register tiling") + ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):] + if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)): + raise ValueError("Memory tiling must be a multiple of the register tiling") + + if swizzle not in {32, 64, 128}: + raise ValueError("Only swizzled transfers supported") + bw = mgpu.bytewidth(dtype) + swizzle_tile_elems = 16 // bw + swizzle_group_elems = 128 // bw + swizzle_groups_per_block = swizzle // 16 + swizzle_block_elems = swizzle_groups_per_block * swizzle_group_elems + + tiled_strides = list(tiling.tile_strides(tuple(ref_strides))) + tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape))) + lane_strides = [tiled_strides[d] for d in layout.lane_dims] + lane_shape = [tiled_shape[d] for d in layout.lane_dims] + if tiled_strides[layout.vector_dim] != 1: + raise ValueError("Stride of the vectorized dimension should be 1") + for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + tiled_shape[d] = 1 + full_tiling = Tiling((ref_tiling_shape, *tiling.tiles)) + full_layout = dataclasses.replace(layout, tiling=full_tiling) + + plan = plan_tiled_transfer( + tiled_shape, tiled_strides, lane_shape, lane_strides, layout, bw, swizzle + ) + + dyn_tiled_strides = [c(s) for s in tiled_strides] + lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides) + warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides) + dyn_offset = arith.addi(lane_offset, warp_offset) + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError("Tiled stores can be performed into SMEM") + ptr = utils.memref_ptr(ref, memory_space=3) + _as_consts = lambda consts: [c(const) for const in consts.tolist()] + # This has bits set only for the offset bits that influence swizzling. + swizzle_mask = swizzle_block_elems - swizzle_tile_elems + for tile_idx in np.ndindex(*tiled_shape): + indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms]) + const_offset = np.dot(indices, tiled_strides) + # We split the offset into a part that interacts with swizzling and a + # part that doesn't. This lets us generate better code because constant + # offsets can be fused into load and store instructions. + const_offset_swizzle = const_offset & swizzle_mask + const_offset_no_swizzle = const_offset - const_offset_swizzle + offset_pre_swizzle = arith.addi( + dyn_offset, plan.select(_as_consts(const_offset_swizzle)) + ) + swizzle_group = arith.remui( + arith.divui(offset_pre_swizzle, c(swizzle_group_elems)), + c(swizzle_groups_per_block), + ) + swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems)) + offset = arith.xori(offset_pre_swizzle, swizzle_bits) + reg_ptr = utils.getelementptr(ptr, [offset], dtype) + offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle)) + reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], dtype) + reg_idxs = [ + tiling.tile_indices(full_tiling.untile_indices(idx)) + for idx in indices.tolist() + ] + def get_register(regs, reg_idxs=reg_idxs): + return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) + def update_registers(regs, new, reg_idxs=reg_idxs): + # TODO(apaszke): If the staggering forms a permutation with a small + # cycle length, then instead of blending at each step we could construct + # a small routing network (kind of like a sorting network) to fix up + # each cycle separately after all the loads are performed. + # This would be especially useful for dims that are powers of two and + # staggered by another power of 2, since all cycles are of length 2 (and + # we could save half the selects). + for i, reg_idx in enumerate(reg_idxs): + regs[reg_idx] = plan.select_if_group(i, regs[reg_idx], new) + yield get_register, update_registers, reg_ptr + def tree_flatten(self): aux = self.layout, self.registers.shape, self.is_signed return list(self.registers.flat), aux @@ -1065,3 +1688,287 @@ def tree_unflatten(cls, aux, flat_registers): layout, reg_shape, is_signed = aux registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + + +class TransferPlan(Protocol): + IndexTransform = Callable[[tuple[int, ...]], tuple[int, ...]] + tile_index_transforms: tuple[IndexTransform, ...] + + def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: + """Selects the value corresponding to the group of the current thread. + + The argument must be of the same length as tile_index_transforms. + """ + raise NotImplementedError + + def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: + """Returns `new` if the current thread belongs to the given group and `old` otherwise. + + group_idx must be between 0 and len(tile_index_transforms) - 1. + """ + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class TrivialTransferPlan(TransferPlan): + @property + def tile_index_transforms(self): + return (lambda x: x,) + + def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: + assert len(group_elems) == 1 + return group_elems[0] + + def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: + assert group_idx == 0 + return new + + +@dataclasses.dataclass(frozen=True) +class StaggeredTransferPlan(TransferPlan): + stagger: int + dim: int + size: int + group_pred: ir.Value + + @property + def tile_index_transforms(self): + dim = self.dim + def rotate(idx: tuple[int, ...]) -> tuple[int, ...]: + return ( + *idx[:dim], (idx[dim] + self.stagger) % self.size, *idx[dim + 1 :], + ) + return (lambda x: x, rotate) + + def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: + assert len(group_elems) == 2 + return arith.select(self.group_pred, group_elems[1], group_elems[0]) + + def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: + assert 0 <= group_idx <= 1 + sides = [old, new] if group_idx == 0 else [new, old] + return arith.select(self.group_pred, *sides) + + +def plan_tiled_transfer( + tiled_shape: Sequence[int], + tiled_strides: Sequence[int], + lane_shape: Sequence[int], + lane_strides: Sequence[int], + layout: TiledLayout, + bw: int, + swizzle: int, +) -> TransferPlan: + i32 = ir.IntegerType.get_signless(32) + c = lambda x: arith.constant(i32, x) + swizzle_tile_elems = 16 // bw + swizzle_group_elems = 128 // bw + # Below, all calculations are in elements, not in bytes, since it should + # generalize better to sub-byte types. + # Here, we verify two conditions: + # 1. Each vector transfer only accesses addresses that fall within a single + # swizzle tile (if not we'd need to split it and swizzle parts differently). + transfer_alignment = math.gcd(*( + s + for i, (s, d) in enumerate_negative(list(zip(tiled_strides, tiled_shape))) + if d > 1 or i in {layout.warp_dim, *layout.lane_dims} + )) + if ( + swizzle_tile_elems % transfer_alignment + and layout.vector_length <= transfer_alignment + ): + raise ValueError( + "Failed to prove that vector transfers don't cross swizzle tile" + " boundaries. This check is incomplete, and does not guarantee that" + " this is a user error, but it might be." + str(transfer_alignment) + ) + + # 2. The transfer pattern does not cause bank conflicts. + # TODO(apaszke): For now, when performing transfers narrower than a bank, + # we simply narrow each bank to the transfer width. The truth is more likely + # that bank conflicts only don't occur if the addresses mapping to the same + # bank are contiguous, but that's a more complicated check to perform. + transfer_bytes = layout.vector_length * bw + if transfer_bytes > SMEM_BANK_BYTES * 4: + raise NotImplementedError + if bw > SMEM_BANK_BYTES: + raise NotImplementedError + smem_bank_bytes = min(SMEM_BANK_BYTES, transfer_bytes) + num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes) + elems_per_bank = smem_bank_bytes // bw + num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) + wavefront_lanes = WARP_SIZE // num_wavefronts + + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) + def has_bank_conflicts(tile_idx_transform): + tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] + lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] + assert lane_tile_idx.shape[1] in {1, WARP_SIZE} + lane_tile_offsets = np.dot(lane_tile_idx, tiled_strides) + offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes] + assert offsets.shape[-1] == WARP_SIZE + swizzle_groups = (offsets // swizzle_group_elems) % (swizzle // 16) + swizzle_bits = swizzle_groups * swizzle_tile_elems + lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks + wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) + # Order of threads within the wavefront is unimportant. + wavefront_banks = np.sort(wavefront_banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + + # We don't need any special treatment if there are no conflicts when each lane + # transfers the same tile at a time. + if not has_bank_conflicts(lambda tile_idx: tile_idx): + return TrivialTransferPlan() + + # Otherwise, we will try to partition the lanes into two groups and have + # each group store to different tile. The only tile dimensions that can help + # us with bank conflicts are those that have multiple elements and a stride + # that's not a multiple of the number of banks. + # + # Note that the code is set up so that we could also consider partitioning + # the lanes into more groups, but the selects will become more expensive if + # we do that. It's a possibility we have if we need it. + candidate_dims = ( + i for i, (s, d) in enumerate(zip(tiled_strides, tiled_shape)) + if d > 1 and s % (SMEM_BANKS * elems_per_bank) + ) + for dim in candidate_dims: + for group_stride in (1, 2, 4, 8, 16): + # We change the group assignment each group_stride lanes. + lane_id = np.arange(WARP_SIZE)[:, None] + lane_group = (lane_id // group_stride) % 2 + # We only consider a transformation where the second group stores to a + # tile that's a constant offset (modulo dim size) from the first one. + for stagger in range(1, tiled_shape[dim]): + offset = np.zeros(len(tiled_shape), np.int64) + offset[dim] = stagger + transform = lambda idx: (idx + offset * lane_group) % tiled_shape + if not has_bank_conflicts(transform): + # We've found a strategy that avoids bank conflicts! + lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE)) + group_idx = arith.remui(arith.divui(lane_idx, c(group_stride)), c(2)) + group_pred = arith.cmpi(arith.CmpIPredicate.ne, group_idx, c(0)) + return StaggeredTransferPlan( + stagger, dim, tiled_shape[dim], group_pred + ) + raise ValueError( + "Failed to synthesize a transfer pattern that avoids bank conflicts" + ) + +# We allow contractions, to potentially take advantage of FMA instructions. +# They can change the results, but the precision should only increase. +def addf(a: ir.Value, b: ir.Value): + return arith.addf(a, b, fastmath=arith.FastMathFlags.contract) + +def subf(a: ir.Value, b: ir.Value): + return arith.subf(a, b, fastmath=arith.FastMathFlags.contract) + +def mulf(a: ir.Value, b: ir.Value): + return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract) + + +def optimization_barrier(*arrays: mgpu.FragmentedArray): + """Acts as an optimization barrier for LLVM. + + Passing arrays through this function will make sure that they are computed + before any side-effecting operations that follow this barrier. + """ + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + + regs = [] + reg_dtypes = [] + reg_constraints = [] + ptx_lines = ["// Optimization barrier"] + repack_fns = [] + # We unpack each array into a flat list of registers, and prepare the + # functions that invert the transform in repack_fns. + for array in arrays: + ptx_lines.append("// Next array") + reg_ty = array.registers.flat[0].type + dtype = array.mlir_dtype + num_prev_cstr = len(reg_constraints) + if ir.F32Type.isinstance(dtype): + if ir.VectorType.isinstance(reg_ty): + [vec_len] = ir.VectorType(reg_ty).shape + array_regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in array.registers.flat + for pos in range(vec_len) + ] + def _repack(regs, reg_ty=reg_ty): + reg = llvm.mlir_undef(reg_ty) + [vec_len] = ir.VectorType(reg_ty).shape + for i_elem in range(vec_len): + reg = llvm.insertelement( + reg, next(regs), arith.constant(i32, i_elem) + ) + return reg + repack_fns.append(_repack) + else: + array_regs = list(array.registers.flat) + repack_fns.append(lambda regs: next(regs)) + reg_constraint = "f" + elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): + if not ir.VectorType.isinstance(reg_ty): + raise NotImplementedError(array.mlir_dtype) + [vec_len] = ir.VectorType(reg_ty).shape + if vec_len != 2: + raise NotImplementedError(vec_len) + i32_reg_ty = ir.VectorType.get((1,), i32) + array_regs = [ + vector.extractelement( + vector.bitcast(i32_reg_ty, reg), position=c(0, index) + ) + for reg in array.registers.flat + ] + reg_constraint = "r" + def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): + return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) + repack_fns.append(_repack) + else: + raise NotImplementedError(array.mlir_dtype) + regs += array_regs + reg_dtypes += [array_regs[0].type] * len(array_regs) + reg_constraints += [f"={reg_constraint}"] * len(array_regs) + reg_constraints += [reg_constraint] * len(array_regs) + ptx_lines += [ + f"mov.b32 ${i}, ${len(array_regs)+i}" + for i in range(num_prev_cstr, num_prev_cstr + len(array_regs)) + ] + reg_constraints = ",".join(reg_constraints) + ptx = ";\n\t".join(ptx_lines) + ";" + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" + ) + result_struct = llvm.inline_asm( + struct_ty, regs, ptx, reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(dtype, result_struct, [i]) + for i, dtype in enumerate(reg_dtypes) + ] + i32 = ir.IntegerType.get_signless(32) + results = [] + regs_it = iter(regs) + for array, repack_fn in zip(arrays, repack_fns, strict=True): + num_regs = array.registers.size + reg_ty = array.registers.flat[0].type + if ir.VectorType.isinstance(reg_ty): + reg_ty = ir.VectorType(reg_ty) + new_registers = np.empty((num_regs,), dtype=object) + for i_vreg in range(num_regs): + reg = repack_fn(regs_it) + assert reg.type == reg_ty, (reg.type, reg_ty) + new_registers[i_vreg] = reg + results.append( + FragmentedArray( + _registers=new_registers.reshape(array.registers.shape), + _layout=array.layout, + _is_signed=array.is_signed, + ) + ) + return results[0] if len(arrays) == 1 else results diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index bf6631cbca16..e51a7b842931 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -14,16 +14,15 @@ # ============================================================================== import contextlib -import ctypes -import functools import itertools import json import math +from typing import Callable, ParamSpec, TypeVar import warnings import jax -from jax._src.interpreters import mlir from jax._src.lib import xla_client +from jax.extend import ffi import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -34,72 +33,80 @@ from .utils import * # noqa: F403 - try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - xla_client.register_custom_call_target( - "mosaic_gpu_record_event", - mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(), - platform="CUDA", - ) except ImportError: - pass + has_registrations = False +else: + # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36. + has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations") + if has_registrations: + for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): + xla_client.register_custom_call_target( + name, handler, platform="CUDA", api_version=1 + ) # ruff: noqa: F405 # mypy: ignore-errors +T = TypeVar("T") +P = ParamSpec("P") -record_event_p = jax.core.Primitive("record_event") -record_event_p.multiple_results = True - -@record_event_p.def_abstract_eval -def _record_event_abstract_eval(*args, event): - del event # Unused. - return args - -@functools.partial(mlir.register_lowering, record_event_p, platform="cuda") -def _record_event_lowering_rule(ctx, *args, event): - ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes( - 8, byteorder="little" - ) # pytype: disable=attribute-error - op = mlir.custom_call( - "mosaic_gpu_record_event", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - backend_config=ptr_bytes, - operand_output_aliases={i: i for i in range(len(args))}, - ) - return op.results - -def _record_event(args, event): +def _event_record(args, *, copy_before): flat_args, treedef = jax.tree.flatten(args) - return jax.tree.unflatten( - treedef, record_event_p.bind(*flat_args, event=event) - ) - -def measure(f, *args, **kwargs): - # TODO(apaszke): Raise if this is called under jit. - start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() - end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() - try: - - @jax.jit - def run(*args, **kwargs): - flat_args, treedef = jax.tree.flatten((args, kwargs)) - flat_args = _record_event(flat_args, start_event) - args, kwargs = jax.tree.unflatten(treedef, flat_args) - return _record_event(f(*args, **kwargs), end_event) - - jax.block_until_ready(run(*args, **kwargs)) # Warmup. - results = jax.block_until_ready(run(*args, **kwargs)) - elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( - start_event, end_event + event, *flat_outs = ffi.ffi_call( + "mgpu_event_record", + result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args), + input_output_aliases={i: i + 1 for i in range(len(flat_args))}, + )(*flat_args, copy_before=copy_before) + return event, treedef.unflatten(flat_outs) + + +def _event_elapsed(start_event, end_event): + return ffi.ffi_call( + "mgpu_event_elapsed", + result_shape_dtypes=jax.core.ShapedArray((), jnp.float32), + )(start_event, end_event) + + +def measure( + f: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> tuple[T, float]: + """Measures the time it takes to execute the function on the GPU. + + Args: + f: The function to measure. It must accept at least one argument and return + at least one output to be measurable. + *args: The arguments to pass to ``f``. + **kwargs: The keyword arguments to pass to ``f``. + + Returns: + The return value of ``f`` and the elapsed time in milliseconds. + """ + if not has_registrations: + raise RuntimeError( + "This function requires jaxlib >=0.4.36 with CUDA support." ) - finally: - mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event) - mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event) - return results, elapsed + + if not (args or kwargs): + # We require at least one argument and at least one output to ensure + # that there is a data dependency between `_event_record` calls in + # the resulting HLO program. + raise ValueError("Can only measure functions with arguments") + + @jax.jit + def run(*args, **kwargs): + start_event, (args, kwargs) = _event_record( + (args, kwargs), copy_before=True + ) + end_event, outs = _event_record(f(*args, **kwargs), copy_before=False) + if jax.tree.structure(outs).num_leaves == 0: + raise ValueError("Can only measure functions with at least one output") + return outs, _event_elapsed(start_event, end_event) + + jax.block_until_ready(run(*args, **kwargs)) # Warmup. + outs, elapsed = run(*args, **kwargs) + return outs, float(elapsed) class ProfilerSpec: @@ -203,7 +210,8 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): "tid": 1 + wg_idx + warpgroups_per_block * block_idx, }) else: # If we didn't break - events.append(block_events) + if block_events: + events.append(block_events) events = sorted(events, key=lambda x: x[0]["ts"]) flat_events = list(itertools.chain.from_iterable(events)) return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 87ffe09291fc..f6cab5654e64 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -40,32 +40,35 @@ WARPGROUP_SIZE: int = 128 DYNAMIC = -9223372036854775808 +DYNAMIC32 = -2147483648 # pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes def ptr_as_memref(ptr, memref_ty: ir.MemRefType): - if len(memref_ty.shape) == 0: - raise NotImplementedError i64 = ir.IntegerType.get_signless(64) rank = len(memref_ty.shape) - desc_ty = ir.Type.parse( - f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>" - ) + if rank > 0: + desc_ty = ir.Type.parse( + f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>" + ) + else: + desc_ty = ir.Type.parse("!llvm.struct<(ptr, ptr, i64)>") desc = llvm.UndefOp(desc_ty) desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2] ) - for i, s in enumerate(memref_ty.shape): - desc = llvm.InsertValueOp( - desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] - ) - for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): - desc = llvm.InsertValueOp( - desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] - ) + if rank > 0: + for i, s in enumerate(memref_ty.shape): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] + ) + for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] + ) return builtin.unrealized_conversion_cast([memref_ty], [desc]) @@ -104,28 +107,43 @@ def c(val: int | float, ty): raise NotImplementedError(ty) return arith.constant(ty, attr) +def _debug_scalar_ty_format(arg): + if ir.IndexType.isinstance(arg.type): + return "%llu", arg + if ir.IntegerType.isinstance(arg.type): + if ir.IntegerType(arg.type).width < 64: + arg = arith.extui(ir.IntegerType.get_signless(64), arg) + return "%llu", arg + if ir.F32Type.isinstance(arg.type): + return "%f", arg + if ir.F16Type.isinstance(arg.type): + arg = arith.extf(ir.F32Type.get(), arg) + return "%f", arg + raise NotImplementedError(f"Can't print the type {arg.type}") def debug_print(fmt, *args, uniform=True): type_formats = [] new_args = [] for arg in args: - ty_format = None - if ir.IndexType.isinstance(arg.type): - ty_format = "%llu" - if ir.IntegerType.isinstance(arg.type): - width = ir.IntegerType(arg.type).width - ty_format = "%llu" - if width < 64: - arg = arith.extui(ir.IntegerType.get_signless(64), arg) - if ir.F32Type.isinstance(arg.type): - ty_format = "%f" - if ir.F16Type.isinstance(arg.type): - ty_format = "%f" - arg = arith.extf(ir.F32Type.get(), arg) + if ir.VectorType.isinstance(arg.type): + index = ir.IndexType.get() + vec_ty = ir.VectorType(arg.type) + if len(vec_ty.shape) > 1: + raise NotImplementedError(vec_ty) + vec_args = [ + vector.extractelement(arg, position=c(i, index)) + for i in range(vec_ty.shape[0]) + ] + ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args)) + ty_format = f"[{','.join(ty_formats)}]" + new_args += args + else: + ty_format, arg = _debug_scalar_ty_format(arg) + new_args.append(arg) + if ty_format is None: raise NotImplementedError(arg.type) type_formats.append(ty_format) - new_args.append(arg) ctx = ( functools.partial(single_thread, per_block=False) if uniform @@ -293,6 +311,12 @@ def globaltimer(kind: Literal["low", "high"] | None = None): def bytewidth(ty: ir.Type): + # The actual width of TF32 is 19 bits. However, sinc we need to treat it as + # 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream + # MLIR, but it changed in + # https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd. + if ir.FloatTF32Type.isinstance(ty): + return 4 if ir.IntegerType.isinstance(ty): return ir.IntegerType(ty).width // 8 if ir.FloatType.isinstance(ty): @@ -649,33 +673,17 @@ def __getitem__(self, offset: ir.Value | int) -> "BarrierRef": 1, ) - def wait_parity(self, parity, expect_wait=False): - i1 = ir.IntegerType.get_signless(1) + def wait_parity(self, parity): i32 = ir.IntegerType.get_signless(32) - ticks = c(10000000, i32) - address = self.get_ptr() + ticks = arith.constant(i32, 10000000) parity = arith.extui(i32, parity) - if expect_wait: - nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks) - return - barrier_ready = llvm.inline_asm( - i1, - [address, parity], - "mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;", - "=b,l,r", - has_side_effects=True, - ) - should_wait = arith.xori(barrier_ready, c(1, i1)) - should_wait = llvm.intr_expect(should_wait, c(0, i1)) - with ir.InsertionPoint(scf.IfOp(should_wait).then_block): - nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks) - scf.yield_([]) + nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks) - def wait(self, expect_wait=False): + def wait(self): parities = memref.load(self.phases, []) parity, new_parities = self.update_parities(parities) memref.store(new_parities, self.phases, []) - self.wait_parity(parity, expect_wait=expect_wait) + self.wait_parity(parity) def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: i32 = ir.IntegerType.get_signless(32) @@ -1036,3 +1044,15 @@ def is_signed(dtype: jax.typing.DTypeLike) -> bool | None: elif jnp.issubdtype(dtype, jnp.integer): return jnp.issubdtype(dtype, jnp.signedinteger) return None + + +def getelementptr( + ptr: ir.Value, indices: Sequence[ir.Value | int], dtype: ir.Type +) -> ir.Value: + static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] + dyn_indices = [i for i in indices if not isinstance(i, int)] + return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) + + +def dyn_dot(x, y): + return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index ba0f130364ff..6f4d96fbd218 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -23,6 +23,7 @@ from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import vector +from jaxlib.mlir.dialects import nvvm import numpy as np import jax.experimental.mosaic.gpu as mgpu @@ -445,58 +446,13 @@ def wgmma( def wgmma_fence(array: mgpu.FragmentedArray): """Fences the array construction from WGMMA instructions. - This is a little workaround to force LLVM to initialize the PTX registers - before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats - in-register computation as pure and can move it after the fence, which is - explicitly disallowed by the PTX programming model. + LLVM treats in-register computation as pure and can move it after the fence, + which is explicitly disallowed by the PTX programming model. For that reason, + we insert an LLVM optimization barrier before the fence. """ - i32 = ir.IntegerType.get_signless(32) - index = ir.IndexType.get() - dtype = array.mlir_dtype - src_vec_ty = ir.VectorType(array.registers.flat[0].type) - assert src_vec_ty.shape == [2] - - if dtype == ir.F32Type.get(): - regs = [ # pylint: disable=g-complex-comprehension - vector.extractelement(reg, position=c(pos, index)) - for reg in array.registers.flat - for pos in range(2) - ] - reg_dtype = dtype - reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs) - ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))] - elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): - regs = [_as_i32_reg(reg) for reg in array.registers.flat] - reg_dtype = i32 - reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs) - ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))] - else: - raise NotImplementedError(dtype) - reg_constraints = ",".join(reg_constraints_list) - # Copy over the registers. ptxas should be able to remove the moves. - ptx_lines.append("wgmma.fence.sync.aligned") - ptx = ";\n".join(ptx_lines) + ";\n" - dtype_str = str(reg_dtype) - struct_ty = ir.Type.parse( - f"!llvm.struct<({','.join(dtype_str for _ in regs)})>" - ) - acc_struct = llvm.inline_asm( - struct_ty, regs, ptx, reg_constraints, - asm_dialect=0, has_side_effects=True, - ) - regs = [ - llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs)) - ] - if dtype == ir.F32Type.get(): - registers = _as_fragmented_reg_ndarray( - regs, array.mlir_dtype, array.registers.shape - ) - elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): - regs = [_unpack_i32(src_vec_ty, r) for r in regs] - registers = np.asarray(regs, dtype=object).reshape(array.registers.shape) - else: - raise NotImplementedError(dtype) - return mgpu.FragmentedArray(_registers=registers, _layout=array.layout, _is_signed=array.is_signed) + array = mgpu.optimization_barrier(array) + nvvm.wgmma_fence_aligned() + return array def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 803efa19056e..79989583fc28 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -75,7 +75,7 @@ def pre_jit(x): return host_local_array_to_global_array(inp, global_mesh, pspec) def post_jit(x): - return np.asarray(x.addressable_data(0)) + return jax.device_get(x.addressable_data(0)) in_tree = jax.tree.map(pre_jit, in_tree) out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding( @@ -359,22 +359,18 @@ def ltg_abstract_eval(arr, *, global_mesh, pspec): lambda ct, _, **params: ( host_local_array_to_global_array_p.bind(ct, **params),)) -def ltg_batcher(insert_axis, spmd_axis_name, axis_size, - axis_name, main_type, vals_in, dims_in, - global_mesh, pspec): +def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec): x, = vals_in d, = dims_in - new_parts = None if spmd_axis_name is None else spmd_axis_name + new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name new_pspec = list(pspec) new_pspec.insert(d, new_parts) new_pspec = P(*new_pspec) y = host_local_array_to_global_array_p.bind( x, global_mesh=global_mesh, pspec=new_pspec) return y, d -batching.spmd_axis_primitive_batchers[host_local_array_to_global_array_p] = partial( +batching.fancy_primitive_batchers[host_local_array_to_global_array_p] = partial( ltg_batcher, False) -batching.axis_primitive_batchers[host_local_array_to_global_array_p] = partial( - ltg_batcher, False, None) def _ltg_lowering(ctx, x, *, global_mesh, pspec): return [x] diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index b8e3daee48c8..987e461a39b2 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -47,12 +47,12 @@ def ravel_first_arg(f, unravel): return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped -@lu.transformation -def ravel_first_arg_(unravel, y_flat, *args): +@lu.transformation2 +def ravel_first_arg_(f, unravel, y_flat, *args): y = unravel(y_flat) - ans = yield (y,) + args, {} + ans = f(y, *args) ans_flat, _ = ravel_pytree(ans) - yield ans_flat + return ans_flat def interp_fit_dopri(y0, y1, k, dt): # Fit a polynomial to the results of a Runge-Kutta step. diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 34cb5328f36a..7e6527ad999a 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -30,6 +30,7 @@ from jax._src.pallas.core import no_block_spec as no_block_spec from jax._src.pallas.core import Unblocked as Unblocked from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p from jax._src.pallas.primitives import atomic_add as atomic_add diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md new file mode 100644 index 000000000000..40b109d102d5 --- /dev/null +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -0,0 +1,207 @@ +# Debugging Pallas + + + + + +[TOC] + +This document contains a collection of tips and tricks for debugging Pallas +programs. For any specific requests or ideas for improvement, please create +a ticket on https://github.com/jax-ml/jax/issues. + +## Debugging Tools + +### Interpret (HLO) Mode + +Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. + +Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. + +### debug_print + +The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation. + +For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option. + + +```python +kernel = pl.pallas_call(...) +compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) +result = compiled_kernel(x) +``` + +### Runtime Asserts + +Checkify can be used to insert runtime asserts, nan checks, out of bounds errors, etc. inside of a kernel. +Pallas implements two options for assertions: a *hard assert* which will crash the TPU if failed, and a *functionalized assertion* which will simulate a runtime assertion that can be thrown +as a Python error after the kernel has successfully executed. + +#### Hard assertion + +Hard assertions can be inserted with `checkify.check` +and running your program with the `--jax_pallas_enable_runtime_assert` flag. + +Your code will look like the following: + +```python +from jax.experimental import checkify + +def kernel(...): + checkify.check(x > y, "Check x > y failed") # Will halt if x <= y +``` + +This will print a relatively lengthy dump which resembles the following: + +``` +E1001 15:22:33.275768 4353 real_program_continuator.cc:1350] 0x0x0_TC0: [Physical location: dldgr4:pe1:1] generic::internal: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x169 (from TensorCoreSequencer:1:0x213): Check x > y failed HLO: main; HLO computation: main.3 +``` + +The benefit of a hard assertion is that it is guaranteed to either pass or +halt the TPU. The kernel will never proceed past the assertion if it fails. +However, the downside is that if the assertion fails you will +likely have to restart the program in order to run any other TPU operations, +and there is no Python error thrown that can be caught. + +#### Functionalized assertion +Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op like so: + +```python +from jax.experimental import checkify + +def kernel(...): + checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y + +kernel = pl.pallas_call(...) +checkified_kernel = checkify.checkify(kernel, + errors=checkify.all_checks) +error, result = checkified_kernel(x) +error.throw() +``` + +This will throw a Python error if any checks failed, such as if a NaN occurred +or if an out-of-bounds index was accessed. + +The benefit of a functionalized assert is that it will throw Python errors +that can be caught, and it will not interfere with downstream TPU operations. +However, it requires the kernel to successfully complete, meaning if your +error would have caused a TPU crash, the crash would still happen and +the error would not be thrown. + + +### Dumping Jaxprs + +Passing in `debug=True` into `pl.pallas_call` will print out the Jaxpr of the kernel as well as the lowered Mosaic code. + +```python +def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +x = jnp.ones((8, 128), dtype=jnp.float32) +pl.pallas_call( + kernel, + out_shape=jax.ShapeDTypeStruct((8, 128), jnp.float32) + debug=True, + name="my_call", +)(x, x) +``` + +This will output: + +``` +The kernel jaxpr for the pallas_call my_call for kernel function kernel at ...:1000: +{ lambda ; a:MemRef{float32[8,128]} b:MemRef{float32[8,128]} c:MemRef{float32[8,128]}. let + d:f32[8,128] <- a[:,:] + e:f32[8,128] <- b[:,:] + f:f32[8,128] = add d e + c[:,:] <- f + in () } + +The Mosaic module for the pallas_call my_call for kernel function kernel at ...:1000: +module { + func.func @main(%arg0: memref<8x128xf32, #tpu.memory_space>, %arg1: memref<8x128xf32, #tpu.memory_space>, %arg2: memref<8x128xf32, #tpu.memory_space>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} { + %c0 = arith.constant 0 : index + %c0_0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0_0] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + %c0_1 = arith.constant 0 : index + %c0_2 = arith.constant 0 : index + %1 = vector.load %arg1[%c0_1, %c0_2] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + %2 = arith.addf %0, %1 : vector<8x128xf32> + %c0_3 = arith.constant 0 : index + %c0_4 = arith.constant 0 : index + %3 = vector.load %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + vector.store %2, %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + return + } +} +``` + +### Dumping Mosaic Passes + +Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosaic if you are running into errors that are originating from the Mosaic compiler to see what code is actually being generated. + +Passing the `--xla_mosaic_dump_to=` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge. + +### Static Verification + +The static verification tool can be used to automatically detect race conditions in distributed kernels. +Because this tool uses formal verification, it is best used for small kernels (<=2 devices). + +Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=`, +which will output a Promela dump file. Afterwards, the dump file can be +analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run: + +``` +spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan +``` + + + +## Useful Command line flags + +* OOB Checks: `--xla_mosaic_on_device_checks=bounds` +* Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` + +* Dump Mosaic: `--xla_mosaic_dump_to=` +* Enable trace markers in XProf: `--xla_enable_transpose_trace` + +## Common Errors + +### INTERNAL Mosaic failed to compile TPU Kernel + +`INTERNAL Mosaic failed to compile TPU Kernel: Not implemented X` + +This error means that you hit an unimplemented case in the underlying Mosaic compiler. +Our recommended course of action here is to file a ticket if one does not already +exist for your specific error. + +In some cases, your error may be due to an operation which cannot be implemented +efficiently in the compiler, in which your best course of action is to find a workaround. This +is most commonly seen in `layout` and `shape_cast` errors. The important tip +to remember regarding layouts is that the last 2 dimensions of arrays in Pallas +are physically tiled into registers, so any reshapes, slicing, transposes, etc. +on the last 2 dimensions may trigger a relayout. + + +### VerificationError + +A verification error indicates that Pallas produced invalid code for Mosaic. + +This is a bug in Pallas, so please file a bug under https://github.com/jax-ml/jax/issues. + +### LoweringError + +This is a catch-all error type during Pallas to Mosaic lowering and can have many causes. +In most cases the error message should hint at what is wrong. + +For specific errors: + +* `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod + + diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index fbb3a3857c68..8da2a5095927 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -28,8 +28,10 @@ from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait +from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem diff --git a/jax/experimental/export/BUILD b/jax/experimental/pallas/ops/gpu/BUILD similarity index 50% rename from jax/experimental/export/BUILD rename to jax/experimental/pallas/ops/gpu/BUILD index 1246b0d407af..20ff2152c356 100644 --- a/jax/experimental/export/BUILD +++ b/jax/experimental/pallas/ops/gpu/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,31 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -# JAX-export provides APIs for exporting StableHLO for serialization purposes. - -load("@rules_python//python:defs.bzl", "py_library") -load( - "//jaxlib:jax.bzl", - "py_deps", +package( + default_applicable_licenses = [], + default_visibility = ["//jax:__subpackages__"], ) -licenses(["notice"]) +exports_files( + srcs = glob(["*.py"]), +) -package( - default_applicable_licenses = [], - default_visibility = ["//visibility:private"], +filegroup( + name = "triton_ops", + srcs = glob( + ["*.py"], + exclude = ["*_mgpu.py"], + ), ) -py_library( - name = "export", - srcs = [ - "__init__.py", - ], - srcs_version = "PY3", - # TODO: b/255503696: enable pytype - tags = ["pytype_unchecked_annotations"], - visibility = ["//visibility:public"], - deps = [ - "//jax", - ] + py_deps("numpy") + py_deps("flatbuffers"), +filegroup( + name = "mgpu_ops", + srcs = glob(["*_mgpu.py"]), ) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 66c9dea39734..198340ec0d11 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -177,13 +177,14 @@ def mha( debug: bool = False, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) num_warps_ = num_warps if num_warps_ is None: @@ -198,16 +199,16 @@ def mha( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) return pl.pallas_call( @@ -243,13 +244,14 @@ def _mha_forward( debug: bool, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) num_warps_ = num_warps if num_warps_ is None: @@ -260,7 +262,7 @@ def _mha_forward( out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse ), ] in_specs = [ @@ -268,16 +270,16 @@ def _mha_forward( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out, lse = pl.pallas_call( kernel, @@ -362,7 +364,8 @@ def mha_backward_kernel( block_d: int, ): del out_ref # Not needed - seq_len = q_ref.shape[0] + q_seq_len = q_ref.shape[0] + kv_seq_len = k_ref.shape[0] # Scan #1: dK and dV # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM. @@ -423,7 +426,7 @@ def inner_loop_dkdv(start_q, carry): lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0 dv, dk = lax.fori_loop( - lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) + lower_bound, pl.cdiv(q_seq_len, block_q1), inner_loop_dkdv, (dv, dk) ) dv_ref[...] = dv.astype(dv_ref.dtype) dk_ref[...] = dk.astype(dk_ref.dtype) @@ -486,7 +489,7 @@ def inner_loop_dq(start_k, dq): if causal: upper_bound = lax.div((start_q + 1) * block_q2, block_k2) else: - upper_bound = pl.cdiv(seq_len, block_k2) + upper_bound = pl.cdiv(kv_seq_len, block_k2) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) dq_ref[...] = dq.astype(dq_ref.dtype) @@ -508,9 +511,10 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, segment_ids, )[1](do) elif backward_pass_impl == "triton": - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ jax.ShapeDtypeStruct(q.shape, q.dtype), @@ -520,29 +524,29 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) - grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k)) + grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_k)) num_warps = 8 dq, dk, dv = pl.pallas_call( functools.partial( diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py new file mode 100644 index 000000000000..e8c818b884b5 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -0,0 +1,301 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FlashAttention3 implementation (using Mosaic GPU as the backend).""" + +import dataclasses +import functools +import itertools +import math +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + block_q: int + block_kv: int + max_concurrent_steps: int + + def __post_init__(self): + if self.block_q % 64: + raise ValueError(f"{self.block_q=} must be a multiple of 64") + if self.block_kv % 64: + raise ValueError(f"{self.block_kv=} must be a multiple of 64") + if self.max_concurrent_steps < 2: + raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") + + +@functools.partial(jax.jit, static_argnames=["config"]) +def attention(q, k, v, config: TuningConfig): + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + _, kv_seq_len, num_kv_heads, _ = k.shape + kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim) + if k.shape != kv_shape: + raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)") + if k.shape != kv_shape: + raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)") + if (dtype := q.dtype) != k.dtype or dtype != v.dtype: + raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}") + if num_q_heads % num_kv_heads: + raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}") + q_heads_per_kv_head = num_q_heads // num_kv_heads + if head_dim % 64: + raise ValueError(f"{head_dim=} must be divisible by 64") + if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): + raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") + + max_concurrent_steps = min( + config.max_concurrent_steps, kv_seq_len // config.block_kv + ) + block_q, block_kv = config.block_q, config.block_kv + + def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + batch = lax.axis_index("batch") + smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped + wg_idx = lax.axis_index("wg") + qo_smem2, k_smem, v_smem = smem_buffers + k_barriers, v_barriers, q_barriers = buffer_barriers + k_consumed_barriers, v_consumed_barriers = consumed_barriers + def perform_schedule_barrier(): + plgpu.barrier_arrive(schedule_barrier) + plgpu.barrier_wait(schedule_barrier) + + @pl.when(wg_idx < 2) + def _compute_wg(): + plgpu.set_max_registers(232, action="increase") + qo_smem = qo_smem2.at[wg_idx] + q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q + q_head = lax.axis_index("heads") + + plgpu.copy_gmem_to_smem( + q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], + qo_smem, + q_barriers.at[wg_idx], + ) + plgpu.barrier_wait(q_barriers.at[wg_idx]) + + m_i = plgpu.layout_cast( + jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, + ) + l_i = plgpu.layout_cast( + jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, + ) + acc = plgpu.layout_cast( + jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + + plgpu.barrier_wait(k_barriers.at[0]) + + pl.when(wg_idx == 1)(perform_schedule_barrier) + def kv_loop(kv_step, carry): + acc, m_i, l_i = carry + slot = lax.rem(kv_step, max_concurrent_steps) + + # QK + def compute_qk(acc_ref): + plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem.at[slot], (1, 0))) + perform_schedule_barrier() + return acc_ref[...] + qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32)) + plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + + # Softmax + # We keep m scaled by log2e to use FMA instructions when computing p. + log2e = math.log2(math.e) + m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e) + alpha = jnp.exp2(m_i - m_ij) + m_i = m_ij + p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0])) + acc *= lax.broadcast_in_dim(alpha, acc.shape, [0]) + l_i *= alpha + p16 = p.astype(dtype) + + def end_softmax_barriers(): + plgpu.barrier_arrive(schedule_barrier) # Done with softmax! + plgpu.barrier_wait(v_barriers.at[slot]) + plgpu.barrier_wait(schedule_barrier) # Wait until TensorCore is free. + # Can't fully explain why, but empirically the ordering here influences + # the performance of the final kernel quite significantly. + if head_dim <= 128: + l_i += p.sum(axis=1) + acc, l_i, m_i, p16 = lax.optimization_barrier((acc, l_i, m_i, p16)) + end_softmax_barriers() + else: + end_softmax_barriers() + l_i += p.sum(axis=1) + + # PV + def compute_pv(acc_ref): + plgpu.wgmma(acc_ref, p16, v_smem.at[slot]) + + wait_step = kv_step + 1 + wait_slot = lax.rem(wait_step, max_concurrent_steps) + @pl.when(wait_step < kv_seq_len // block_kv) + def _wait(): + plgpu.barrier_wait(k_barriers.at[wait_slot]) + acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc)) + plgpu.barrier_arrive(v_consumed_barriers.at[slot]) + return acc, m_i, l_i + if kv_seq_len % block_kv: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") + acc, m_i, l_i = lax.fori_loop( + 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) + ) + pl.when(wg_idx == 0)(perform_schedule_barrier) + del m_i # Not needed anymore + + # TODO(apaszke): Invert and multiply to avoid expensive divisions. + acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) + qo_smem[...] = acc.astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], + ) + plgpu.wait_smem_to_gmem(0) + @pl.when(wg_idx == 2) + def _memory_wg(): + plgpu.set_max_registers(40, action="decrease") + kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) + for i in range(max_concurrent_steps): + s = (batch, pl.ds(i * block_kv, block_kv), kv_head) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) + + def kv_loop(kv_step, _): + tma_step = kv_step + max_concurrent_steps + tma_slot = lax.rem(kv_step, max_concurrent_steps) + s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) + plgpu.barrier_wait(k_consumed_barriers.at[tma_slot]) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) + plgpu.barrier_wait(v_consumed_barriers.at[tma_slot]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) + lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) + + def run(refs): + q_ref, k_ref, v_ref, out_ref = refs + + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") + mesh = plgpu.GPUMesh( + grid=(batch_size, num_q_tiles, num_q_heads), + num_threads=3, + axis_names=("batch", "q_seq", "heads", "wg"), + approx_math=True, + ) + @pl.core_map(mesh) + def _kernel_entry(): + compute_wgs = 2 + tiling = plgpu.TilingTransform((64, 64)) + swizzle = plgpu.SwizzleTransform(128) + qo_scratch = plgpu.SMEM( + (compute_wgs, block_q, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + k_scratch = plgpu.SMEM( + (max_concurrent_steps, block_kv, head_dim), jnp.float16, + transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + ) + v_scratch = plgpu.SMEM( + (max_concurrent_steps, block_kv, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + pl.run_scoped( + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), + (qo_scratch, k_scratch, v_scratch), + ( + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(1, num_barriers=compute_wgs), + ), + (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, + plgpu.Barrier(num_arrivals=compute_wgs), + ) + + _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) + return out + + +@jax.jit +def attention_reference(q, k, v): + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) + q_reshaped = q.reshape( + batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim + ) + logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k) + m = logits.max(axis=-1, keepdims=True) + unnormalized = jnp.exp(logits - m) + l = unnormalized.sum(axis=-1, keepdims=True) + weights = unnormalized / l + return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + +def main(unused_argv): + num_q_heads = 16 + num_kv_heads = 16 + problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,)) + for batch_size, seq_len, head_dim in problem_it: + q_seq_len = kv_seq_len = seq_len + print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" + f"{num_q_heads=:<4} {head_dim=:<6} ====") + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + block_q = 64 + best = None + for block_kv in (256, 128, 64): + config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2) + try: + out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v) + if seq_len < 32768: + out_ref = attention_reference(q, k, v) + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + continue + raise + runtime_us = runtime_ms * 1e3 + matmul_flops = ( + 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size + ) + peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + print( + f"block_q={block_q:<4}block_kv={block_kv:<4}: {runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + if best is None or runtime_us < best[0]: + best = (runtime_us, achieved_tc_util) + break # Remove this for full autotuning. + if best is not None: + print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization") + + +if __name__ == "__main__": + from absl import app + import jax + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index a7e1b33e1f35..1c558c220ea9 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -14,6 +14,7 @@ """Module containing decode attention.""" from __future__ import annotations +import math import functools from typing import Any @@ -24,82 +25,115 @@ from jax.experimental.pallas import triton as plgpu import jax.numpy as jnp - def attn_forward_kernel( - q_ref, # [num_heads, head_dim] - k_ref, # [k_seq_len, head_dim] - v_ref, # [k_seq_len, head_dim] - o_ref: Any, # [num_heads, head_dim] + # inputs + q_ref, # [num_heads, head_dim] + k_ref, # [k_seq_len, head_dim] + v_ref, # [k_seq_len, head_dim] + start_idx_ref, # [] (i.e., scalar) + kv_seq_len_ref, # [] (i.e., scalar) + # outputs + o_ref: Any, # [num_heads, head_dim] *residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,] sm_scale: float, block_k: int, + block_h: int, + num_heads: int, ): - block_h, head_dim = q_ref.shape - k_seq_len, _ = k_ref.shape - start_q = pl.program_id(0) + _, head_dim = q_ref.shape + split_k_seq_len, _ = k_ref.shape + prog_i, prog_j = pl.program_id(0), pl.program_id(1) + q_slice = pl.ds(0, block_h) + q_mask = (jnp.arange(block_h) < num_heads - block_h * prog_i)[:, None] + + def _compute(start_idx, kv_seq_len, o, m_i, l_i): + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_h, head_dim]. + q = pl.load(q_ref, (q_slice, pl.ds(None)), mask=q_mask) + + def _dot(a, b): + # if a.shape[0] == 1: + # # Use matrix vector product + # return (a.T * b).sum(axis=0, keepdims=True) + return pl.dot(a, b) + + mask_indices = jnp.arange(block_k) + + # Loop over blocks of kv to process entire kv seq_len. + # Grid loops over q blocks over num_heads. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + curr_k_slice = pl.ds(start_k * block_k, block_k) + + k = pl.load(k_ref, (curr_k_slice, slice(None))) + qk = _dot(q, k.T) # [block_h, block_k] + if sm_scale != 1.0: + qk *= sm_scale # [block_h, block_k] + + # apply mask if start or sequence length is specified + if start_idx_ref is not None or kv_seq_len_ref is not None: + indices = (prog_j * split_k_seq_len + start_k * block_k + mask_indices) + mask = ((indices >= start_idx) & (indices < kv_seq_len))[None, :] + qk += (~mask) * (0.7 * jnp.finfo(qk.dtype).min) + + m_curr = qk.max(axis=-1) + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + s_curr = jnp.exp( + qk - m_next[:, None] + ) # Use m_next instead of m_curr to avoid a correction on l_curr + l_curr = s_curr.sum(axis=-1) + l_next = l_prev_corr + l_curr + v = pl.load(v_ref, (curr_k_slice, slice(None))) + o_curr = _dot(s_curr.astype(v.dtype), v) + + # flash2 unscaled_o + o_next = correction[:, None] * o_prev + o_curr + return o_next, m_next, l_next + + max_it = jnp.minimum(pl.cdiv((kv_seq_len - prog_j * split_k_seq_len), + block_k), split_k_seq_len // block_k) + (o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i)) + return o, m_i, l_i # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop. - m_i = jnp.zeros(block_h, dtype=jnp.float32) - float("inf") + m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min l_i = jnp.zeros(block_h, dtype=jnp.float32) o = jnp.zeros((block_h, head_dim), dtype=jnp.float32) - # Load q: it will stay in L1 throughout. Indices form a matrix because we - # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_h, head_dim]. - curr_q_slice = pl.dslice(start_q * block_h, block_h) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) - - def _dot(a, b): - # if a.shape[0] == 1: - # # Use matrix vector product - # return (a.T * b).sum(axis=0, keepdims=True) - return pl.dot(a, b) - - # Loop over blocks of kv to process entire kv seq_len. - # Grid loops over q blocks over num_heads. - def body(start_k, carry): - o_prev, m_prev, l_prev = carry - curr_k_slice = pl.dslice(start_k * block_k, block_k) - - k = pl.load(k_ref, (curr_k_slice, slice(None))) - qk = _dot(q, k.T) # [block_h, block_k] - if sm_scale != 1.0: - qk *= sm_scale # [block_h, block_k] - - m_curr = qk.max(axis=-1) - m_next = jnp.maximum(m_prev, m_curr) - correction = jnp.exp(m_prev - m_next) - l_prev_corr = correction * l_prev - s_curr = jnp.exp( - qk - m_next[:, None] - ) # Use m_next instead of m_curr to avoid a correction on l_curr - l_curr = s_curr.sum(axis=-1) - l_next = l_prev_corr + l_curr - v = pl.load(v_ref, (curr_k_slice, slice(None))) - o_curr = _dot(s_curr.astype(v.dtype), v) - - # flash2 unscaled_o - o_next = correction[:, None] * o_prev + o_curr - return o_next, m_next, l_next - - upper_bound = pl.cdiv(k_seq_len, block_k) - # o is left unscaled; it will be scaled in the final reduction step - o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + start_idx = split_k_seq_len * prog_j + if start_idx_ref is not None: + start_idx = jnp.maximum(start_idx, pl.load(start_idx_ref, ())) + kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len + if kv_seq_len_ref is not None: + kv_seq_len = jnp.minimum(kv_seq_len, pl.load(kv_seq_len_ref, ())) + + if start_idx_ref is None and kv_seq_len is None: + o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i) + else: + o, m_i, l_i = jax.lax.cond( + start_idx >= kv_seq_len, lambda: (o, m_i, l_i), + lambda: _compute(start_idx, kv_seq_len, o, m_i, l_i)) + # Write output to dram. if residual_refs: l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) - # Write output to dram. + vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None + pl.store(l_ref, q_slice, l_i, mask=vec_q_mask) + pl.store(m_ref, q_slice, m_i, mask=vec_q_mask) o = o.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) + pl.store(o_ref, (q_slice, pl.ds(None)), o, mask=q_mask) -def attn_unbatched( - q, # [num_heads, head_dim] - k, # [k_seq_len, head_dim] - v, # [k_seq_len, head_dim] +def decode_attn_unbatched( + q, # [num_heads, head_dim] + k, # [k_seq_len, head_dim] + v, # [k_seq_len, head_dim] + start_idx, # [] + kv_seq_len, # [] sm_scale: float, block_h: int, block_k: int, @@ -109,16 +143,11 @@ def attn_unbatched( grid: tuple[int, ...] | None, interpret: bool, debug: bool, + return_residuals: bool ): num_heads, head_dim = q.shape k_seq_len, _ = k.shape # Pad num query heads to 16 if needed, and slice output at the end. - original_num_heads = None - if num_heads < 16: - q = jnp.pad(q, ((0, 16 - num_heads), (0, 0))) - original_num_heads = num_heads - num_heads = q.shape[0] - block_h = min(block_h, num_heads) head_splits = pl.cdiv(num_heads, block_h) grid_ = grid if grid_ is None: @@ -127,11 +156,16 @@ def attn_unbatched( assert ( k_seq_len % k_splits == 0 ), f"{k_seq_len=} must be divisible by {k_splits=}" + assert k_seq_len // k_splits >= 16, ( + f"{k_seq_len=} divided by {k_splits=} must be >= 16.") + assert block_k >= 16, "block_k must be >= 16" k = k.reshape(k_splits, k_seq_len // k_splits, head_dim) v = v.reshape(k_splits, k_seq_len // k_splits, head_dim) - k_seq_len = k_seq_len // k_splits - assert min(num_heads, head_dim, k_seq_len) >= 16, "Minimum pl.dot size is 16" - block_k = min(block_k, k_seq_len) + split_k_seq_len = k_seq_len // k_splits + block_k = min(block_k, split_k_seq_len) + assert split_k_seq_len % block_k == 0, ( + f"Sequence length ({k_seq_len=}) split by {k_splits=} must by divisible by" + f" {block_k=}") num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 @@ -139,48 +173,53 @@ def attn_unbatched( attn_forward_kernel, sm_scale=sm_scale, block_k=block_k, + block_h=block_h, + num_heads=num_heads, ) o, l, m = pl.pallas_call( - kernel, - grid=grid_, - in_specs=[ - pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), - pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), - pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), - ], - out_specs=[ - pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o - pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l - pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m - ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps_, num_stages=num_stages - ), - out_shape=[ - jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o - jax.ShapeDtypeStruct( - shape=(k_splits, num_heads), dtype=jnp.float32 - ), # l - jax.ShapeDtypeStruct( - shape=(k_splits, num_heads), dtype=jnp.float32 - ), # m - ], - debug=debug, - interpret=interpret, - name="mha_forward", - )(q, k, v) + kernel, + grid=grid_, + in_specs=[ + pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), + pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + ] + + [None if start_idx is None else pl.BlockSpec((), lambda i, j: ())] + + [None if kv_seq_len is None else pl.BlockSpec((), lambda i, j: ())], + out_specs=[ + pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m + ], + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages + ), + out_shape=[ + jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # l + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # m + ], + debug=debug, + interpret=interpret, + name="mha_forward", + )(q, k, v, start_idx, kv_seq_len) # final round of flash m_next = m.max(axis=0) correction = jnp.exp(m - m_next[None]) - o = o * correction[:, :, None] + o = o * correction[:, :, None].astype(o.dtype) l_next = (l * correction).sum(axis=0) - o = o.sum(axis=0) / l_next[:, None] - - if original_num_heads is not None: - o = o[:original_num_heads, :] - return o + eps = jnp.finfo(l_next.dtype).eps + o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps) + if return_residuals: + return o, (l_next, m_next) + else: + return o @functools.partial( @@ -195,13 +234,16 @@ def attn_unbatched( "grid", "interpret", "debug", + "return_residuals" ], ) def mqa( - q, # [batch_size, num_heads, head_dim] - k, # [batch_size, k_seq_len, head_dim] - v, # [batch_size, k_seq_len, head_dim] - sm_scale: float = 1.0, + q, # [batch_size, num_heads, head_dim] + k, # [batch_size, k_seq_len, head_dim] + v, # [batch_size, k_seq_len, head_dim] + start_idx=None, # [batch_size] + kv_seq_len=None, # [batch_size] + sm_scale: float | None = None, block_h: int = 16, block_k: int = 256, k_splits: int = 16, @@ -210,9 +252,16 @@ def mqa( grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) + bs = q.shape[0] + if start_idx is not None: + start_idx = jnp.broadcast_to(start_idx, (bs,)) + if kv_seq_len is not None: + kv_seq_len = jnp.broadcast_to(kv_seq_len, (bs,)) inner = functools.partial( - attn_unbatched, + decode_attn_unbatched, sm_scale=sm_scale, block_h=block_h, block_k=block_k, @@ -222,8 +271,9 @@ def mqa( grid=grid, interpret=interpret, debug=debug, + return_residuals=return_residuals ) - return jax.vmap(inner)(q, k, v) + return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len) @functools.partial( @@ -238,26 +288,39 @@ def mqa( "grid", "interpret", "debug", + "return_residuals" ], ) def gqa( - q, # [batch_size, num_q_heads, head_dim] - k, # [batch_size, k_seq_len, num_kv_heads, head_dim] - v, # [batch_size, k_seq_len, num_kv_heads, head_dim] - sm_scale: float = 1.0, + q, # [batch_size, num_q_heads, head_dim] + k, # [batch_size, k_seq_len, num_kv_heads, head_dim] + v, # [batch_size, k_seq_len, num_kv_heads, head_dim] + start_idx=None, # [batch_size] + kv_seq_len=None, # [batch_size] + sm_scale: float | None = None, block_h: int = 16, - block_k: int = 256, + block_k: int = 128, k_splits: int = 16, num_warps: int | None = None, num_stages: int = 2, grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) batch_size, q_heads, head_dim = q.shape - kv_heads = k.shape[2] + k_seq_len, kv_heads = k.shape[1], k.shape[2] assert kv_heads == v.shape[2] assert q_heads % kv_heads == 0 + if start_idx is not None: + assert start_idx.ndim in (0, 1) + start_idx = jnp.broadcast_to(jnp.asarray(start_idx)[..., None], + (batch_size, kv_heads)) + if kv_seq_len is not None: + assert kv_seq_len.ndim in (0, 1) + kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len)[..., None], + (batch_size, kv_heads)) q_heads_per_kv_head = q_heads // kv_heads q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim) k_transposed = jnp.swapaxes( @@ -267,7 +330,7 @@ def gqa( v, 1, 2 ) # [batch_size, num_kv_heads, k_seq_len, head_dim] inner = functools.partial( - attn_unbatched, + decode_attn_unbatched, sm_scale=sm_scale, block_h=block_h, block_k=block_k, @@ -277,44 +340,100 @@ def gqa( grid=grid, interpret=interpret, debug=debug, + return_residuals=return_residuals, ) with_kv_heads = jax.vmap(inner) - o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed) - return o.reshape(batch_size, q_heads, head_dim) + o, *res = jax.vmap(with_kv_heads)( + q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len + ) + o = o.reshape(batch_size, q_heads, head_dim) + if return_residuals: + l, m = res[0] + l = l.reshape(batch_size, q_heads) + m = m.reshape(batch_size, q_heads) + return o, (l, m) + else: + return o -@functools.partial(jax.jit, static_argnames=["sm_scale"]) +@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"]) def mqa_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, head_dim] - v, # [bs, k_seq_len, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, head_dim] + v, # [bs, k_seq_len, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, + return_residuals=False ): + original_dtype = q.dtype + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) + bs = q.shape[0] + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32) - weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) - return jnp.einsum("bns,bsd->bnd", weights, v) + if sm_scale is not None and sm_scale != 1.0: + logits = logits * sm_scale + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) + + m = logits.max(axis=-1) + s = jnp.exp(logits - m[..., None]) + l = s.sum(axis=-1) + s = s / l[..., None] + o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype) + + if return_residuals: + return o, (l, m) + else: + return o @functools.partial(jax.jit, static_argnames=["sm_scale"]) def mha_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, num_k_heads, head_dim] - v, # [bs, k_seq_len, num_v_heads, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + bs = q.shape[0] + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) assert q.shape[1] == k.shape[2] logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum("bns,bsnd->bnd", weights, v) -@functools.partial(jax.jit, static_argnames=["sm_scale"]) +@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"]) def gqa_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, num_k_heads, head_dim] - v, # [bs, k_seq_len, num_v_heads, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, + return_residuals=False ): + original_dtype = q.dtype + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) bs, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] assert num_q_heads % num_kv_heads == 0 @@ -330,6 +449,27 @@ def gqa_reference( logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype( jnp.float32 ) - weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) - o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed) - return o.reshape(bs, num_q_heads, head_dim) + if sm_scale is not None and sm_scale != 1.0: + logits = logits * sm_scale + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) + + m = logits.max(axis=-1) + s = jnp.exp(logits - m[..., None]) + l = s.sum(axis=-1) + s = s / l[..., None] + o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype) + o = o.reshape(bs, num_q_heads, head_dim) + + if return_residuals: + l = l.reshape(bs, num_q_heads) + m = m.reshape(bs, num_q_heads) + return o, (l, m) + else: + return o diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 7d11e4faf299..d37afaf4d9e0 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -94,7 +94,7 @@ def layer_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -215,7 +215,7 @@ def layer_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index 2a7824315a0f..ff224c6dfde7 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -196,7 +196,7 @@ def rms_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 82bcde8153ef..0cb3d798d09e 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -582,20 +582,15 @@ def _fwd_cost_estimate( kernel_inputs_specs, kernel_outputs_specs, ) -> pl.CostEstimate | None: - full_cost = ( - mha_reference.lower( - q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale - ) - .compile() - .cost_analysis() + body_cost = pl.estimate_cost( + mha_reference, + q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale ) - if not full_cost: - return None input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) return pl.CostEstimate( - flops=full_cost[0]["flops"], - transcendentals=full_cost[0]["transcendentals"], + flops=body_cost.flops, + transcendentals=body_cost.transcendentals, bytes_accessed=input_bytes + output_bytes, ) diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py new file mode 100644 index 000000000000..d1e6bf1fd93d --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -0,0 +1,156 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the Threefry PRNG as a Pallas kernel.""" +from typing import Sequence +import jax +from jax import lax +from jax._src import prng +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + +Shape = Sequence[int] + +BLOCK_SIZE = (256, 256) + +_round_up = lambda x, y: (x + y - 1) // y * y + + +def blocked_iota(block_shape: Shape, + total_shape: Shape): + """Computes a sub-block of a larger shaped iota. + + Args: + block_shape: The output block shape of the iota. + total_shape: The total shape of the input tensor. + Returns: + Result of the blocked iota. + """ + iota_data = jnp.zeros(block_shape, dtype=jnp.uint32) + multiplier = 1 + for dim in range(len(block_shape)-1, -1, -1): + block_mult = 1 + counts_lo = lax.broadcasted_iota( + dtype=jnp.uint32, shape=block_shape, dimension=dim + ) + iota_data += counts_lo * multiplier * block_mult + multiplier *= total_shape[dim] + return iota_data + + +def _compute_scalar_offset(iteration_index, + total_size: Shape, + block_size: Shape): + ndims = len(iteration_index) + dim_size = 1 + total_idx = 0 + for i in range(ndims-1, -1, -1): + dim_idx = iteration_index[i] * block_size[i] + total_idx += dim_idx * dim_size + dim_size *= total_size[i] + return total_idx + + +def threefry_2x32_count(key, + shape: Shape, + unpadded_shape: Shape, + block_size: tuple[int, int]): + """Generates random bits using the Threefry hash function. + + This function is a fusion of prng.shaped_iota and prng.threefry_2x32 from + the JAX core library. + + Args: + key: A threefry key of shape (2,). + shape: The shape of the output. Must be divisible by `block_size`. + unpadded_shape: If `shape` is padded, then this is the shape of the + output tensor if it were not padded. This is important for indexing + calculations within the kernel. If `shape` is not padded, then this + should be equal to `shape`. + block_size: The block size of the kernel. + + Returns: + A tensor of random bits of shape `shape`. + """ + shape = tuple(shape) + if np.prod(shape) > jnp.iinfo(jnp.uint32).max: + raise ValueError( + f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}") + + if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0): + raise ValueError( + f"Shape dimension {shape[-2:]} must be divisible by {block_size}") + grid_dims = shape[:-2] + ( + shape[-2] // block_size[-2], shape[-1] // block_size[1],) + + def kernel(key_ref, out_ref): + counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims))) + offset = _compute_scalar_offset(counts_idx, unpadded_shape, block_shape) + counts_lo = blocked_iota(block_size, unpadded_shape) + counts_lo = counts_lo + offset + counts_lo = counts_lo.astype(jnp.uint32) + # TODO(justinfu): Support hi bits on count. + counts_hi = jnp.zeros_like(counts_lo) + k1 = jnp.reshape(key_ref[0, 0], (1, 1)) + k2 = jnp.reshape(key_ref[0, 1], (1, 1)) + o1, o2 = prng.threefry2x32_p.bind( + k1, k2, counts_hi, counts_lo) + out_bits = o1 ^ o2 + out_ref[...] = out_bits.reshape(out_ref.shape) + + key = key.reshape((1, 2)) + out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32) + block_shape = (1,) * (len(shape)-2) + block_size + result = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs), + grid=grid_dims, + out_shape=out, + )(key) + return result + +def plthreefry_random_bits(key, bit_width: int, shape: Shape): + if bit_width != 32: + raise ValueError("Only 32-bit PRNG supported.") + if len(shape) == 0: + return plthreefry_random_bits(key, bit_width, (1, 1))[0, 0] + elif len(shape) == 1: + return plthreefry_random_bits(key, bit_width, (1, *shape))[0] + + requires_pad = ( + shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0) + if requires_pad: + padded_shape = tuple(shape[:-2]) + ( + _round_up(shape[-2], BLOCK_SIZE[-2]), + _round_up(shape[-1], BLOCK_SIZE[-1]), + ) + padded_result = threefry_2x32_count( + key, padded_shape, shape, block_size=BLOCK_SIZE) + return padded_result[..., :shape[-2], :shape[-1]] + else: + return threefry_2x32_count(key, shape, shape, block_size=BLOCK_SIZE) + + +plthreefry_prng_impl = prng.PRNGImpl( + key_shape=(2,), + seed=prng.threefry_seed, + split=prng.threefry_split, + random_bits=plthreefry_random_bits, + fold_in=prng.threefry_fold_in, + name="pallas_threefry2x32", + tag="plfry") + +prng.register_prng(plthreefry_prng_impl) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index d00e0e90cd3e..41f0de3a0f61 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -49,12 +49,6 @@ from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key -# Remove this import after October 22th 2024. -from jax._src.tpu_custom_call import CostEstimate as CostEstimate - -# TODO(cperivol): Temporary alias to the global run_scoped. Remove -# this once everyone has migrated to the pallas core one. -from jax._src.pallas.primitives import run_scoped as run_scoped import types from jax._src.pallas.mosaic.verification import assume diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index a711c6bc472c..8ba7eb25d646 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -22,5 +22,3 @@ AUTO as AUTO, UNSPECIFIED as _UNSPECIFIED, ) - -from jax._src.pjit import _pjit_lower_cached, _pjit_lower diff --git a/jax/experimental/roofline/__init__.py b/jax/experimental/roofline/__init__.py new file mode 100644 index 000000000000..8d76c46858c7 --- /dev/null +++ b/jax/experimental/roofline/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from jax.experimental.roofline.roofline import ( + RooflineRuleContext as RooflineRuleContext, +) +from jax.experimental.roofline.roofline import RooflineShape as RooflineShape +from jax.experimental.roofline.roofline import RooflineResult as RooflineResult +from jax.experimental.roofline.roofline import roofline as roofline +from jax.experimental.roofline.roofline import register_roofline as register_roofline +from jax.experimental.roofline.roofline import ( + register_standard_roofline as register_standard_roofline, +) +from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad + + +import jax.experimental.roofline.rooflines as rooflines + +del rooflines diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py new file mode 100644 index 000000000000..42f72f005034 --- /dev/null +++ b/jax/experimental/roofline/roofline.py @@ -0,0 +1,342 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Protocol, Sequence +import numpy as np + +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax._src import api +from jax._src import core +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.api import make_jaxpr +from jax._src.interpreters.partial_eval import dce_jaxpr +from jax._src.interpreters.xla import abstractify +from jax._src.mesh import AbstractMesh, Mesh +from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map +from jax.experimental import shard_map + + +ShapeDtypeStructTree = Any + + +map = util.safe_map + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineRuleContext: + name_stack: source_info_util.NameStack + primitive: core.Primitive + avals_in: Sequence[core.AbstractValue] + avals_out: Sequence[core.AbstractValue] + jaxpr_eqn_ctx: core.JaxprEqnContext + mesh: Mesh | AbstractMesh + pin_lhs_in_vmem: bool + pin_rhs_in_vmem: bool + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineShape: + shape: tuple[int, ...] + dtype: np.dtype + + @classmethod + def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": + if not isinstance(aval, core.ShapedArray): + raise TypeError(f"Expected ShapedArray, got {type(aval)}.") + if not isinstance(aval.dtype, np.dtype): + raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.") + return cls(shape=aval.shape, dtype=aval.dtype) + + @property + def size(self) -> int: + return int(np.prod(self.shape)) + + @property + def bytes(self) -> int: + return int(self.size * self.dtype.itemsize) + + @classmethod + def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int: + return sum(cls.from_aval(aval).bytes for aval in avals) + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineResult: + flops: int = 0 + ici_bytes: dict[str, int] = field(default_factory=dict) + ici_latency: dict[str, int] = field(default_factory=dict) + hbm_bytes: int = 0 + peak_hbm_bytes: int = 0 + + @classmethod + def zeros(cls) -> "RooflineResult": + return cls() + + def __add__(self, other: "RooflineResult") -> "RooflineResult": + def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: + return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} + + return RooflineResult( + flops=self.flops + other.flops, + ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes), + ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency), + hbm_bytes=self.hbm_bytes + other.hbm_bytes, + peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes), + ) + + def __mul__(self, constant: int | float) -> "RooflineResult": + return RooflineResult( + flops=int(self.flops * constant), + ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()}, + ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()}, + hbm_bytes=int(self.hbm_bytes * constant), + peak_hbm_bytes=int(self.peak_hbm_bytes * constant), + ) + + def __rmul__(self, constant: int | float) -> "RooflineResult": + return self.__mul__(constant) + + +class _RooflineRule(Protocol): + def __call__( + self, ctx: RooflineRuleContext, *args: RooflineShape, **kw + ) -> RooflineResult: ... + + +_rooflines: dict[core.Primitive, _RooflineRule] = {} + + +def _roofline_interpreter( + f_name: str, + jaxpr: core.Jaxpr, + mesh: Mesh | AbstractMesh, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, +) -> RooflineResult: + name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline")) + + result = RooflineResult.zeros() + + env: dict[core.Var, RooflineShape] = {} + + def write(v: core.Var, node: RooflineShape): + assert node is not None + env[v] = node + + def read(v: core.Atom) -> RooflineShape: + if type(v) is core.Literal: + return RooflineShape.from_aval(abstractify(v.val)) + else: + assert isinstance(v, core.Var) + return env[v] + + def aval(v: core.Atom) -> core.AbstractValue: + if type(v) is core.Literal: + return abstractify(v.val) + else: + return v.aval + + def calculate_peak_hbm_bytes() -> int: + return int( + sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values()) + ) + + make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x)) + map( + write, + jaxpr.constvars, + map(make_roofline_shape, jaxpr.constvars), + ) + map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) + last_used = core.last_used(jaxpr) + for eqn in jaxpr.eqns: + source_info = eqn.source_info.replace( + name_stack=name_stack + eqn.source_info.name_stack + ) + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=source_info.name_stack + ): + if "jaxpr" in eqn.params: + result += _roofline_interpreter( + util.wrap_name(f_name, eqn.primitive.name), + eqn.params["jaxpr"], + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + else: + if eqn.primitive not in _rooflines: + msg = f"No roofline rule for {eqn.primitive}." + for attr in dir(eqn): + if not attr.startswith("_"): + msg += f"\n{attr}: {getattr(eqn, attr)}" + raise NotImplementedError(msg) + rule = _rooflines[eqn.primitive] + result += rule( + RooflineRuleContext( + name_stack=source_info.name_stack, + primitive=eqn.primitive, + avals_in=map(aval, eqn.invars), + avals_out=map(aval, eqn.outvars), + jaxpr_eqn_ctx=eqn.ctx, + mesh=mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ), + *map(read, eqn.invars), + **eqn.params, + ) + + map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) + core.clean_up_dead_vars(eqn, env, last_used) + result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes()) + + return result + + +def _f_with_vjp(f: Callable): + @util.wraps(f) + def wrapped(*args): + primals, f_vjp = api.vjp(f, *args) + return f_vjp(tree_map(jnp.bfloat16, primals)) + + return wrapped + + +def roofline( + f: Callable, + mesh: Mesh | AbstractMesh, + in_specs: shard_map.Specs, + out_specs: shard_map.Specs, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, + vjp: bool = False, + print_jaxpr: bool = False, +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]: + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + wrapped_f = shard_map.shard_map(f, mesh, in_specs, out_specs) + if vjp: + wrapped_f = _f_with_vjp(wrapped_f) + + jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) + + def make_sharded_shape_dtype_struct( + shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs + ) -> api.ShapeDtypeStruct: + return api.ShapeDtypeStruct( + shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) + ) + + out_specs_flat = broadcast_prefix(out_specs, out_shapes) + flat_out_shapes, treedef = tree_flatten(out_shapes) + flat_out_shapes = map( + make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat + ) + out_shapes = tree_unflatten(treedef, flat_out_shapes) + + used_outputs = (True,) * len(jaxpr.jaxpr.outvars) + jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) + try: + jaxpr = [e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p][ + -1 + ].params["jaxpr"] + except KeyError: + raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.") + + if print_jaxpr: + print(jaxpr) + + return out_shapes, _roofline_interpreter( + util.fun_qual_name(f), + jaxpr, + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + + return wrapped + + +def register_roofline(prim: core.Primitive): + def register(rule: _RooflineRule): + _rooflines[prim] = rule + return rule + + return register + + +def register_standard_roofline(prim: core.Primitive): + def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): + return RooflineResult.zeros() + + _rooflines[prim] = standard_rule + + +def roofline_and_grad( + f: Callable, + mesh: Mesh | AbstractMesh, + in_specs: shard_map.Specs, + out_specs: shard_map.Specs, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, + print_jaxpr: bool = False, +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]: + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + primal_shapes, fwd_result = roofline( + f, + mesh, + in_specs, + out_specs, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + print_jaxpr=print_jaxpr, + )(*args) + + return ( + primal_shapes, + fwd_result, + roofline( + f, + mesh, + in_specs, + out_specs, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + vjp=True, + print_jaxpr=print_jaxpr, + )( + *tree_map( + lambda x: api.ShapeDtypeStruct( + x.shape, + jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16, + sharding=x.sharding, + ), + args, + ) + )[1], + ) + + return wrapped diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py new file mode 100644 index 000000000000..cfdb6358bc76 --- /dev/null +++ b/jax/experimental/roofline/rooflines.py @@ -0,0 +1,270 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from dataclasses import replace +import itertools as it +import numpy as np + +from jax._src import ad_util +from jax._src import core, util +from jax._src import ops +from jax._src import prng +from jax._src import random +from jax._src.lax import ( + ann, + convolution, + fft, + lax, + linalg, + parallel as lax_parallel, + slicing, + special, + windowed_reductions, +) +from jax.experimental import roofline +from jax.experimental import shard_map + + +for prim in it.chain( + ad_util.__dict__.values(), + ann.__dict__.values(), + convolution.__dict__.values(), + fft.__dict__.values(), + lax.__dict__.values(), + linalg.__dict__.values(), + ops.__dict__.values(), + prng.__dict__.values(), + random.__dict__.values(), + shard_map.__dict__.values(), + slicing.__dict__.values(), + special.__dict__.values(), + windowed_reductions.__dict__.values(), +): + if isinstance(prim, core.Primitive): + roofline.register_standard_roofline(prim) + + +@roofline.register_roofline(lax.dot_general_p) +def _dot_general_roofline( + ctx: roofline.RooflineRuleContext, + *args, + dimension_numbers: lax.DotDimensionNumbers, + **kw, +) -> roofline.RooflineResult: + lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + (lhs_contract, _), (lhs_batch, _) = dimension_numbers + + flops = ( + 2 + * lhs.size + * rhs.size + / np.prod([lhs.shape[i] for i in lhs_contract]) + / np.prod([lhs.shape[i] for i in lhs_batch]) + ) + + hbm_bytes = 0 + if not ctx.pin_lhs_in_vmem: + hbm_bytes += lhs.bytes + hbm_bytes += out.bytes + if not ctx.pin_rhs_in_vmem: + hbm_bytes += rhs.bytes + + return roofline.RooflineResult(flops=int(flops), hbm_bytes=hbm_bytes) + + +def _return_zeros_if_one_sized_axis( + ctx: roofline.RooflineRuleContext, axes: tuple[str, ...] +) -> roofline.RooflineResult | None: + axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes]) + if axes_size > 1: + return None + return roofline.RooflineResult( + ici_bytes={axis: 0 for axis in axes}, + ici_latency={axis: 0 for axis in axes}, + ) + + +def _ring_collective_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + is_reduce: bool = True, + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes): + return zeros_result + + mesh = ctx.mesh.shape + current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in) + if is_reduce: + current_shard_size /= np.prod([mesh[axis] for axis in axes]) + + # We model the slowest color as the bottleneck. + sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True) + num_axes = len(sorted_axes) + + ici_bytes = 0 + # Phase split. + current_shard_size //= num_axes + for axis in sorted_axes: + axis_size = mesh[axis] + # Do phase. + ici_bytes += current_shard_size * (axis_size - 1) + # Increase shard size. + current_shard_size *= axis_size + + # Bottleneck is the longest axis. + ici_latency = mesh[sorted_axes[0]] * num_axes + + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in sorted_axes}, + ici_latency={axis: int(ici_latency) for axis in sorted_axes}, + ) + + +roofline.register_roofline(lax_parallel.reduce_scatter_p)( + lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw) +) +roofline.register_roofline(lax_parallel.all_gather_p)( + lambda *args, axis_name, **kw: _ring_collective_roofline( + *args, axes=axis_name, is_reduce=False, **kw + ) +) + + +def _scalar_collective_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in] + ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes]) + return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw) + + +roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline) +roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline) + + +@roofline.register_roofline(shard_map.psum2_p) +def _psum2_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw) + + def double_dict(d: dict[str, int]) -> dict[str, int]: + return {k: v * 2 for k, v in d.items()} + + return roofline.RooflineResult( + ici_bytes=double_dict(ring_roofline.ici_bytes), + ici_latency=double_dict(ring_roofline.ici_latency), + ) + + +@roofline.register_roofline(lax_parallel.all_to_all_p) +def _all_to_all_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis_name: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name): + return zeros_result + + mesh = ctx.mesh.shape + size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([ + mesh[axis] for axis in axis_name + ]) + + smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0] + num_axes = len(axis_name) + bisection_bw = mesh[smallest_axis] ** (num_axes - 1) + if mesh[smallest_axis] > 2: + # Times 2 because of wraparound. + bisection_bw *= 2 + + # Half the data needs to cross the bisection on average. + ici_bytes = size / 2 / bisection_bw + + # The latency is the max number of hops across the mesh. + ici_latency = sum(mesh[axis] / 2 for axis in axis_name) + + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in axis_name}, + ici_latency={axis: int(ici_latency) for axis in axis_name}, + ) + + +@roofline.register_roofline(lax_parallel.ppermute_p) +def _ppermute_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis_name: tuple[str, ...], + perm: tuple[tuple[int, int], ...], + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name): + return zeros_result + + mesh = ctx.mesh.shape + mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name] + shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in) + + ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float) + ici_latency = 0 + + for src, dst in perm: + if src == dst: + continue + # Perms are linearized. + src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims)) + dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims)) + + ici_latency_for_perm = 0 + + # For each dimension. + for i in range(len(axis_name)): + dim_size = mesh_dims[i] + src_pos = src_coords[i] + dst_pos = dst_coords[i] + + if src_pos != dst_pos: + # Calculate distance with wraparound. + clockwise_dist = (dst_pos - src_pos) % dim_size + counter_dist = (src_pos - dst_pos) % dim_size + direction = 1 if clockwise_dist <= counter_dist else -1 + + curr_pos = src_pos + while curr_pos != dst_pos: + curr_coords = util.tuple_update(src_coords, i, curr_pos) + next_pos = (curr_pos + direction) % dim_size + next_coords = util.tuple_update(curr_coords, i, next_pos) + ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1 + curr_pos = next_pos + + distance = min(clockwise_dist, counter_dist) + ici_latency_for_perm += distance + + ici_latency = max(ici_latency, ici_latency_for_perm) + + ici_bytes = shard_size * max(ici_contention.values(), default=0) + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in axis_name}, + ici_latency={axis: int(ici_latency) for axis in axis_name}, + ) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 24bed503491b..b4609282e2f8 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -46,16 +46,17 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import AbstractMesh, Mesh +from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) +from jax._src.extend import ffi from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy -from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, +from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2) + split_list, subs_list2) from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -454,30 +455,9 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] class ShardMapPrimitive(core.Primitive): multiple_results = True - def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, rewrite: bool, auto: frozenset[AxisName] - ) -> Sequence[MaybeTracer]: - top_trace = core.find_top_trace(args) - fun, env_todo = process_env_traces(fun, top_trace.level, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto) - - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - out_names = out_names_thunk() - _, xforms = env_todo() - for t in xforms: - out_names = t(out_names) - return out_names - - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_shard_map( # pytype: disable=attribute-error - shard_map_p, fun, tracers, mesh=mesh, in_names=in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - todos, _ = env_todo() - return map(core.full_lower, core.apply_todos(todos, outs)) + def bind_with_trace(self, trace, fun_and_args, params): + fun, *args = fun_and_args + return trace.process_shard_map(shard_map_p, fun, args, **params) def get_bind_params(self, params): new_params = dict(params) @@ -489,56 +469,38 @@ def get_bind_params(self, params): shard_map_p = ShardMapPrimitive('shard_map') -@lu.transformation_with_aux -def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, - rewrite, auto, *args: Any): - outs = yield args, {} - todos, out_names_transforms = [], [] - while True: - tracers = [x for x in outs if isinstance(x, core.Tracer) - and (level is None or x._trace.level > level)] - if tracers: - ans = max(tracers, key=op.attrgetter('_trace.level')) - else: - break - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, (todo, xform) = trace.post_process_shard_map( - outs, mesh, in_names, out_names_thunk, check_rep, rewrite, auto) - todos.append(todo) - out_names_transforms.append(xform) - yield outs, (tuple(todos), tuple(out_names_transforms)) - # Staging def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, - in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh, + in_tracers: Sequence[Any], *, mesh: Mesh, in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, rewrite: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: + in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) - main = trace.main - with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) - out_avals_ = map(_check_shapedarray, genavals) + with (core.extend_axis_env_nd(list(mesh.shape.items())), + set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))): + jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) - in_rep = map(partial(_in_names_to_rep, mesh), in_names) if check_rep: + in_rep = map(partial(_in_names_to_rep, mesh), in_names) out_rep = _check_rep(mesh, jaxpr, in_rep) _check_reps(mesh, out_names_thunk(), out_rep) - out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_) + out_avals = map(_check_shapedarray, out_avals_) + out_avals = [_check_shapedarray(_unshard_aval(mesh, names, aval)) + for names, aval in zip(out_names_thunk(), out_avals)] source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.instantiate_const, consts)) + constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with core.extend_axis_env_nd(mesh.shape.items()): + with core.extend_axis_env_nd(list(mesh.shape.items())): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, @@ -568,17 +530,32 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + new_mesh = AbstractMesh( + mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names}) + new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + # TODO(yashkatariya): Reset the mesh properly based on the input avals if the + # mesh of shard_map specifies collective axes. + if config.sharding_in_types.value: + spec = _names_to_pspec(names)._normalized_spec(aval.ndim) + new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking @@ -646,14 +623,18 @@ def _rule_missing(prim: core.Primitive, *_, **__): # Lowering def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, names, aval_in + ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in ) -> ir.Attribute: axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) - return ns._to_sdy_sharding(aval_in.ndim).build() + sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) + if auto: + for dim_sharding in sdy_sharding.dimension_shardings: + dim_sharding.is_closed = False + return sdy_sharding.build() def _shard_map_lowering_shardy( @@ -683,10 +664,10 @@ def _shard_map_lowering_shardy( return out_nodes in_shardings = sdy.TensorShardingPerValueAttr.get(map( - partial(_shardy_shard_map_sharding, ctx, mesh), + partial(_shardy_shard_map_sharding, ctx, mesh, auto), in_names, ctx.avals_in)) out_shardings = sdy.TensorShardingPerValueAttr.get(map( - partial(_shardy_shard_map_sharding, ctx, mesh), + partial(_shardy_shard_map_sharding, ctx, mesh, auto), out_names, ctx.avals_out)) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) manual_computation_op = sdy.ManualComputationOp( @@ -804,28 +785,23 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, mesh = get_mesh_from_args(args, mesh) args = map(partial(_unmatch_spec, mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) - with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: - fun, out_rep = _shmap_subtrace(fun, main, in_rep) - with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main): - outs = fun.call_wrapped(*args) - del main + outs, out_rep = _run_shmap(fun, mesh, args, in_rep, check_rep) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_rep: - _check_reps(mesh, out_names_thunk(), out_rep()) + _check_reps(mesh, out_names_thunk(), out_rep) pspecs = map(_names_to_pspec, out_names_thunk()) return map(partial(_match_spec, mesh, check_rep), pspecs, outs) core.EvalTrace.process_shard_map = _shard_map_impl -@lu.transformation_with_aux -def _shmap_subtrace(main, in_rep, *in_vals): - t = main.with_cur_sublevel() - in_tracers = map(partial(ShardMapTracer, t), in_rep, in_vals) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - outs, out_rep = unzip2((t.val, t.rep) for t in out_tracers) - del t, in_tracers, ans, out_tracers - yield outs, out_rep +def _run_shmap(f, mesh, args, reps, check_rep): + trace = ShardMapTrace(mesh, check_rep) + in_tracers = map(partial(ShardMapTracer, trace), reps, args) + with core.set_current_trace(trace): + with core.extend_axis_env_nd(mesh.shape.items()): + ans = f.call_wrapped(*in_tracers) + outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) + return outs, out_rep def _names_to_pspec(names: AxisNames) -> PartitionSpec: ndmin = max(names) + 1 if names else 0 @@ -877,20 +853,21 @@ class ShardMapTrace(core.Trace): mesh: Mesh check: bool - def __init__(self, *args, mesh, check): - super().__init__(*args) + def __init__(self, mesh, check): self.mesh = mesh self.check = check - def pure(self, val): - val_ = _unmatch_spec(self.mesh, {}, val) - return ShardMapTracer(self, None, val_) - - def sublift(self, tracer): - return ShardMapTracer(self, tracer.rep, tracer.val) + def to_val_rep_pair(self, val): + if isinstance(val, ShardMapTracer): + return val.val, val.rep + elif isinstance(val, Tracer): + raise Exception("Shouldn't have any non-shard_map tracers") + else: + val_ = _unmatch_spec(self.mesh, {}, val) + return val_, None def process_primitive(self, prim, tracers, params): - in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) + in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) eager_rule = eager_rules.get(prim) if eager_rule: out_vals = eager_rule(self.mesh, *in_vals, **params) @@ -926,36 +903,21 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros - in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) - fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(ShardMapTracer, self), out_rep(), out_vals) - - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable + in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) + out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check) + return map(partial(ShardMapTracer, self), out_rep, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): - # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros - in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) - fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(ShardMapTracer, self), out_rep(), out_vals) - - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - - def process_axis_index(self, frame): - with core.eval_context(), jax.disable_jit(False): - return jax.jit(lambda: jax.lax.axis_index(frame.name))() + in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) + out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check) + return map(partial(ShardMapTracer, self), out_rep, out_vals) class ShardMapTracer(core.Tracer): @@ -970,16 +932,14 @@ def __init__(self, trace, rep, val): @property def aval(self): aval = core.get_aval(self.val) - if (isinstance(aval, core.ConcreteArray) and - self.rep == set(self._trace.mesh.axis_names)): + return core.mapped_aval(self._trace.mesh.size, 0, aval) + + def to_concrete_value(self): + if self.rep == set(self._trace.mesh.axis_names): with core.eval_context(): - return core.get_aval(self.val[0]) + return core.to_concrete_value(self.val[0]) else: - aval = core.raise_to_shaped(aval) - return core.mapped_aval(self._trace.mesh.size, 0, aval) - - def full_lower(self) -> ShardMapTracer: - return self + return None def __str__(self) -> str: with core.eval_context(): @@ -1023,17 +983,16 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): # New primitives for efficient transposition # psum2_p is like psum_p except has a different transpose, so mostly copied: -psum2_p = core.AxisPrimitive('psum2') +psum2_p = core.Primitive('psum2') psum2_p.multiple_results = True psum2_p.def_impl(lax_parallel.psum_p.impl) psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) -batching.primitive_batchers[psum2_p] = partial(lax_parallel._reduction_batcher, psum2_p) -batching.axis_primitive_batchers[psum2_p] = \ +batching.fancy_primitive_batchers[psum2_p] = \ partial(lax_parallel._batched_reduction_collective, psum2_p, lambda v, axis_size: axis_size * v) -core.axis_substitution_rules[psum2_p] = \ - partial(lax_parallel._subst_all_names_in_param, 'axes') +batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes') + def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): del args return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) @@ -1046,7 +1005,7 @@ def pbroadcast(x, axis_name): xs, treedef = tree_flatten(x) ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) return tree_unflatten(treedef, ys) -pbroadcast_p = core.AxisPrimitive('pbroadcast') +pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.multiple_results = True pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) @@ -1057,12 +1016,6 @@ def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): axis_index_groups=axis_index_groups) return vals_out, dims_in batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -def _pbroadcast_axis_batcher(size, name, trace_type, vals_in, dims_in, *, axes, - groups): - raise NotImplementedError # vmap with axis name involved in this primitive -batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_axis_batcher -core.axis_substitution_rules[pbroadcast_p] = \ - partial(lax_parallel._subst_all_names_in_param, 'axes') ad.deflinear2(pbroadcast_p, lambda cts, *_, axes, axis_index_groups: psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) @@ -1338,6 +1291,40 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) return out_vals, out_rep +@register_check(control_flow.conditionals.cond_p) +def _cond_rule(mesh, *in_rep, branches): + _, *args_rep = in_rep + out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) + for branch in branches[1:]: + out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) + if not out_rep_ == out_rep: + raise Exception("The branches of cond produced mismatched replication " + "types. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument " + "to shard_map") + return out_rep + +@register_rewrite(control_flow.conditionals.cond_p) +def _cond_rewrite(mesh, in_rep, *args, branches): + pred_rep, *args_rep = in_rep + _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) + for branch in branches[1:]: + _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep) + if out_rep: + out_rep = map(op.and_, out_rep, out_rep_) + else: + out_rep = out_rep_ + out_rep = map(partial(op.and_, pred_rep), out_rep) + branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep) + for branch in branches) + out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) + return out_vals, out_rep + +@register_check(control_flow.conditionals.platform_index_p) +def _platform_index_rule(mesh, *_, **__): + return set(mesh.axis_names) +register_norewrite(control_flow.conditionals.platform_index_p) @register_rewrite(core.closed_call_p) def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): @@ -1388,20 +1375,17 @@ def fwd_jaxpr_thunk_(*zeros): def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) - -# TODO(mattjj): make standard_check handle multiple outputs, share code @register_check(control_flow.solves.linear_solve_p) -def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs): - in_rep_ = [r for r in in_rep if r is not None] - assert in_rep - if not in_rep_[:-1] == in_rep_[1:]: - msg = ("shard_map check_rep rewrite failed. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a workaround pass the " - "check_rep=False argument to shard_map") - raise Exception(msg) - return [in_rep_[0]] * len(jaxprs.solve.out_avals) +def _linear_solve_check(mesh, *in_rep, jaxprs, **_): + out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep) + return [out_rep] * len(jaxprs.solve.out_avals) register_standard_rewrite(control_flow.solves.linear_solve_p) +@register_check(ffi.ffi_call_p) +def _ffi_call_check(mesh, *in_rep, result_avals, **_): + out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep) + return [out_rep] * len(result_avals) +register_standard_rewrite(ffi.ffi_call_p) del _check_rules[lax.tie_p] @@ -1421,23 +1405,23 @@ def _shard_map_batch( check_rep: bool, rewrite: bool, auto: frozenset) -> Sequence[batching.BatchTracer]: - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers) - if all(bdim is batching.not_mapped for bdim in in_dims): - return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names, - out_names_thunk=out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError - fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims)) - new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore + new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] for ax in names} for names, d in zip(in_names, in_dims)] - spmd_axis_name = trace.spmd_axis_name + spmd_axis_name = trace.axis_data.spmd_name if spmd_axis_name is not None: used = {n for names in in_names for ns in names.values() for n in ns} if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore + new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped else ns for ns, d in zip(new_in_names, in_dims)] + new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) + new_axis_data = batching.AxisData(trace.axis_data.name, new_size, trace.axis_data.spmd_name) + else: + new_axis_data = trace.axis_data + fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) @as_hashable_function(closure=out_names_thunk) def new_out_names_thunk(): return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) @@ -1445,25 +1429,13 @@ def new_out_names_thunk(): new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) - out_vals = prim.bind(fun, *in_vals, **new_params) + with core.set_current_trace(trace.parent_trace): + out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, source_info=source_info_util.current()) return map(make_tracer, out_vals, out_dims()) batching.BatchTrace.process_shard_map = _shard_map_batch -def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - del mesh, in_names, out_names_thunk, check_rep, rewrite, auto - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - m = trace.main - def todo(vals): - trace = m.with_cur_sublevel() - return map(partial(batching.BatchTracer, trace), vals, dims, srcs) - out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims) - return vals, (todo, out_names_transform) -batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process - def _batch_out_names(spmd_axis_name, dims, out_names): out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] for ax in names} for names, d in zip(out_names, dims)] @@ -1480,11 +1452,11 @@ def _batch_out_names(spmd_axis_name, dims, out_names): def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): - primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) - f_jvp = ad.jvp_subtrace(f, trace.main) + f_jvp = ad.jvp_subtrace(f, trace.tag) f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] @@ -1496,36 +1468,22 @@ def new_out_names_thunk(): out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) - result = shard_map_p.bind(f_jvp, *args, **params) + result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) primal_out, tangent_out = tree_unflatten(out_tree(), result) tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t for p, t in zip(primal_out, tangent_out)] return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] ad.JVPTrace.process_shard_map = _shard_map_jvp -def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - del mesh, in_names, out_names_thunk, check_rep, rewrite, auto - primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) - out, treedef = tree_flatten((primals, tangents)) - tangents_nz = [type(t) is not ad.Zero for t in tangents] - m = trace.main - def todo(x): - primals, tangents = tree_unflatten(treedef, x) - return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents) - def out_names_transform(out_names): - return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz)) - return out, (todo, out_names_transform) -ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process - def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): + tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - all_names = _all_mesh_names(mesh) + all_names = _all_mesh_names_except_spmd(mesh, trace) in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) - f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False) + f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, False) f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits( f, (*in_knowns,), (*in_avals_sharded,)) @@ -1540,7 +1498,7 @@ def known_out_names(): known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, rewrite=rewrite, auto=auto) - out = shard_map_p.bind(f_known, *in_consts, **known_params) + out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux() num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) @@ -1553,7 +1511,7 @@ def known_out_names(): {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.full_raise, env) + env_tracers = map(trace.to_jaxpr_tracer, env) unk_arg_tracers = [t for t in tracers if not t.is_known()] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, @@ -1569,64 +1527,15 @@ def known_out_names(): return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval -def _shard_map_partial_eval_post_process( - trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): - del check_rep - all_names = _all_mesh_names(mesh) - unk_tracers = [t for t in tracers if not t.is_known()] - jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers) - # TODO(mattjj): output forwarding optimization - which = [not getattr(v.aval, 'shape', True) for v in jaxpr.constvars] - res = [jax.lax.broadcast(x, (1,)) if not getattr(v.aval, 'shape', True) else x - for x, v in zip(res, jaxpr.constvars)] - jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) - - out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers]) - out = [*consts, *res] - main = trace.main - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_ = pe.convert_constvars_jaxpr(jaxpr) - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res_ = split_list(out, [len(out) - len(res)]) - const_tracers = map(trace.new_instantiated_const, res_) - env_tracers = map(trace.full_raise, env) - - staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env) - staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, - out_names=(*out_names_unknown,), check_rep=False, - rewrite=rewrite, auto=auto) - - out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_) - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in out_avals] - name_stack = trace._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - shard_map_p, staged_params, effs, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def out_names_transform(out_names): - nonlocal out_names_unknown - out_names_unknown, out_names_known = partition_list(out_knowns, out_names) - return (*out_names_known,) + ({0: all_names},) * len(res) - out_names_unknown: list | None = None - - return out, (todo, out_names_transform) -pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process - -@lu.transformation -def _promote_scalar_residuals(*args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs +@lu.transformation2 +def _promote_scalar_residuals(f, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) which = [f1 is None and f2 is None and not v.aval.shape for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x for x in out_consts] - yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) def _promote_scalar_residuals_jaxpr(jaxpr, which): @lu.wrap_init @@ -1641,19 +1550,20 @@ def fun(*res_and_args): return jaxpr -def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: +def _unmentioned2(mesh: Mesh, names: AxisNames, + auto: frozenset[AxisName]) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} - return [n for n in _all_mesh_names(mesh) if n not in name_set] + name_set = {n for ns in names.values() for n in ns} | auto + return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns)))) + else x if rewrite or dtypes.dtype(x) == dtypes.float0 + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) @@ -1671,7 +1581,7 @@ def fun_trans(out_cts, args): ) out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero else x if rewrite - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns))) + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) for ns, x in zip(in_names, out)] return out @@ -1692,18 +1602,6 @@ def new_out_names_thunk(): return tree_unflatten(out_tree(), out_flat) ad.primitive_transposes[shard_map_p] = _shard_map_transpose -def _shard_map_axis_subst(params, subst, traverse): - if 'jaxpr' not in params: - return params - if not traverse: - return params - def shadowed_subst(name): - return (name,) if name in params['mesh'].shape else subst(name) - with core.extend_axis_env_nd(params['mesh'].shape.items()): - new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst) - return dict(params, jaxpr=new_jaxpr) -core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst - # Remat def _partial_eval_jaxpr_custom_rule( @@ -1753,7 +1651,7 @@ def _partial_eval_jaxpr_custom_rule( def _add_reshapes(which, jaxpr_known, jaxpr_staged): # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape + which_ = [w and not v.aval.shape # pytype: disable=attribute-error for w, v in zip(which, jaxpr_staged.invars[:len(which)])] if not any(which_): return jaxpr_known, jaxpr_staged assert not jaxpr_known.constvars and not jaxpr_staged.constvars @@ -1783,7 +1681,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, in_fwd, out_fwd, which, params_known, params_staged): # prune inputs to jaxpr_known according to unks_in mesh = params_known['mesh'] - all_names = _all_mesh_names(mesh) + all_names = _all_mesh_names_except_spmd(mesh) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) out_names_known = out_names_known + [{0: all_names}] * sum(which) @@ -1801,21 +1699,18 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, out_names=tuple(out_names_staged), check_rep=False) return new_params_known, new_params_staged, all_names - # TODO(mattjj): remove this mechanism when we revise mesh scopes -def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]: - stack = core.thread_local_state.trace_state.trace_stack.stack - names = {n for frame in stack - if (ns := frame.payload.get('spmd_axis_name', ())) is not None - for n in ns} - return tuple(name for name in mesh.axis_names if name not in names) - +def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: + spmd_names = core.get_axis_env().spmd_axis_names + return tuple(name for name in mesh.axis_names if name not in spmd_names) # DCE # TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] with core.extend_axis_env_nd(mesh.shape.items()): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) @@ -1882,13 +1777,13 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): check_rep=False, auto=frozenset()), in_specs, out_specs) -@lu.transformation -def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs): +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), list(args), list(in_axes)) - out = yield args, {} - yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), + list(out), list(out_axes_thunk())) def _axis_to_spec(axis_name, ax): if isinstance(ax, int): @@ -1926,59 +1821,55 @@ def __init__(self, trace, rep, val): def aval(self) -> core.AbstractValue: return core.get_aval(self.val) - def full_lower(self) -> RewriteTracer: - return self + def to_concrete_value(self): + return core.to_concrete_value(self.val) def __str__(self) -> str: return str(self.val) # TODO(mattjj): could show replication info here __repr__ = __str__ # for debuggers, like `p x` class RewriteTrace(core.Trace): + parent_trace : core.Trace + tag : core.TraceTag mesh: Mesh - dyna: int - def __init__(self, *args, mesh, dyna): - super().__init__(*args) + def __init__(self, parent_trace, tag, mesh): + self.parent_trace = parent_trace + self.tag = tag self.mesh = mesh - self.dyna = dyna - - def pure(self, val) -> RewriteTracer: - return RewriteTracer(self, set(self.mesh.axis_names), val) - def lift(self, tracer: core.Tracer) -> RewriteTracer: - return RewriteTracer(self, set(self.mesh.axis_names), tracer) - - def sublift(self, tracer: core.Tracer) -> RewriteTracer: - return RewriteTracer(self, tracer.rep, tracer.val) + def to_val_rep_pair(self, val): + # TODO: add a tag to tell if self + if isinstance(val, RewriteTracer) and val._trace.tag is self.tag: + return val.val, val.rep + else: + return val, set(self.mesh.axis_names) def process_primitive(self, prim, in_tracers, params): rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) - in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) - with core.new_dynamic(self.dyna): + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) + with core.set_current_trace(self.parent_trace): out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) return out_tracers if prim.multiple_results else out_tracers[0] def process_call(self, call_primitive, f, in_tracers, params): - in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) - f, out_reps = _rewrite_subtrace(f, self.main, tuple(in_reps)) - with core.new_dynamic(self.dyna): + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) + f, out_reps = _rewrite_subtrace(f, self.tag, self.mesh, tuple(in_reps)) + with core.set_current_trace(self.parent_trace): out_vals = call_primitive.bind(f, *in_vals, **params) return map(partial(RewriteTracer, self), out_reps(), out_vals) - def post_process_call(self, call_primitive, out_tracers, params): - assert False # unreachable - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) - in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) - fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) - jvp, out_reps2 = _rewrite_subtrace(jvp, self.main, in_reps * 2) - with core.new_dynamic(self.dyna): + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) + fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) + jvp, out_reps2 = _rewrite_subtrace(jvp, self.tag, self.mesh, in_reps * 2) + with core.set_current_trace(self.parent_trace): out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) if not fst: @@ -1986,9 +1877,6 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): out_reps = out_reps[:len(out_reps) // 2] return map(partial(RewriteTracer, self), out_reps, out_vals) - def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): - assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: @@ -1996,12 +1884,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) - in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) - fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) + fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] - fwd, out_reps2 = _rewrite_subtrace(fwd, self.main, fwd_in_reps) + fwd, out_reps2 = _rewrite_subtrace(fwd, self.tag, self.mesh, fwd_in_reps) bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) - with core.new_dynamic(self.dyna): + with core.set_current_trace(self.parent_trace): out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) @@ -2010,45 +1898,34 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, _, out_reps = split_list(out_reps, [res_tree.num_leaves]) return map(partial(RewriteTracer, self), out_reps, out_vals) - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - - # TODO process_axis_index - def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): in_reps = map(partial(_in_names_to_rep, mesh), in_names) out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()] fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) -def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps): - return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps) - -@lu.transformation_with_aux -def _efficient_transpose_outer(mesh, in_reps, *args): - lvl = core.dynamic_level() - with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: - out_vals, out_reps = yield (main, mesh, in_reps, args), {} - del main - yield out_vals, out_reps - -@lu.transformation -def _efficient_transpose_inner(main, mesh, in_reps, args): - t = main.with_cur_sublevel() - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - yield unzip2((t.val, t.rep) for t in out_tracers) - -@lu.transformation -def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): - outs = yield args, {} +@lu.transformation_with_aux2 +def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): + with core.take_current_trace() as parent: + tag = core.TraceTag() + t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) + in_tracers = map(partial(RewriteTracer, t), in_reps, args) + with core.set_current_trace(t): + ans = f(*in_tracers) + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) + del t, in_tracers, ans + store.store(out_reps) + return out_vals + +@lu.transformation2 +def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): + outs = f(*args) out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ _check_reps2(mesh, out_reps_dst, out_reps_src) outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - yield outs + return outs # TODO(mattjj): caching def _replication_rewrite_match( @@ -2060,8 +1937,7 @@ def _replication_rewrite_match( f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) f = _match_rep(f, mesh, out_rep, out_rep_dst) - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts) # TODO(mattjj): caching @@ -2072,28 +1948,26 @@ def _replication_rewrite_nomatch( ) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() -@lu.transformation_with_aux -def _rewrite_subtrace(main, in_reps, *in_vals): - assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) - t = main.with_cur_sublevel() - in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) - with core.new_dynamic(main.level): - outs = yield in_tracers, {} - out_tracers = map(t.full_raise, outs) - out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) - yield out_vals, out_reps +@lu.transformation_with_aux2 +def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals): + with core.take_current_trace() as parent_trace: + assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) + t = RewriteTrace(parent_trace, tag, mesh) + in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) + with core.set_current_trace(t): + outs = f(*in_tracers) + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) + store.store(out_reps) + return out_vals def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): def new_bwd(*args): - lvl = core.dynamic_level() - with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: - bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), main, in_reps()) - out = bwd_.call_wrapped(*args) - del main + tag = core.TraceTag() + bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), tag, mesh, in_reps()) + out = bwd_.call_wrapped(*args) return map(_match_replication, reps_thunk(), reps_dst, out) return new_bwd diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py index af7b079eeb7f..8324e4c55457 100644 --- a/jax/experimental/slab/slab.py +++ b/jax/experimental/slab/slab.py @@ -89,7 +89,7 @@ def xprod(xs: Iterable[XInt]) -> XInt: return xmul(*list(xs)) def static_int(x: XInt) -> bool: - return isinstance(core.get_aval(x), core.ConcreteArray) + return core.is_concrete(x) def static_shape(s: DShape) -> bool: return all(map(static_int, s)) diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index f4fe0b9040e6..6962ef78bcff 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -19,7 +19,7 @@ from functools import partial -from jax import core +from jax._src import core from jax._src import dispatch from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index f65f7b0a194b..d8bf1ee4a7bd 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -606,8 +606,11 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc bcoo_dot_general_p = core.Primitive('bcoo_dot_general') -def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, - precision: None = None, preferred_element_type: None = None) -> BCOO | Array: +def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, + dimension_numbers: DotDimensionNumbers, + precision: None = None, + preferred_element_type: None = None, + out_type=None) -> BCOO | Array: """A general contraction operation. Args: @@ -625,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision # unused + del precision, out_type # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -1051,7 +1054,8 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, - 'preferred_element_type': None} + 'preferred_element_type': None, + 'out_type': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices @@ -1691,7 +1695,8 @@ def _update(d, i): return BCOO((new_data, new_indices), shape=shape) -def bcoo_broadcast_in_dim(mat: BCOO, *, shape: Shape, broadcast_dimensions: Sequence[int]) -> BCOO: +def bcoo_broadcast_in_dim(mat: BCOO, *, shape: Shape, broadcast_dimensions: Sequence[int], + sharding=None) -> BCOO: """Expand the size and rank of a BCOO array by duplicating the data. A BCOO equivalence to jax.lax.broadcast_in_dim. @@ -1821,7 +1826,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO: return BCOO((new_data, new_indices), shape=out_aval.shape) -def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[int] | None = None) -> BCOO: +def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], + dimensions: Sequence[int] | None = None, + sharding=None) -> BCOO: """Sparse implementation of {func}`jax.lax.reshape`. Args: diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 8aa7d80c7a29..ed7e53d4c64e 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -462,7 +462,8 @@ def _bcsr_extract_batching_rule(batched_args, batch_dims): def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, - preferred_element_type: None = None) -> Array: + preferred_element_type: None = None, + out_type=None) -> Array: """A general contraction operation. Args: @@ -479,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision # unused + del precision, out_type # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, @@ -713,7 +714,8 @@ def _bcsr_dot_general_gpu_lowering( #---------------------------------------------------------------------- # BCOO functions that maybe should be primitives? -def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequence[int]) -> BCSR: +def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequence[int], + sharding=None) -> BCSR: result_bcoo = bcoo.bcoo_broadcast_in_dim( mat.to_bcoo(), shape=shape, broadcast_dimensions=broadcast_dimensions) return BCSR.from_bcoo(result_bcoo) diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py index 6c827325befc..f9d28f5ff83c 100644 --- a/jax/experimental/sparse/nm.py +++ b/jax/experimental/sparse/nm.py @@ -14,7 +14,7 @@ """N:M-sparsity associated primitives.""" -from jax import core +from jax._src import core from jax._src import dispatch from jax._src.lax.lax import DotDimensionNumbers from jax._src.lib import gpu_sparse diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index efdf1888f436..c83d9a667888 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -97,6 +97,7 @@ lax.sin_p, lax.sinh_p, lax.sqrt_p, + lax.square_p, lax.tan_p, lax.tanh_p, lax.convert_element_type_p, @@ -276,16 +277,6 @@ def spvalue_to_aval(spvalue): # ------------------------------------------------------------------------------ # Implementation of sparsify() using tracers. -def popattr(obj: Any, name: str) -> Any: - assert hasattr(obj, name) - val = getattr(obj, name) - delattr(obj, name) - return val - -def setnewattr(obj: Any, name: str, val: Any): - assert not hasattr(obj, name) - setattr(obj, name, val) - class SparseTracer(core.Tracer): def __init__(self, trace: core.Trace, *, spvalue): self._spvalue = spvalue @@ -293,9 +284,9 @@ def __init__(self, trace: core.Trace, *, spvalue): @property def spenv(self): - if not hasattr(self._trace.main, 'spenv'): - raise RuntimeError("Internal: main does not have spenv defined.") - return self._trace.main.spenv + if not hasattr(self._trace, 'spenv'): + raise RuntimeError("Internal: trace does not have spenv defined.") + return self._trace.spenv @property def aval(self): @@ -305,71 +296,71 @@ def full_lower(self): return self class SparseTrace(core.Trace): - def pure(self, val: Any): - if not hasattr(self.main, 'spenv'): - raise RuntimeError("Internal: main does not have spenv defined.") - spvalue, = arrays_to_spvalues(self.main.spenv, [val]) - return SparseTracer(self, spvalue=spvalue) - def lift(self, val: core.Tracer): - if not hasattr(self.main, 'spenv'): - raise RuntimeError("Internal: main does not have spenv defined.") - spvalue, = arrays_to_spvalues(self.main.spenv, [val]) - return SparseTracer(self, spvalue=spvalue) + def __init__(self, parent_trace, tag, spenv): + self.parent_trace = parent_trace + self.tag = tag + self.spenv = spenv - def sublift(self, val: SparseTracer): - return SparseTracer(val._trace, spvalue=val._spvalue) + def to_sparse_tracer(self, val): + if isinstance(val, SparseTracer) and self.tag is val._trace.tag: + return val + else: + with core.set_current_trace(self.parent_trace): + spvalue, = arrays_to_spvalues(self.spenv, [val]) + return SparseTracer(self, spvalue=spvalue) def process_primitive(self, primitive, tracers, params): - spenv = popattr(self.main, 'spenv') + tracers = [self.to_sparse_tracer(t) for t in tracers] spvalues = [t._spvalue for t in tracers] if any(spvalue.is_sparse() for spvalue in spvalues): if primitive not in sparse_rules_bcoo: _raise_unimplemented_primitive(primitive) - out_spvalues = sparse_rules_bcoo[primitive](spenv, *(t._spvalue for t in tracers), **params) + with core.set_current_trace(self.parent_trace): + out_spvalues = sparse_rules_bcoo[primitive](self.spenv, *(t._spvalue for t in tracers), **params) else: - out_bufs = primitive.bind(*(spenv.data(spvalue) for spvalue in spvalues), **params) - out_spvalues = arrays_to_spvalues(spenv, out_bufs if primitive.multiple_results else [out_bufs]) - setnewattr(self.main, 'spenv', spenv) + out_bufs = primitive.bind_with_trace(self.parent_trace, tuple(self.spenv.data(spvalue) for spvalue in spvalues), params) + out_spvalues = arrays_to_spvalues(self.spenv, out_bufs if primitive.multiple_results else [out_bufs]) out_tracers = tuple(SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues) return out_tracers if primitive.multiple_results else out_tracers[0] def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): - spenv = popattr(self.main, 'spenv') + assert False spvalues = tuple(t._spvalue for t in tracers) - in_bufs = spenv._buffers + in_bufs = self.spenv._buffers fun, out_spvalues = sparsify_subtrace(f, self.main, spvalues) if any(params['donated_invars']): raise NotImplementedError("sparsify does not support donated_invars") params = dict(params, donated_invars=tuple(False for buf in in_bufs)) bufs_out = call_primitive.bind(fun, *in_bufs, **params) - setnewattr(self.main, 'spenv', SparsifyEnv(bufs_out)) return [SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues()] def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): # TODO(jakevdp): handle the jvp here del primitive, jvp, symbolic_zeros - return fun.call_wrapped(*tracers) - -@lu.transformation_with_aux -def sparsify_subtrace(main, spvalues, *bufs): - setnewattr(main, 'spenv', SparsifyEnv(bufs)) - trace = main.with_cur_sublevel() - in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] - outs = yield in_tracers, {} - out_traces = [trace.full_raise(out) for out in outs] - buffers = popattr(main, 'spenv')._buffers - yield buffers, [out._spvalue for out in out_traces] + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) + +@lu.transformation_with_aux2 +def sparsify_subtrace(f, store, tag, spenv, spvalues, *bufs): + with core.take_current_trace() as parent: + trace = SparseTrace(parent, tag, spenv) + with core.set_current_trace(trace): + in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] + outs = f(*in_tracers) + out_traces = [trace.to_sparse_tracer(out) for out in outs] + buffers = spenv._buffers + store.store([out._spvalue for out in out_traces]) + return buffers def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]): - with core.new_main(SparseTrace) as main: - spenv = SparsifyEnv() - spvalues = arrays_to_spvalues(spenv, args) - in_bufs = spenv._buffers - fun, out_spvalues = sparsify_subtrace(wrapped_fun, main, spvalues) - out_bufs = fun.call_wrapped(*in_bufs) - spenv = SparsifyEnv(out_bufs) - del main + tag = core.TraceTag() + spenv = SparsifyEnv() + spvalues = arrays_to_spvalues(spenv, args) + in_bufs = spenv._buffers + fun, out_spvalues = sparsify_subtrace(wrapped_fun, tag, spenv, spvalues) + out_bufs = fun.call_wrapped(*in_bufs) + spenv = SparsifyEnv(out_bufs) return spvalues_to_arrays(spenv, out_spvalues()) def _sparsify_with_tracer(fun): @@ -773,7 +764,7 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): + keep_unused, inline, compiler_options_kvs): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -809,7 +800,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) sparse_rules_bcoo[pjit.pjit_p] = _pjit_sparse diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 7ef1ed781c15..2cb765676411 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -111,4 +111,4 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None) + precision=None, preferred_element_type=None, out_type=None) diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index e8ef32935cbf..bbb5925ab41a 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -14,16 +14,24 @@ """Modules for JAX extensions. -The :mod:`jax.extend` package provides modules for access to JAX +The :mod:`jax.extend` module provides modules for access to JAX internal machinery. See `JEP #15856 `_. +This module is not the only means by which JAX aims to be +extensible. For example, the main JAX API offers mechanisms for +`customizing derivatives +`_, +`registering custom pytree definitions +`_, +and more. + API policy ---------- Unlike the `public API `_, -this package offers **no compatibility guarantee** across releases. +this module offers **no compatibility guarantee** across releases. Breaking changes will be announced via the `JAX project changelog `_. """ diff --git a/jax/extend/backend.py b/jax/extend/backend.py index b1e471133482..8d5488baba16 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -24,3 +24,6 @@ get_backend as get_backend, register_backend_factory as register_backend_factory, ) +from jax._src.interpreters.pxla import ( + get_default_device as get_default_device +) diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index feb70b5171be..d8a10154cf4a 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -127,6 +127,7 @@ sinh_p as sinh_p, sort_p as sort_p, sqrt_p as sqrt_p, + square_p as square_p, squeeze_p as squeeze_p, sub_p as sub_p, tan_p as tan_p, @@ -203,6 +204,7 @@ pmin_p as pmin_p, ppermute_p as ppermute_p, psum_p as psum_p, + ragged_all_to_all_p as ragged_all_to_all_p, ) from jax._src.lax.ann import ( diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 74c52dddbae8..8b80d033fa5c 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -22,5 +22,7 @@ merge_linear_aux as merge_linear_aux, transformation as transformation, transformation_with_aux as transformation_with_aux, + transformation2 as transformation2, + transformation_with_aux2 as transformation_with_aux2, wrap_init as wrap_init, ) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 28816afb01e3..4ded4a803ae0 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -18,8 +18,6 @@ from __future__ import annotations from jax._src.interpreters.ad import ( - CustomJVPException as CustomJVPException, - CustomVJPException as CustomVJPException, JVPTrace as JVPTrace, JVPTracer as JVPTracer, UndefinedPrimal as UndefinedPrimal, @@ -67,21 +65,9 @@ vjp as vjp, zero_jvp as zero_jvp, zeros_like_aval as zeros_like_aval, - zeros_like_jaxval as zeros_like_jaxval, zeros_like_p as zeros_like_p, ) -_deprecations = { - # Finalized Mar 18, 2024; remove after June 18, 2024 - "config": ( - "jax.interpreters.ad.config is deprecated. Use jax.config directly.", - None, - ), - "source_info_util": ( - "jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.", - None, - ), -} def backward_pass(jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 607fc6fa596d..7a93a6942c21 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -50,6 +50,7 @@ defbroadcasting as defbroadcasting, defreducer as defreducer, defvectorized as defvectorized, + fancy_primitive_batchers as fancy_primitive_batchers, flatten_fun_for_vmap as flatten_fun_for_vmap, from_elt as from_elt, from_elt_handlers as from_elt_handlers, @@ -64,7 +65,6 @@ reducer_batcher as reducer_batcher, register_vmappable as register_vmappable, spec_types as spec_types, - spmd_axis_primitive_batchers as spmd_axis_primitive_batchers, to_elt as to_elt, to_elt_handlers as to_elt_handlers, unregister_vmappable as unregister_vmappable, diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 3c63948bee63..dca438996229 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -62,8 +62,8 @@ debug_info as debug_info, debug_info_final as debug_info_final, def_trivial_padding as def_trivial_padding, - extend_jaxpr_stack as extend_jaxpr_stack, forwarding_rules as forwarding_rules, + has_effects as has_effects, infer_lambda_input_type as infer_lambda_input_type, instantiate_const_at as instantiate_const_at, make_jaxpr_effects as make_jaxpr_effects, @@ -81,15 +81,9 @@ recipe_to_eqn as recipe_to_eqn, result_info as result_info, sig_info as sig_info, - trace_to_jaxpr as trace_to_jaxpr, trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic, trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2, - trace_to_jaxpr_final as trace_to_jaxpr_final, - trace_to_jaxpr_final2 as trace_to_jaxpr_final2, trace_to_jaxpr_nounits as trace_to_jaxpr_nounits, - trace_to_subjaxpr as trace_to_subjaxpr, - trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic, - trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2, trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 15c9a2cfe49d..f3fd8bac558c 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -40,7 +40,6 @@ ArrayMapping as ArrayMapping, UNSPECIFIED as _UNSPECIFIED, # noqa: F401 array_mapping_to_axis_resources as array_mapping_to_axis_resources, - is_unspecified as _is_unspecified, # noqa: F401 ) from jax._src.sharding_specs import ( diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bbd5b65d5d3e..b3a470f5e049 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -23,72 +23,27 @@ apply_primitive as apply_primitive, ) -from jax._src import xla_bridge as _xb from jax._src.lib import xla_client as _xc - -_xe = _xc._xla -Backend = _xe.Client +Backend = _xc._xla.Client +del _xc # Deprecations _deprecations = { - # Added 2024-06-28 + # Finalized 2024-10-24; remove after 2025-01-24 "xb": ( - "jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.", - _xb + ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " + "Use jax.lib.xla_bridge instead."), None ), "xc": ( - "jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.", - _xc, + ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " + "Use jax.lib.xla_client instead."), None ), "xe": ( - "jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.", - _xe, - ), - # Finalized 2024-05-13; remove after 2024-08-13 - "backend_specific_translations": ( - "jax.interpreters.xla.backend_specific_translations is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "translations": ( - "jax.interpreters.xla.translations is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "register_translation": ( - "jax.interpreters.xla.register_translation is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "xla_destructure": ( - "jax.interpreters.xla.xla_destructure is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "TranslationRule": ( - "jax.interpreters.xla.TranslationRule is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "TranslationContext": ( - "jax.interpreters.xla.TranslationContext is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "XlaOp": ( - "jax.interpreters.xla.XlaOp is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, + ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " + "Use jax.lib.xla_extension instead."), None ), } -import typing from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -if typing.TYPE_CHECKING: - xb = _xb - xc = _xc - xe = _xe -else: - __getattr__ = _deprecation_getattr(__name__, _deprecations) +__getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr -del typing diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 7f42cfca5fe8..321b1dda19cf 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -206,6 +206,7 @@ sqrt as sqrt, sqrt_p as sqrt_p, square as square, + square_p as square_p, squeeze as squeeze, squeeze_p as squeeze_p, stop_gradient as stop_gradient, @@ -330,7 +331,6 @@ linear_solve_p as linear_solve_p, map as map, scan as scan, - scan_bind as scan_bind, scan_p as scan_p, switch as switch, while_loop as while_loop, @@ -362,6 +362,8 @@ psum_p as psum_p, psum_scatter as psum_scatter, pswapaxes as pswapaxes, + ragged_all_to_all as ragged_all_to_all, + ragged_all_to_all_p as ragged_all_to_all_p, ) from jax._src.lax.other import ( conv_general_dilated_local as conv_general_dilated_local, @@ -378,16 +380,3 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p - - -_deprecations = { - # Finalized 2024-05-13; remove after 2024-08-13 - "tie_in": ( - "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. " - "Replace z = tie_in(x, y) with z = y.", None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 654abc35fc78..2bcb1cb037f4 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -17,7 +17,6 @@ default_backend as _deprecated_default_backend, get_backend as _deprecated_get_backend, xla_client as _deprecated_xla_client, - _backends as _backends, ) from jax._src.compiler import ( diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index aaf3791037d0..cd3696d8838c 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -18,7 +18,6 @@ get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version -ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions DeviceAssignment = _xc.DeviceAssignment @@ -95,6 +94,11 @@ "XlaComputation is deprecated; use StableHLO instead.", _xc.XlaComputation, ), + # Added Nov 20 2024 + "ArrayImpl": ( + "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.", + _xc.ArrayImpl, + ), } import typing as _typing @@ -106,6 +110,7 @@ ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target shape_from_pyval = _xc.shape_from_pyval + ArrayImpl = _xc.ArrayImpl Device = _xc.Device FftType = _FftType PaddingType = _xc.PaddingType diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 20ce459685aa..52fe94e231d1 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -24,7 +24,6 @@ pmap_lib = _xe.pmap_lib profiler = _xe.profiler pytree = _xe.pytree -ArrayImpl = _xe.ArrayImpl Device = _xe.Device DistributedRuntimeClient = _xe.DistributedRuntimeClient HloModule = _xe.HloModule @@ -33,6 +32,28 @@ PjitFunctionCache = _xe.PjitFunctionCache PjitFunction = _xe.PjitFunction PmapFunction = _xe.PmapFunction -XlaRuntimeError = _xe.XlaRuntimeError +_deprecations = { + # Added Nov 20 2024 + "ArrayImpl": ( + "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", + _xe.ArrayImpl, + ), + "XlaRuntimeError": ( + "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", + _xe.XlaRuntimeError, + ), +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + ArrayImpl = _xe.ArrayImpl + XlaRuntimeError = _xe.XlaRuntimeError +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing del _xe diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 496d03261384..ebe725c448ee 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -49,17 +49,3 @@ squareplus as squareplus, mish as mish, ) - -# Deprecations - -_deprecations = { - # Finalized 2024-05-13; remove after 2024-08-13 - "normalize": ( - "jax.nn.normalize is deprecated. Use jax.nn.standardize instead.", - None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d50d55033c33..12736c1cd9b1 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -202,6 +202,7 @@ printoptions as printoptions, promote_types as promote_types, put as put, + put_along_axis as put_along_axis, ravel as ravel, ravel_multi_index as ravel_multi_index, repeat as repeat, @@ -273,6 +274,15 @@ except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, @@ -307,8 +317,9 @@ all as all, average as average, count_nonzero as count_nonzero, - cumsum as cumsum, cumprod as cumprod, + cumsum as cumsum, + cumulative_prod as cumulative_prod, cumulative_sum as cumulative_sum, max as max, mean as mean, @@ -471,11 +482,6 @@ "jnp.round_ is deprecated; use jnp.round instead.", round ), - # Deprecated 18 Sept 2023 and removed 06 Feb 2024 - "trapz": ( - "jnp.trapz is deprecated; use jnp.trapezoid instead.", - None - ), } import typing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 30363c8f4e47..5d357ab1bb03 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -29,6 +29,46 @@ _Device = Device ComplexWarning: type +class ufunc: + def __init__(self, func: Callable[..., Any], /, + nin: int, nout: int, *, + name: str | None = None, + nargs: int | None = None, + identity: Any = None, + call: Callable[..., Any] | None = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None, + ): ... + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, *args: ArrayLike) -> Any: ... + def reduce(self, a: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + out: None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ... + class BinaryUfunc(Protocol): @property def nin(self) -> int: ... @@ -39,9 +79,10 @@ class BinaryUfunc(Protocol): @property def identity(self) -> builtins.bool | int | float: ... def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... - def reduce(self, arr: ArrayLike, /, *, + def reduce(self, a: ArrayLike, /, *, axis: int | None = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: builtins.bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: ... @@ -276,6 +317,9 @@ def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... +def cumulative_prod(x: ArrayLike, /, *, axis: int | None = ..., + dtype: DTypeLike | None = ..., + include_initial: builtins.bool = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., include_initial: builtins.bool = ...) -> Array: ... @@ -431,6 +475,8 @@ def fromfile(*args, **kwargs): ... def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = ..., **kwargs) -> Array: ... def fromiter(*args, **kwargs): ... +def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, + *, identity: Any = None) -> ufunc: ... def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... @@ -583,8 +629,8 @@ def log(x: ArrayLike, /) -> Array: ... def log10(x: ArrayLike, /) -> Array: ... def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... -def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logaddexp: BinaryUfunc +logaddexp2: BinaryUfunc logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... logical_or: BinaryUfunc @@ -739,6 +785,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, + axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ... def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... diff --git a/jax/random.py b/jax/random.py index 29a625389811..b99cd531f18c 100644 --- a/jax/random.py +++ b/jax/random.py @@ -251,20 +251,3 @@ weibull_min as weibull_min, wrap_key_data as wrap_key_data, ) - -_deprecations = { - # Finalized Jul 26 2024; remove after Nov 2024. - "shuffle": ( - "jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.", - None, - ) -} - -import typing -if typing.TYPE_CHECKING: - pass -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 84cc697d1894..83d0b4b25923 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str: }[(platform.system(), cpu)] return f"{platform_name}_{cpu_name}" +def get_githash(jaxlib_git_hash): + if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): + with open(jaxlib_git_hash, "r") as f: + return f.readline().strip() + return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jax/version.py b/jax/version.py index 7fde742cd259..b5e5f255f152 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.35" +_version = "0.4.38" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -60,7 +60,11 @@ def _version_from_git_tree(base_version: str) -> str | None: except: return None else: - return f"{base_version}.dev{datestring}+{commit_hash}" + version = f"{base_version}.dev{datestring}+{commit_hash}" + suffix = os.environ.get("JAX_CUSTOM_VERSION_SUFFIX", None) + if suffix: + return version + "." + suffix + return version def _get_version_for_build() -> str: @@ -133,7 +137,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.35" +_minimum_jaxlib_version = "0.4.36" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 8e99907d7078..ce31684de46f 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -55,7 +55,7 @@ def has_ext_modules(self): 'with_cuda': [ "nvidia-cublas-cu12>=12.1.3.1", "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.1.105", + "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", "nvidia-cudnn-cu12>=9.1,<10.0", "nvidia-cufft-cu12>=11.0.2.54", diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index a84a6b34ea48..d504d0a11666 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -22,6 +22,11 @@ project_name = f"jax-rocm{rocm_version}-plugin" package_name = f"jax_rocm{rocm_version}_plugin" +# Extract ROCm version from the `ROCM_PATH` environment variable. +default_rocm_path = "/opt/rocm" +rocm_path = os.getenv("ROCM_PATH", default_rocm_path) +rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown" + def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( 'version', os.path.join(pkg_path, 'version.py')) @@ -43,7 +48,7 @@ def has_ext_modules(self): name=project_name, version=__version__, cmdclass=_cmdclass, - description="JAX Plugin for AMD GPUs", + description=f"JAX Plugin for AMD GPUs (ROCm:{rocm_detected_version})", long_description="", long_description_content_type="text/markdown", author="Ruturaj4", diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py index d131e732c91a..ec3eae2d8821 100644 --- a/jax_plugins/rocm/setup.py +++ b/jax_plugins/rocm/setup.py @@ -21,6 +21,11 @@ project_name = f"jax-rocm{rocm_version}-pjrt" package_name = f"jax_plugins.xla_rocm{rocm_version}" +# Extract ROCm version from the `ROCM_PATH` environment variable. +default_rocm_path = "/opt/rocm" +rocm_path = os.getenv("ROCM_PATH", default_rocm_path) +rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown" + def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( 'version', os.path.join(pkg_path, 'version.py')) @@ -41,7 +46,7 @@ def load_version_module(pkg_path): setup( name=project_name, version=__version__, - description="JAX XLA PJRT Plugin for AMD GPUs", + description=f"JAX XLA PJRT Plugin for AMD GPUs (ROCm:{rocm_detected_version})", long_description="", long_description_content_type="text/markdown", author="Ruturaj4", diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 8c402cfcefe8..e7ba1dd3de16 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -208,27 +208,47 @@ pybind_extension( ], ) -pybind_extension( - name = "cuda_plugin_extension", - srcs = ["cuda_plugin_extension.cc"], - module_name = "cuda_plugin_extension", +cc_library( + name = "gpu_plugin_extension", + srcs = ["gpu_plugin_extension.cc"], + hdrs = ["gpu_plugin_extension.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ + ":kernel_nanobind_helpers", "@com_google_absl//absl/status", - "@nanobind", - "//jaxlib:kernel_nanobind_helpers", - "@xla//third_party/python_runtime:headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@tsl//tsl/platform:statusor", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", - # TODO(jieying): move to jaxlib after py_client_gpu is separated from py_client "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +pybind_extension( + name = "cuda_plugin_extension", + srcs = ["cuda_plugin_extension.cc"], + module_name = "cuda_plugin_extension", + deps = [ + ":gpu_plugin_extension", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/pjrt:status_casters", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -237,21 +257,12 @@ pybind_extension( srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/status", + ":gpu_plugin_extension", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@xla//third_party/python_runtime:headers", - "@xla//xla:status", - "@xla//xla:util", - "@xla//xla/ffi/api:c_api", - "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_helpers", - "@xla//xla/python:py_client_gpu", - "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 19b82a5ce149..ed815e1b1bd2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// LAPACK uses a packed representation to represent a mixture of real -// eigenvectors and complex conjugate pairs. This helper unpacks the -// representation into regular complex matrices. -template -static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag, - const T* packed, std::complex* unpacked) { - for (int j = 0; j < n;) { - if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { - // Real values in each row without imaginary part - // Second row of the imaginary part is not provided - for (int i = 0; i < n; ++i) { - unpacked[j * n + i] = {packed[j * n + i], 0.}; - } - ++j; - } else { - // Complex values where the real part is in the jth row - // and the imaginary part is in the next row (j + 1) - for (int i = 0; i < n; ++i) { - const T real_part = packed[j * n + i]; - const T imag_part = packed[(j + 1) * n + i]; - unpacked[j * n + i] = {real_part, imag_part}; - unpacked[(j + 1) * n + i] = {real_part, -imag_part}; - } - j += 2; - } - } -} - // lapack geev template diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 7d15e494fffc..cddcb1162120 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ +#include #include #include #include @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian { // lapack geev +// LAPACK uses a packed representation to represent a mixture of real +// eigenvectors and complex conjugate pairs. This helper unpacks the +// representation into regular complex matrices. +template +static void UnpackEigenvectors(Int n, const T* eigenvals_imag, + const T* packed, std::complex* unpacked) { + for (int j = 0; j < n;) { + if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { + // Real values in each row without imaginary part + // Second row of the imaginary part is not provided + for (int i = 0; i < n; ++i) { + unpacked[j * n + i] = {packed[j * n + i], 0.}; + } + ++j; + } else { + // Complex values where the real part is in the jth row + // and the imaginary part is in the next row (j + 1) + for (int i = 0; i < n; ++i) { + const T real_part = packed[j * n + i]; + const T imag_part = packed[(j + 1) * n + i]; + unpacked[j * n + i] = {real_part, imag_part}; + unpacked[(j + 1) * n + i] = {real_part, -imag_part}; + } + j += 2; + } + } +} + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 34e40d12d5be..afce2c000ecc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -476,6 +476,55 @@ pybind_extension( ], ) +cc_library( + name = "cuda_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + module_name = "_hybrid", + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_hybrid_kernels", + ":cuda_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "cuda_gpu_kernels", srcs = ["//jaxlib/gpu:gpu_kernels.cc"], @@ -633,6 +682,7 @@ py_library( name = "cuda_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_rnn", diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index ea81109b36c0..34cf462d623e 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -12,135 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include -#include -#include #include "nanobind/nanobind.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/ffi/api/c_api.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" -#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "jaxlib/gpu_plugin_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" -#include "xla/tsl/python/lib/core/numpy.h" -#include "xla/util.h" namespace nb = nanobind; namespace xla { namespace { -absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, - const char* fn_name_c_str, - size_t fn_name_size, nb::object fn, - int api_version, - XLA_FFI_Handler_Traits traits) { - if (c_api->extension_start == nullptr) { - return Unimplemented("The plugin does not have extension."); - } - const PJRT_Extension_Base* next = - reinterpret_cast(c_api->extension_start); - while (next != nullptr && - next->type != - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { - next = next->next; - } - if (next == nullptr) { - return Unimplemented("The plugin does not have a custom call extension."); - } - PJRT_Gpu_Register_Custom_Call* register_custom_call = - reinterpret_cast(next)->custom_call; - - if (traits != 0) { - return Unimplemented("The plugin does not support custom call traits."); - } - - PJRT_Gpu_Register_Custom_Call_Args args; - args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name_c_str; - args.function_name_size = fn_name_size; - -#if PJRT_API_GPU_EXTENSION_VERSION >= 1 - args.api_version = api_version; -#endif - - auto as_capsule = [](nb::object obj) -> absl::StatusOr { - nb::capsule capsule; - if (!nb::try_cast(obj, capsule)) { - return absl::InvalidArgumentError( - "Custom call target registration requires handlers as PyCapsules"); - } - return capsule; - }; - -#if PJRT_API_GPU_EXTENSION_VERSION <= 1 - TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); - args.custom_call_function = fn_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); -#else - args.handler_instantiate = nullptr; - args.handler_prepare = nullptr; - args.handler_initialize = nullptr; - args.handler_execute = nullptr; - - // Register legacy custom call target (untyped void* API). - if (api_version == 0) { - TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); - args.handler_execute = capsule_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - // Register XLA FFI handler (typed API with explicit function signatures). - if (api_version == 1) { - auto capsule_execute = as_capsule(fn); - if (capsule_execute.ok()) { - args.handler_execute = capsule_execute->data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - nb::dict bundle; - if (nb::try_cast(fn, bundle)) { - auto handler = [&](const char* name) -> absl::StatusOr { - if (!bundle.contains(name)) return nullptr; - TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); - return capsule.data(); - }; - - TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); - TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); - TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); - TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - return absl::InvalidArgumentError( - "Unsupported custom call target type for api_version=1"); - } - - return absl::UnimplementedError(absl::StrFormat( - "API version %d is not supported by RegisterCustomCallTarget. " - "Supported versions are 0 and 1.", - api_version)); -#endif -} - -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -155,31 +41,7 @@ static std::string ToString(CUresult result) { } // namespace NB_MODULE(cuda_plugin_extension, m) { - tsl::ImportNumpy(); - m.def( - "register_custom_call_target", - [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, - nb::str xla_platform_name, int api_version, - XLA_FFI_Handler_Traits traits) { - const char* fn_name_c_str; - size_t fn_name_size; - nb::str fn_name_bn_str; - if (nb::try_cast(fn_name_py, fn_name_bn_str)) { - fn_name_c_str = fn_name_bn_str.c_str(); - fn_name_size = nb::len(fn_name_bn_str); - } else{ - nb::bytes bytes = nb::cast(fn_name_py); - fn_name_c_str = bytes.c_str(); - fn_name_size = bytes.size(); - } - xla::ThrowIfError(RegisterCustomCallTarget( - static_cast(c_api.data()), fn_name_c_str, - fn_name_size, std::move(fn), api_version, traits)); - }, - nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), - nb::arg("xla_platform_name"), nb::arg("api_version") = 0, - nb::arg("traits") = 0); - m.def("registrations", &Registrations); + BuildGpuPluginExtension(m); m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 7d50a91cfcda..a5069cfb4a8e 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -20,6 +20,7 @@ load( "jax_visibility", "xla_py_proto_library", ) +# Placeholder: load proto_library licenses(["notice"]) @@ -37,6 +38,9 @@ exports_files(srcs = [ "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", + "hybrid.cc", + "hybrid_kernels.cc", + "hybrid_kernels.h", "linalg.cc", "linalg_kernels.cc", "linalg_kernels.cu.cc", diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 62977c5f57a1..a1e59385e6fa 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -60,8 +60,14 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", + SytrdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA", + GesvdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdj_ffi", "CUDA", + GesvdjFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA", CholeskyUpdateFfi); diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc new file mode 100644 index 000000000000..afe95a650d29 --- /dev/null +++ b/jaxlib/gpu/hybrid.cc @@ -0,0 +1,60 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/gpu/hybrid_kernels.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + +void GetLapackKernelsFromScipy() { + static bool initialized = false; // Protected by GIL + if (initialized) return; + nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); + nb::module_ cython_lapack = + nb::module_::import_("scipy.linalg.cython_lapack"); + nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); + auto lapack_ptr = [&](const char* name) { + return nb::cast(lapack_capi[name]).data(); + }; + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>(lapack_ptr("cgeev")); + AssignKernelFn>( + lapack_ptr("zgeev")); +} + +NB_MODULE(_hybrid, m) { + m.def("initialize", GetLapackKernelsFromScipy); + m.def("has_magma", []() { return MagmaLookup().FindMagmaInit().ok(); }); + m.def("registrations", []() { + nb::dict dict; + dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); + dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); + return dict; + }); +} + +} // namespace +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc new file mode 100644 index 000000000000..1ce2e547b11f --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -0,0 +1,631 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/hybrid_kernels.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace ffi = ::xla::ffi; + +// This helper class is used to define a host buffer that can be copied to and +// from a device buffer. +template +class HostBuffer { + public: + HostBuffer(std::size_t size) : size_(size) { + data_ = std::unique_ptr(new T[size]); + } + + absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T), + gpuMemcpyDeviceToHost, stream)); + } + + absl::Status CopyToDevice(gpuStream_t stream, T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T), + gpuMemcpyHostToDevice, stream)); + } + + T* get() const { return data_.get(); } + + private: + std::unique_ptr data_; + size_t size_; +}; + +// Forwarded from MAGMA for use as an input parameter. +typedef enum { + MagmaNoVec = 301, + MagmaVec = 302, +} magma_vec_t; + +// Compile time lookup of MAGMA function names depending on the data type. +template +struct always_false : std::false_type {}; + +template +struct MagmaGeev { + static_assert(always_false::value, "unsupported data type"); +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_sgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_dgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_cgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_zgeev"; +}; + +MagmaLookup::~MagmaLookup() { + if (initialized_) { + void* magma_finalize = dlsym(handle_, "magma_finalize"); + if (magma_finalize != nullptr) { + reinterpret_cast(magma_finalize)(); + } + } + if (handle_ != nullptr) { + dlclose(handle_); + } +} + +absl::StatusOr MagmaLookup::FindMagmaInit() { + void* magma_init = nullptr; + std::vector paths; + const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH"); + if (magma_lib_path != nullptr) { + paths.push_back(magma_lib_path); + } else { + paths.push_back("libmagma.so.2"); + paths.push_back("libmagma.so"); + paths.push_back(nullptr); + } + for (const auto& path : paths) { + handle_ = dlopen(path, RTLD_LAZY); + if (handle_ != nullptr) { + magma_init = dlsym(handle_, "magma_init"); + if (magma_init != nullptr) { + if (path != nullptr) { + lib_path_ = std::string(path); + } + break; + } + } + } + if (handle_ == nullptr || magma_init == nullptr) { + return absl::InternalError( + "Unable to dlopen a MAGMA shared library that defines a magma_init " + "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to " + "specify an explicit path to the library."); + } + return magma_init; +} + +absl::Status MagmaLookup::Initialize() { + if (failed_) { + return absl::InternalError("MAGMA initialization was unsuccessful."); + } + if (!initialized_) { + auto maybe_magma_init = FindMagmaInit(); + if (!maybe_magma_init.ok()) { + failed_ = true; + return maybe_magma_init.status(); + } + reinterpret_cast(maybe_magma_init.value())(); + initialized_ = true; + } + return absl::OkStatus(); +} + +absl::StatusOr MagmaLookup::Find(const char name[]) { + if (!initialized_) { + return absl::InternalError("MAGMA support has not been initialized."); + } + + auto it = symbols_.find(name); + if (it != symbols_.end()) return it->second; + + void* symbol = dlsym(handle_, name); + if (symbol == nullptr) { + if (lib_path_.has_value()) { + return absl::InternalError(absl::StrFormat( + "Unable to load the symbol '%s' from the MAGMA library at '%s'.", + name, lib_path_.value())); + + } else { + return absl::InternalError(absl::StrFormat( + "Unable to load a globally defined symbol called '%s'. Use the " + "JAX_GPU_MAGMA_PATH environment variable to specify an explicit " + "path to the library.", + name)); + } + } + + symbols_.insert({name, symbol}); + return symbol; +} + +// Lookup the MAGMA symbol for the given function name. This function only +// dlopen the MAGMA library once per process. +absl::StatusOr FindMagmaSymbol(const char name[]) { + static absl::Mutex mu; + static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu); + absl::MutexLock lock(&mu); + auto status = lookup.Initialize(); + if (!status.ok()) { + return status; + } + return lookup.Find(name); +} + +// Real-valued eigendecomposition + +template +class EigRealHost { + using Real = ffi::NativeType; + + public: + explicit EigRealHost() = default; + EigRealHost(EigRealHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi, + vl, &n_, vr, &n_, work, &lwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigRealMagma { + using Real = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*, + int, Real*, int, Real*, int, int*); + + public: + explicit EigRealMagma() = default; + EigRealMagma(EigRealMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Real query_host; + fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n, + &query_host, -1, &query_info); + return static_cast(query_host); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info); + } + + private: + int n_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto wr_host = HostBuffer(batch * cols); + auto wi_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto work_left = AllocateScratchMemory(cols * cols); + auto work_right = AllocateScratchMemory(cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), + [](auto value) { return std::isfinite(value); }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols, + wi_host.get() + i * cols, work_left.get(), work_right.get(), + work_host.get(), lwork, info_host.get() + i); + if (info_host.get()[i] == 0) { + if (left) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(), + vl_host.get() + i * cols * cols); + } + if (right) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(), + vr_host.get() + i * cols * cols); + } + } + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + wr_host.CopyToDevice(stream, wr->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + wi_host.CopyToDevice(stream, wi->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigRealDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != wr->element_type() || dataType != wi->element_type() || + ffi::ToComplex(dataType) != vl->element_type() || + ffi::ToComplex(dataType) != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig")); + FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::F32: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + case ffi::F64: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // wr + .Ret() // wi + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +// Complex-valued eigendecomposition + +template +class EigCompHost { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + public: + explicit EigCompHost() = default; + EigCompHost(EigCompHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_, + w, vl, &n_, vr, &n_, work, + &lwork, rwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigCompMagma { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*, + Complex*, int, Complex*, int, Complex*, int, Real*, int*); + + public: + explicit EigCompMagma() = default; + EigCompMagma(EigCompMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + lda_ = std::max(n_, 1); + ldvl_ = left ? n_ : 1; + ldvr_ = right ? n_ : 1; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Complex query_host; + fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr, + ldvr_, &query_host, -1, nullptr, &query_info); + return static_cast(query_host.real()); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork, + rwork, info); + } + + private: + int n_, lda_, ldvl_, ldvr_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto w_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto rwork_host = + AllocateScratchMemory(2 * cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { + return std::isfinite(z.real()) && std::isfinite(z.imag()); + }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols, + vl_host.get() + i * cols * cols, + vr_host.get() + i * cols * cols, work_host.get(), lwork, + rwork_host.get(), info_host.get() + i); + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + w_host.CopyToDevice(stream, w->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigCompDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != w->element_type() || dataType != vl->element_type() || + dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::C64: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + case ffi::C128: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, + stream, left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h new file mode 100644 index 000000000000..2890837a2bd5 --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_ +#define JAXLIB_GPU_HYBRID_KERNELS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +// The MagmaLookup class is used for dlopening the MAGMA shared library, +// initializing it, and looking up MAGMA symbols. +class MagmaLookup { + public: + explicit MagmaLookup() = default; + ~MagmaLookup(); + absl::StatusOr FindMagmaInit(); + absl::Status Initialize(); + absl::StatusOr Find(const char name[]); + + private: + bool initialized_ = false; + bool failed_ = false; + void* handle_ = nullptr; + std::optional lib_path_ = std::nullopt; + absl::flat_hash_map symbols_; +}; + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_HYBRID_KERNELS_H_ diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index b10655645924..3a24e355ead0 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include "jaxlib/gpu/vendor.h" @@ -24,8 +25,9 @@ namespace JAX_GPU_NAMESPACE { namespace { __global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, - int batch, int batch_elem_size) { - for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; + int64_t batch, + int64_t batch_elem_size) { + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; idx += blockDim.x * gridDim.x) { buffer_out[idx] = buffer_in + idx * batch_elem_size; } @@ -33,8 +35,9 @@ __global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, } // namespace void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, - void* buffer_out, int batch, int batch_elem_size) { - const int block_dim = 128; + void* buffer_out, int64_t batch, + int64_t batch_elem_size) { + const std::size_t block_dim = 128; const std::size_t grid_dim = std::min(1024, (batch + block_dim - 1) / block_dim); MakeBatchPointersAsyncKernel<<>>( diff --git a/jaxlib/gpu/make_batch_pointers.h b/jaxlib/gpu/make_batch_pointers.h index f2fd064961e8..f43ac25c7e50 100644 --- a/jaxlib/gpu/make_batch_pointers.h +++ b/jaxlib/gpu/make_batch_pointers.h @@ -16,13 +16,16 @@ limitations under the License. #ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ #define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ +#include + #include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, - void* buffer_out, int batch, int batch_elem_size); + void* buffer_out, int64_t batch, + int64_t batch_elem_size); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 38936ee497cf..c74d9a1476c2 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -482,6 +482,7 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); + dict[JAX_GPU_PREFIX "solver_sytrd_ffi"] = EncapsulateFfiHandler(SytrdFfi); #ifdef JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index 4d1af3c50d76..d93d049d41db 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -317,6 +317,34 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); #endif // JAX_GPU_CUDA +// Symmetric tridiagonal reduction: sytrd + +#define JAX_GPU_DEFINE_SYTRD(Type, Name) \ + template <> \ + absl::StatusOr SytrdBufferSize(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, \ + /*E=*/nullptr, /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Sytrd(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n, Type *a, \ + RealType::value *d, RealType::value *e, \ + Type *tau, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, uplo, n, a, n, d, e, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_SYTRD(float, gpusolverDnSsytrd); +JAX_GPU_DEFINE_SYTRD(double, gpusolverDnDsytrd); +JAX_GPU_DEFINE_SYTRD(gpuComplex, gpusolverDnChetrd); +JAX_GPU_DEFINE_SYTRD(gpuDoubleComplex, gpusolverDnZhetrd); +#undef JAX_GPU_DEFINE_SYTRD + } // namespace solver } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index 336480e2e13b..e84a688a6081 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -188,8 +188,8 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBufferSize); #define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \ gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ - Type *a, Real *s, Type *u, Type *v, Type *workspace, \ - int lwork, int *info, gesvdjInfo_t params + Type *a, Real *s, Type *u, Type *v, Type *workspace, int lwork, \ + int *info, gesvdjInfo_t params JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); #undef JAX_GPU_SOLVER_Gesvdj_ARGS @@ -199,15 +199,28 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); #undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS -#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ - gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ - Real *s, Type *u, Type *v, Type *workspace, int lwork, \ - int *info, gpuGesvdjInfo_t params, int batch +#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ + Real *s, Type *u, Type *v, Type *workspace, int lwork, int *info, \ + gpuGesvdjInfo_t params, int batch JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); #undef JAX_GPU_SOLVER_GesvdjBatched_ARGS #endif // JAX_GPU_CUDA +// Symmetric tridiagonal reduction: sytrd + +#define JAX_GPU_SOLVER_SytrdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpublasFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SytrdBufferSize); +#undef JAX_GPU_SOLVER_SytrdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Sytrd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpublasFillMode_t uplo, int n, Type *a, Real *d, \ + Real *e, Type *tau, Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Sytrd); +#undef JAX_GPU_SOLVER_Sytrd_ARGS + #undef JAX_GPU_SOLVER_EXPAND_DEFINITION } // namespace solver diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 7852da4bc04f..7e6f14ed4717 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -915,7 +915,8 @@ ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto s_data = static_cast::value*>(s->untyped_data()); + auto s_data = + static_cast::value*>(s->untyped_data()); auto u_data = static_cast(u->untyped_data()); auto v_data = static_cast(v->untyped_data()); auto info_data = info->typed_data(); @@ -1014,6 +1015,101 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, #endif // JAX_GPU_CUDA +// Symmetric tridiagonal reduction: sytrd + +template +ffi::Error SytrdImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result d, + ffi::Result e, + ffi::Result tau, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + FFI_ASSIGN_OR_RETURN(int lwork, + solver::SytrdBufferSize(handle.get(), uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "sytrd")); + + auto* a_data = static_cast(a.untyped_data()); + auto* out_data = static_cast(out->untyped_data()); + auto* d_data = + static_cast::value*>(d->untyped_data()); + auto* e_data = + static_cast::value*>(e->untyped_data()); + auto* tau_data = static_cast(tau->untyped_data()); + auto* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + int out_step = n * n; + for (int64_t i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Sytrd(handle.get(), uplo, n, out_data, + d_data, e_data, tau_data, + workspace, lwork, info_data)); + out_data += out_step; + d_data += n; + e_data += n - 1; + tau_data += n - 1; + ++info_data; + } + return ffi::Error::Success(); +} + +ffi::Error SytrdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result d, + ffi::Result e, + ffi::Result tau, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + d->element_type() != ffi::ToReal(dataType) || + e->element_type() != ffi::ToReal(dataType) || + tau->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to sytrd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to sytrd must be square"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "sytrd")); + FFI_RETURN_IF_ERROR(CheckShape(d->dimensions(), {batch, cols}, "d", "sytrd")); + FFI_RETURN_IF_ERROR( + CheckShape(e->dimensions(), {batch, cols - 1}, "e", "sytrd")); + FFI_RETURN_IF_ERROR( + CheckShape(tau->dimensions(), {batch, cols - 1}, "tau", "sytrd")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "sytrd")); + SOLVER_DISPATCH_IMPL(SytrdImpl, batch, rows, stream, scratch, lower, a, out, + d, e, tau, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in sytrd", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SytrdFfi, SytrdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("lower") + .Arg() // a + .Ret() // out + .Ret() // d + .Ret() // e + .Ret() // tau + .Ret>() // info +); + #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 022564eb108c..2f9494d7fb38 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -36,6 +36,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi); #ifdef JAX_GPU_CUDA XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index c4a9af5ffe2e..22397ff908bc 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -466,7 +466,8 @@ KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1, absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { std::vector params; - params.reserve(parameters_.size()); + // We need an additional parameter for the scratchpad buffer. + params.reserve(parameters_.size() + 1); for (size_t i = 0; i < parameters_.size(); ++i) { const Parameter& param = parameters_[i]; if (std::holds_alternative(param.value)) { @@ -492,6 +493,14 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { param.value))); } } + // Triton's kernel ABI expects an additional scratchpad global memory. + // For now it is only used for on-device creation of TMA descriptors, which + // we do not use yet, so we are just replacing this argument with a null + // pointer. + // TODO: b/381242007 - Allocate a proper buffer if we want to use + // device-side TMA APIs. + void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns. + params.push_back(&scratch_ptr); return kernel_.Launch(stream, grid_, params.data()); } diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu_plugin_extension.cc new file mode 100644 index 000000000000..ba7896aa5dfe --- /dev/null +++ b/jaxlib/gpu_plugin_extension.cc @@ -0,0 +1,178 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu_plugin_extension.h" + +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/c_api.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/py_client_gpu.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, + const char* fn_name_c_str, + size_t fn_name_size, nb::object fn, + int api_version, + XLA_FFI_Handler_Traits traits) { + if (c_api->extension_start == nullptr) { + return Unimplemented("The plugin does not have extension."); + } + const PJRT_Extension_Base* next = + reinterpret_cast(c_api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + if (next == nullptr) { + return Unimplemented("The plugin does not have a custom call extension."); + } + PJRT_Gpu_Register_Custom_Call* register_custom_call = + reinterpret_cast(next)->custom_call; + + if (traits != 0) { + return Unimplemented("The plugin does not support custom call traits."); + } + + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + args.function_name = fn_name_c_str; + args.function_name_size = fn_name_size; + +#if PJRT_API_GPU_EXTENSION_VERSION >= 1 + args.api_version = api_version; +#endif + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration requires handlers as PyCapsules"); + } + return capsule; + }; + +#if PJRT_API_GPU_EXTENSION_VERSION <= 1 + TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); + args.custom_call_function = fn_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); +#else + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = nullptr; + + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); + args.handler_execute = capsule_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + auto capsule_execute = as_capsule(fn); + if (capsule_execute.ok()) { + args.handler_execute = capsule_execute->data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); + return capsule.data(); + }; + + TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +#endif +} + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(xla::XlaPythonGpuCallback); + return dict; +} + +} // namespace + +void BuildGpuPluginExtension(nanobind::module_& m) { + tsl::ImportNumpy(); + m.def( + "register_custom_call_target", + [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, + nb::str xla_platform_name, int api_version, + XLA_FFI_Handler_Traits traits) { + const char* fn_name_c_str; + size_t fn_name_size; + nb::str fn_name_bn_str; + if (nb::try_cast(fn_name_py, fn_name_bn_str)) { + fn_name_c_str = fn_name_bn_str.c_str(); + fn_name_size = nb::len(fn_name_bn_str); + } else { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name_c_str = bytes.c_str(); + fn_name_size = bytes.size(); + } + xla::ThrowIfError(RegisterCustomCallTarget( + static_cast(c_api.data()), fn_name_c_str, + fn_name_size, std::move(fn), api_version, traits)); + }, + nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), + nb::arg("xla_platform_name"), nb::arg("api_version") = 0, + nb::arg("traits") = 0); + m.def("registrations", &Registrations); +} + +} // namespace xla diff --git a/jaxlib/gpu_plugin_extension.h b/jaxlib/gpu_plugin_extension.h new file mode 100644 index 000000000000..ae8cd73dbcfb --- /dev/null +++ b/jaxlib/gpu_plugin_extension.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_ +#define JAXLIB_GPU_PLUGIN_EXTENSION_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildGpuPluginExtension(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_ diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index e364b91e278c..b96040acd614 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -13,12 +13,9 @@ # limitations under the License. from __future__ import annotations - -import functools from functools import partial import importlib import itertools -import operator import jaxlib.mlir.ir as ir @@ -61,8 +58,6 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) -_prod = lambda xs: functools.reduce(operator.mul, xs, 1) - def _threefry2x32_lowering(prng, platform: str, keys, data, length: int | ir.Value | None = None, diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 03fd43e9ef89..59819f1fc914 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -56,6 +56,21 @@ xla_client.register_custom_call_target(_name, _value, platform="CUDA", api_version=api_version) +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: + try: + _cuhybrid = importlib.import_module( + f"{cuda_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _cuhybrid = None + else: + break + +if _cuhybrid: + for _name, _value in _cuhybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="CUDA", + api_version=1) + try: from .rocm import _blas as _hipblas # pytype: disable=import-error except ImportError: @@ -88,6 +103,34 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hiphybrid = importlib.import_module( + f"{rocm_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _hiphybrid = None + else: + break + +if _hiphybrid: + for _name, _value in _hiphybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM", + api_version=1) + +def initialize_hybrid_kernels(): + if _cuhybrid: + _cuhybrid.initialize() + if _hiphybrid: + _hiphybrid.initialize() + +def has_magma(): + if _cuhybrid: + return _cuhybrid.has_magma() + if _hiphybrid: + return _hiphybrid.has_magma() + return False + def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" return np.finfo(dtype).dtype diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 3c812d62cfae..976e5f26cb4b 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,8 +20,8 @@ load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_roc load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_python//python:defs.bzl", "py_test") -load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") +load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl # lint tools. @@ -43,6 +43,7 @@ mosaic_gpu_internal_users = [] mosaic_internal_users = [] pallas_gpu_internal_users = [] pallas_tpu_internal_users = [] +pallas_extension_deps = [] jax_internal_export_back_compat_test_util_visibility = [] jax_internal_test_harnesses_visibility = [] @@ -65,7 +66,9 @@ _py_deps = { "filelock": ["@pypi_filelock//:pkg"], "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], + "magma": [], "matplotlib": ["@pypi_matplotlib//:pkg"], + "mpmath": [], "opt_einsum": ["@pypi_opt_einsum//:pkg"], "pil": ["@pypi_pillow//:pkg"], "portpicker": ["@pypi_portpicker//:pkg"], @@ -305,6 +308,95 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _jax_wheel_impl(ctx): + executable = ctx.executable.wheel_binary + + output = ctx.actions.declare_directory(ctx.label.name) + args = ctx.actions.args() + args.add("--output_path", output.path) # required argument + args.add("--cpu", ctx.attr.platform_tag) # required argument + jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path + args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + if ctx.attr.enable_cuda: + args.add("--enable-cuda", "True") + if ctx.attr.platform_version == "": + fail("platform_version must be set to a valid cuda version for cuda wheels") + args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + if ctx.attr.enable_rocm: + args.add("--enable-rocm", "True") + if ctx.attr.platform_version == "": + fail("platform_version must be set to a valid rocm version for rocm wheels") + args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + if ctx.attr.skip_gpu_kernels: + args.add("--skip_gpu_kernels") + + args.set_param_file_format("flag_per_line") + args.use_param_file("@%s", use_always = False) + ctx.actions.run( + arguments = [args], + inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], + outputs = [output], + executable = executable, + ) + return [DefaultInfo(files = depset(direct = [output]))] + +_jax_wheel = rule( + attrs = { + "wheel_binary": attr.label( + default = Label("//jaxlib/tools:build_wheel"), + executable = True, + # b/365588895 Investigate cfg = "exec" for multi platform builds + cfg = "target", + ), + "platform_tag": attr.string(mandatory = True), + "git_hash": attr.label(allow_single_file = True), + "enable_cuda": attr.bool(default = False), + # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. + "platform_version": attr.string(mandatory = True, default = ""), + "skip_gpu_kernels": attr.bool(default = False), + "enable_rocm": attr.bool(default = False), + }, + implementation = _jax_wheel_impl, + executable = False, +) + +def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): + """Create jax artifact wheels. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the name of the wheel + wheel_binary: the binary to use to build the wheel + enable_cuda: whether to build a cuda wheel + platform_version: the cuda version to use for the wheel + + Returns: + A directory containing the wheel + """ + _jax_wheel( + name = name, + wheel_binary = wheel_binary, + enable_cuda = enable_cuda, + platform_version = platform_version, + # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to + # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to + # the git hash file needs to be created first. + git_hash = select({ + "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", + "//conditions:default": None, + }), + # Following the convention in jax/tools/build_utils.py. + # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. + platform_tag = select({ + "//jaxlib/tools:macos_arm64": "arm64", + "//jaxlib/tools:win_amd64": "AMD64", + "//jaxlib/tools:arm64": "aarch64", + "@platforms//cpu:x86_64": "x86_64", + }), + ) + jax_test_file_visibility = [] def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 9eef615ccc07..5c1d316cf255 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -27,6 +27,7 @@ from jaxlib import xla_client from .cpu import _lapack +from .cpu._lapack import schur from .cpu._lapack import eig from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, @@ -353,9 +354,9 @@ def geev_hlo(ctx, dtype, input, *, # # gees : Schur factorization -def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, +def gees_hlo(ctx, dtype, a, *, jobvs=True, sort=False, select=None, a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() + fn_base = prepare_lapack_call(fn_base="gees", dtype=dtype) a_type = ir.RankedTensorType(a.type) etype = a_type.element_type assert len(a_shape_vals) >= 2 @@ -368,70 +369,108 @@ def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, raise NotImplementedError( "The sort feature of LAPACK's gees routine is not implemented.") - jobvs = ord('V' if jobvs else 'N') - sort = ord('S' if sort else 'N') + mode = ( + schur.ComputationMode.kComputeSchurVectors + if jobvs + else schur.ComputationMode.kNoComputeSchurVectors + ) + sort = schur.Sort.kSortEigenvalues if sort else schur.Sort.kNoSortEigenvalues + if ctx.is_forward_compat(): + fn = fn_base + workspaces: list[ShapeTypePair] + eigvals: list[ShapeTypePair] + if not np.issubdtype(dtype, np.complexfloating): + workspaces = [(a_shape_vals, etype)] + workspace_layouts = [layout] + eigvals = [(batch_dims_vals + (n,), etype)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + else: + workspaces = [(a_shape_vals, etype), + ([n], ir.ComplexType(etype).element_type), + ] + workspace_layouts = [layout, [0]] + eigvals = [(batch_dims_vals + (n,), etype)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] - if dtype == np.float32: - fn = "lapack_sgees" - elif dtype == np.float64: - fn = "lapack_dgees" - elif dtype == np.complex64: - fn = "lapack_cgees" - elif dtype == np.complex128: - fn = "lapack_zgees" - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") + i32_type = ir.IntegerType.get_signless(32) - workspaces: list[ShapeTypePair] + scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + shape_type_pairs = workspaces + eigvals + [ + (a_shape_vals, etype), + (batch_dims_vals, i32_type), + (batch_dims_vals, i32_type)] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + out = custom_call( + fn, + result_types=result_types, + operands=[ + batch_size_val, + ensure_hlo_s32(n), + hlo_u8(mode.value), + hlo_u8(sort.value), + # TODO: figure out how to put the callable select function here + a + ], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=workspace_layouts + eigvals_layouts + [ + layout, + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ], + operand_output_aliases={4: 0}, + result_shapes=result_shapes, + ).results + if sort == schur.Sort.kSortEigenvalues: + return (out[0], out[3], out[4], out[5]) + else: + return (out[0], out[3], out[5]) + fn = fn_base + "_ffi" eigvals: list[ShapeTypePair] - if not np.issubdtype(dtype, np.complexfloating): - workspaces = [(a_shape_vals, etype)] - workspace_layouts = [layout] - eigvals = [(batch_dims_vals + (n,), etype)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - else: - workspaces = [(a_shape_vals, etype), - ([n], ir.ComplexType(etype).element_type), - ] - workspace_layouts = [layout, [0]] - eigvals = [(batch_dims_vals + (n,), etype)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] + is_complex = np.issubdtype(dtype, np.complexfloating) + eigvals = [(batch_dims_vals + (n,), etype)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + if not is_complex: + eigvals = eigvals * 2 + eigvals_layouts = eigvals_layouts * 2 i32_type = ir.IntegerType.get_signless(32) - - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs = workspaces + eigvals + [ + shape_type_pairs = [ (a_shape_vals, etype), + (a_shape_vals, etype), + *eigvals, (batch_dims_vals, i32_type), - (batch_dims_vals, i32_type)] + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, result_types=result_types, - operands=[ - batch_size_val, - ensure_hlo_s32(n), - hlo_u8(jobvs), - hlo_u8(sort), - # TODO: figure out how to put the callable select function here - a - ], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=workspace_layouts + eigvals_layouts + [ - layout, - tuple(range(num_bd - 1, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), + operands=[a], + # TODO(paruzelp): Use FFI execution context to put `select` + operand_layouts=[layout], + result_layouts=[ + layout, + layout, + *eigvals_layouts, + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), ], - operand_output_aliases={4: 0}, + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={ + "mode": _enum_to_char_attr(mode), + "sort": _enum_to_char_attr(sort), + }, + api_version=4, ).results - if sort == ord('S'): - return (out[0], out[3], out[4], out[5]) + # out: Schur Form, Schur Vectors, Eigenvalues, Selected Eigenvalues, Info + if is_complex: + return out[0], out[1], out[2], out[3], out[4] else: - return (out[0], out[3], out[5]) + return out[0], out[1], (out[2], out[3]), out[4], out[5] # gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form. @@ -509,8 +548,9 @@ def gehrd_hlo(ctx, dtype, a): # sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. -def sytrd_hlo(dtype, a, *, lower): - _lapack.initialize() +def sytrd_hlo(ctx, dtype, a, *, lower): + fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" + fn_base = prepare_lapack_call(fn_base=fn_base + "trd", dtype=dtype) a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 @@ -518,52 +558,83 @@ def sytrd_hlo(dtype, a, *, lower): assert m == n, (m, n) batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - b = 1 - for d in batch_dims: - b *= d + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + i32_type = ir.IntegerType.get_signless(32) - if dtype == np.float32: - fn = "lapack_ssytrd" - lwork = _lapack.lapack_ssytrd_workspace(n, n) - diag_type = a_type.element_type - elif dtype == np.float64: - fn = "lapack_dsytrd" - lwork = _lapack.lapack_dsytrd_workspace(n, n) - diag_type = a_type.element_type - elif dtype == np.complex64: - fn = "lapack_chetrd" - lwork = _lapack.lapack_chetrd_workspace(n, n) + if ctx.is_forward_compat(): + fn = fn_base + b = 1 + for d in batch_dims: + b *= d + + if dtype == np.float32: + lwork = _lapack.lapack_ssytrd_workspace(n, n) + diag_type = a_type.element_type + elif dtype == np.float64: + lwork = _lapack.lapack_dsytrd_workspace(n, n) + diag_type = a_type.element_type + elif dtype == np.complex64: + lwork = _lapack.lapack_chetrd_workspace(n, n) + diag_type = ir.F32Type.get() + elif dtype == np.complex128: + lwork = _lapack.lapack_zhetrd_workspace(n, n) + diag_type = ir.F64Type.get() + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + return custom_call( + fn, + result_types=[ + a.type, + ir.RankedTensorType.get(batch_dims + (n,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), + ir.RankedTensorType.get(batch_dims, i32_type), + ir.RankedTensorType.get([lwork], a_type.element_type), + ], + operands=[hlo_s32(n), hlo_s32(1 if lower else 0), hlo_s32(max(1, n)), + hlo_s32(b), hlo_s32(lwork), a], + operand_layouts=[[]] * 5 + [layout], + result_layouts=[ + layout, + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={5: 0}, + ).results[:5] + fn = fn_base + "_ffi" + if dtype == np.float32 or dtype == np.complex64: diag_type = ir.F32Type.get() - elif dtype == np.complex128: - fn = "lapack_zhetrd" - lwork = _lapack.lapack_zhetrd_workspace(n, n) + elif dtype == np.float64 or dtype == np.complex128: diag_type = ir.F64Type.get() else: raise NotImplementedError(f"Unsupported dtype {dtype}") - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - out = custom_call( + # Returns x_out, on_diag, off_diag, tau, info + return custom_call( fn, result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (n,), diag_type), - ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type), - ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), + a.type, + ir.RankedTensorType.get(batch_dims + (n,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), + ir.RankedTensorType.get(batch_dims, i32_type), ], - operands=[hlo_s32(n), hlo_s32(1 if lower else 0), hlo_s32(max(1, n)), - hlo_s32(b), hlo_s32(lwork), a], - operand_layouts=[[]] * 5 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ - layout, - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), ], - operand_output_aliases={5: 0}, + operand_output_aliases={0: 0}, + backend_config={ + "uplo": _matrix_uplo_attr(lower=lower), + }, + api_version=4, ).results - return out[:5] diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 1c45a4ce9463..0b94f9d1d948 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -149,6 +149,18 @@ py_extension( ], ) +py_extension( + name = "_mosaic_gpu_ext", + srcs = ["mosaic_gpu_ext.cc"], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + ], +) + # This is here, instead of in jaxlib/mosaic/python, so it's in the same # directory as libjaxlib_mlir_capi.so (produced by # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc new file mode 100644 index 000000000000..7204bbaa1658 --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// clang-format: off +// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h, +// otherwise this code will not build on Windows. +#include "pybind11/pybind11.h" +// clang-format: on + +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep +#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" + +PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) { + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle dialect = mlirGetDialectHandle__mosaic_gpu__(); + mlirDialectHandleRegisterDialect(dialect, context); + if (load) { + mlirDialectHandleLoadDialect(dialect, context); + } + }, + py::arg("context"), py::arg("load") = true); +} diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 5452520204b8..da7498ed437d 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -1,3 +1,6 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_python//python:defs.bzl", "py_library") + # Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,9 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "pallas_extension_deps") licenses(["notice"]) @@ -41,6 +42,7 @@ cc_library( "dialect/tpu/tpu_dialect.cc", "dialect/tpu/tpu_ops.cc", "dialect/tpu/util.cc", + ":extension_srcs", ] + glob([ "dialect/tpu/transforms/*.cc", ]), @@ -60,6 +62,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -83,7 +86,7 @@ cc_library( "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", - ], + ] + pallas_extension_deps, ) gentbl_cc_library( @@ -226,3 +229,11 @@ cc_library( ], alwayslink = True, ) + +filegroup( + name = "extension_srcs", + srcs = [ + "dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc", + "dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc", + ], +) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 0fff8eee6529..681ee708edd8 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load( + "@llvm-project//mlir:tblgen.bzl", + "gentbl_cc_library", + "gentbl_filegroup", + "td_library", +) package( default_applicable_licenses = [], @@ -24,7 +29,10 @@ td_library( srcs = ["mosaic_gpu.td"], includes = ["."], deps = [ + "@llvm-project//mlir:BasicPtxBuilderIntTdFiles", "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", ], ) @@ -47,17 +55,11 @@ gentbl_cc_library( "mosaic_gpu_dialect.cc.inc", ), ( - [ - "-gen-op-decls", - "--typedefs-dialect=mosaic_gpu", - ], + ["-gen-op-decls"], "mosaic_gpu_ops.h.inc", ), ( - [ - "-gen-op-defs", - "--typedefs-dialect=mosaic_gpu", - ], + ["-gen-op-defs"], "mosaic_gpu_ops.cc.inc", ), ( @@ -74,6 +76,28 @@ gentbl_cc_library( ], "mosaic_gpu_types.cc.inc", ), + ( + ["-gen-enum-decls"], + "mosaic_gpu_enums.h.inc", + ), + ( + ["-gen-enum-defs"], + "mosaic_gpu_enums.cc.inc", + ), + ( + [ + "-gen-attrdef-decls", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.cc.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mosaic_gpu.td", @@ -88,6 +112,7 @@ cc_library( hdrs = ["mosaic_gpu.h"], deps = [ ":mosaic_gpu_inc_gen", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -95,9 +120,11 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", "@tsl//tsl/platform:statusor", @@ -127,3 +154,64 @@ cc_test( "@tsl//tsl/platform:errors", ], ) + +gentbl_filegroup( + name = "mosaic_gpu_python_gen_raw", + tbl_outs = [ + ( + [ + "-gen-python-enum-bindings", + "-bind-dialect=mosaic_gpu", + ], + "_mosaic_gpu_gen_enums_raw.py", + ), + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=mosaic_gpu", + ], + "_mosaic_gpu_gen_ops_raw.py", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = ":mosaic_gpu.td", + deps = [ + ":mosaic_gpu_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +genrule( + name = "mosaic_gpu_python_gen_enums", + srcs = ["_mosaic_gpu_gen_enums_raw.py"], + outs = ["_mosaic_gpu_gen_enums.py"], + cmd = """ + cat $(location _mosaic_gpu_gen_enums_raw.py) | \ + sed -e 's/^from \\.\\.ir/from jaxlib\\.mlir\\.ir/g; s/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@""", +) + +genrule( + name = "mosaic_gpu_python_gen_ops", + srcs = ["_mosaic_gpu_gen_ops_raw.py"], + outs = ["_mosaic_gpu_gen_ops.py"], + cmd = "cat $(location _mosaic_gpu_gen_ops_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@", +) + +DIALECT_CAPI_SOURCES = [ + ":integrations/c/gpu_dialect.cc", +] + +DIALECT_CAPI_HEADERS = [ + ":integrations/c/gpu_dialect.h", +] + +cc_library( + name = "gpu_dialect_capi", + srcs = DIALECT_CAPI_SOURCES, + hdrs = DIALECT_CAPI_HEADERS, + deps = [ + ":mosaic_gpu", + ":mosaic_gpu_inc_gen", + "@llvm-project//mlir:CAPIIR", + ], +) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc new file mode 100644 index 000000000000..1a854f395044 --- /dev/null +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc @@ -0,0 +1,25 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" + +#include "mlir/CAPI/Registration.h" +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" + +extern "C" { + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu, + mosaic_gpu::MosaicGPUDialect); +} diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h new file mode 100644 index 000000000000..bb6cf6e3af4a --- /dev/null +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_ +#define JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_ + +#include + +#include "mlir/CAPI/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu); + +#ifdef __cplusplus +} +#endif + +#endif // JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 933d798238e3..c86450fbdf0c 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -18,37 +18,44 @@ limitations under the License. #include #include +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h" // IWYU pragma: keep -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Dialect.h" -#include "mlir/include/mlir/IR/DialectImplementation.h" // IWYU pragma: keep -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/IR/Diagnostics.h" #include "tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" - +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_enums.cc.inc" +#define GET_ATTRDEF_CLASSES +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc" #define GET_TYPEDEF_CLASSES #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc" #define GET_OP_CLASSES @@ -230,11 +237,89 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) { .setVisibility(mlir::func::FuncOp::Visibility::Private); } +bool IsContiguous(mlir::MemRefType type) { + return type.getLayout().isIdentity() || + (type.hasStaticShape() && type.getNumElements() > 0 && + mlir::memref::isStaticShapeAndContiguousRowMajor(type)); +} + +namespace { +llvm::LogicalResult VerifyCommonLoadStoreOp( + mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name, + mlir::MemRefType smem_type, absl::string_view smem_name, + mlir::ArrayRef slice_lengths, int num_indices) { + auto error = [loc](auto... params) { + return emitError(loc, llvm::formatv(params...)); + }; + + if (!IsContiguous(smem_type)) { + return error("The `{0}` memref must be contiguous.", smem_name); + } + if (gmem_type.getElementType() != smem_type.getElementType()) { + return error( + "The `source` and `destination` memrefs must have the same element " + "type."); + } + if (absl::c_any_of(slice_lengths, [](int64_t s) { return s < -1; })) { + return error( + "The `slice_lengths` attribute must not contain values less than -1."); + } + if (gmem_type.getRank() != + smem_type.getRank() + absl::c_count(slice_lengths, -1)) { + return error( + "The rank of the `{0}` must be equal to the rank of the " + "`{1}` plus the number of collapsed dimensions as indicated " + "by -1 values in `slice_lengths`.", + gmem_name, smem_name); + } + if (num_indices != gmem_type.getRank()) { + return error("The size of `indices` must be equal to the rank of `{0}`.", + gmem_name); + } + if (slice_lengths.size() != gmem_type.getRank()) { + return error( + "The size of `slice_lengths` must be equal to the rank of `{0}`.", + gmem_name); + } + return llvm::success(); +} +} // namespace + +llvm::LogicalResult AsyncLoadOp::verify() { + auto r = VerifyCommonLoadStoreOp(getLoc(), getSource().getType(), "source", + getDestination().getType(), "destination", + getSliceLengths(), getIndices().size()); + if (failed(r)) { + return r; + } + + for (int i = 0; i < getCollective().size(); ++i) { + for (int k = i + 1; k < getCollective().size(); ++k) + if (getCollective()[i] == getCollective()[k]) { + return emitError( + "The `collective` attribute must not contain duplicate " + "dimensions."); + } + } + + return llvm::success(); +} + +llvm::LogicalResult AsyncStoreOp::verify() { + return VerifyCommonLoadStoreOp(getLoc(), getDestination().getType(), + "destination", getSource().getType(), "source", + getSliceLengths(), getIndices().size()); +} + void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc" + >(); addOperations< #define GET_OP_LIST #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_ops.cc.inc" diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index 1badcab28012..14c0d0295a8f 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -19,17 +19,23 @@ limitations under the License. #include #include +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_enums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.h.inc" #define GET_TYPEDEF_CLASSES #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.h.inc" #define GET_OP_CLASSES @@ -40,6 +46,10 @@ namespace mosaic_gpu { using Memref = ::mlir::TypedValue<::mlir::MemRefType>; using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>; +struct GlobalMemory : public mlir::SideEffects::Resource::Base { + llvm::StringRef getName() final { return ""; } +}; + constexpr absl::string_view kRuntimeTmaDescriptorInitializerName = "mosaic_gpu_init_tma_desc"; constexpr absl::string_view kRuntimeMemcpyAsyncH2DName = diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index e7154f577a7a..4129dcd1b345 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -16,15 +16,22 @@ limitations under the License. #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonTypeConstraints.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" def MosaicGPU_Dialect : Dialect { let name = "mosaic_gpu"; let cppNamespace = "::mosaic_gpu"; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } class MosaicGPU_Type traits = []> @@ -32,25 +39,241 @@ class MosaicGPU_Type traits = []> let mnemonic = mnemonic_; } +class MosaicGPU_Attr + : AttrDef { + let mnemonic = mnemonic_; +} + def MosaicGPU_Barrier : MosaicGPU_Type<"Barrier", "barrier", [MemRefElementTypeInterface]> { let summary = "barrier"; let description = "A barrier to use for synchronizing threads"; } +def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; + def MosaicGPU_InitializeBarrierOp : Op { let summary = "Initializes a memref of barriers"; let description = [{ Initializes a memref of barriers each meant to synchronize exactly `arrival_count` threads. + + The base pointer of the result memref corresponds to `base_pointer`, which + must be a pointer to a shared memory location. }]; - let arguments = (ins ConfinedAttr:$arrival_count); + let arguments = (ins + LLVM_PointerShared:$base_pointer, + ConfinedAttr:$arrival_count); let results = (outs MemRefOf<[MosaicGPU_Barrier]>:$barriers_ref); let assemblyFormat = [{ - $arrival_count attr-dict `:` type($barriers_ref) + $base_pointer $arrival_count attr-dict `:` type($barriers_ref) + }]; +} + +def MosaicGPU_FragmentedLayout : + I32EnumAttr<"FragmentedLayout", "The layout of a FragmentedArray", [ + + // FragmentedArrays in this layout are always the result of a splat. Each + // thread in the warpgroup has a single copy of the value, regardless of + // the shape of the FragmentedArray. This makes it trivial to broadcast, + // reshape and do elementwise operations with all other layouts. + I32EnumAttrCase<"WGSplatFragLayout", 0>, + + // Convert the array to 1D and then shard across threads. + I32EnumAttrCase<"WGStridedFragLayout", 1>, + + // [m, n] matrix, where m % 64 == 0 == n % 8. + I32EnumAttrCase<"WGMMAFragLayout", 2>, + + // [m] vector, where m % 64 == 0. + I32EnumAttrCase<"WGMMARowFragLayout", 3> + ]> { + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_FragmentedLayoutAttr : EnumAttr< + MosaicGPU_Dialect, MosaicGPU_FragmentedLayout, "fragmented_layout"> { + let assemblyFormat = "`<` $value `>`"; +} + +// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td +// but it was not possible to reuse that definition. Including that file +// pulls in ops definitions that we don't want and they fail to compile. +def MosaicGPU_Dimension : I32EnumAttr<"Dimension", + "a dimension, either 'x', 'y', or 'z'", + [ + I32EnumAttrCase<"x", 0>, + I32EnumAttrCase<"y", 1>, + I32EnumAttrCase<"z", 2> + ]>{ + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_DimensionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode", + "What swizzling to use for a memory access.", + [ + I32EnumAttrCase<"kNoSwizzle", 0, "none">, + I32EnumAttrCase<"k32ByteSwizzle", 1, "32">, + I32EnumAttrCase<"k64ByteSwizzle", 2, "64">, + I32EnumAttrCase<"k128ByteSwizzle", 3, "128"> + ]>{ + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_SwizzlingModeAttr : EnumAttr; + +def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { + let parameters = (ins Variadic:$tiling); + let summary = "Tiles a suffix of memref dimensions."; + let description = [{ + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends + with the tile shape, and the size of tiled dimensions is divided by the tile + size. This is especially useful for swizzled WGMMA, which expect tiled + layouts in shared memory. + + Each tiled dimension must have a size that is either smaller than the + corresponding tile size or a multiple of the tile size. }]; + let assemblyFormat = "`<` $tiling `>`"; +} + +def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> { + let parameters = (ins Variadic:$permutation); + let summary = "Specifies how to transpose a memref."; + let assemblyFormat = "`<` $permutation `>`"; +} + +def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">; + +def MosaicGPU_AsyncLoadOp : Op]>]> { + let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; + let description = [{ + Schedules an async copy of the contents of the `source` MemRef in GMEM to + the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be + contiguous. + + If `arrive` is true, the `arrive.expect-tx(expect_count)` operation will be + executed on the provided `barrier` before the copy is scheduled. Upon + completion of the copy, the `complete-tx(complete-count)` operation will + always be executed on the provided `barrier`. + + The `indices` and `slice_lengths` inputs define what slice of the GMEM + `source` corresponds to the SMEM `destination`. Both `indices` and + `slice_lengths` must have a length equal to the rank of the `source`. The + values in `indices` are the starting indices of each dimension and the + values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths` + indicates that the slice length is 1 and that the corresponding dimension + should be collapsed and does not appear in the `destination` MemRef. + + Additional `transforms` may be provided to control how the `source` data is + mapped to the `destination`. The transformations will be composed in the + order they are provided. The `swizzle` attribute controls what swizzling + is applied to the data after it is transformed, before it is finally written + to SMEM. The transformed data is written in row-major order to the + contiguous SMEM `destination`. The untransformed `source` data does not need + to be contiguous, except for the last dimension, which needs to be + contiguous and the minor-most dimension. + + The `collective` attribute can be provided to use TMA multicast to more + efficiently load the GMEM data in cases where multiple thread blocks are + grouped together in a cluster and need to load the same data. Each block in + a cluster will first load a slice from GMEM to SMEM and then the slices will + be multicast to all other blocks in the cluster. In this way TMA multicast + guarnatees L2 cache hits. The `collective` attribute is the list of + cluster dimensions along which to partition the input data loads. + + The `predicate` input should be set to `true` by a single thread in the + warpgroup so that it schedules the load operation. All other threads in the + warpgroup should set the `predicate` to `false`. + }]; + + let arguments = (ins + MemRefOf<[AnyType]>:$source, + MemRefOf<[AnyType]>:$destination, + MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, + Variadic:$indices, + PtxPredicate:$predicate, + + // Attributes + DenseI64ArrayAttr:$slice_lengths, + TypedArrayAttrBase, "transforms">:$transforms, + DefaultValuedAttr:$swizzle, + DefaultValuedAttr:$arrive, + TypedArrayAttrBase:$collective + ); + + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `destination` `(` $destination `:` type($destination) `)` + `barrier` `(` $barrier `:` type($barrier) `)` + `indices` `(` $indices `)` + `predicate` `(` $predicate `)` + attr-dict + }]; + + let hasVerifier = 1; +} + +def MosaicGPU_AsyncStoreOp : Op]>]> { + let summary = "Schedules an async store of a MemRef from SMEM to GMEM"; + let description = [{ + Schedules an async store of the contents of the `source` MemRef in SMEM to + the `destination` MemRef in GMEM. The `source` MemRef in SMEM must be + contiguous. + + The `indices` and `slice_lengths` inputs define what slice of the GMEM + `destination` corresponds to the SMEM `source`. Both `indices` and + `slice_lengths` must have a length equal to the rank of the `destination`. + The values in `indices` are the starting indices of each dimension and the + values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths` + indicates that this dimension is collapsed in the `source` and needs to be + expanded to a slice of size 1 in the `destination`. + + Additional `transforms` may be provided to control how the `destination` + data in GMEM is mapped to the `source` data in SMEM. The transformations + will be composed in the order they are provided. The `swizzle` attribute + is the swizzling mode of the `source` data in SMEM. The `source` SMEM data + is contiguous and the transformed data is written to the `destination` GMEM + which does not need to be contiguous. + + The `predicate` input should be set to `true` by a single thread in the + warpgroup so that it schedules the store operation. All other threads in the + warpgroup should set the `predicate` to `false`. + }]; + + let arguments = (ins + MemRefOf<[AnyType]>:$source, + MemRefOf<[AnyType]>:$destination, + Variadic:$indices, + PtxPredicate:$predicate, + + // Attributes + DenseI64ArrayAttr:$slice_lengths, + TypedArrayAttrBase, "transforms">:$transforms, + DefaultValuedAttr:$swizzle + ); + + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `destination` `(` $destination `:` type($destination) `)` + `indices` `(` $indices `)` + `predicate` `(` $predicate `)` + attr-dict + }]; + + let hasVerifier = 1; } -#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ \ No newline at end of file +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 53c8d44f6c32..34f6241661d5 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "llvm/include/llvm/ADT/ArrayRef.h" +#include "llvm/include/llvm/ADT/SmallVector.h" #include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" @@ -118,8 +118,8 @@ class MosaicGpuTest : public ::testing::Test { }; TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { - llvm::ArrayRef shape{1, 2, 3}; - llvm::ArrayRef slice_shape{1, 2}; + std::vector shape{1, 2, 3}; + std::vector slice_shape{1, 2}; mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(&context_); @@ -128,7 +128,7 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { EXPECT_THAT( FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, slice_shape), + memref_type, mlir::ArrayRef(slice_shape)), StatusIs( absl::StatusCode::kFailedPrecondition, HasSubstr( @@ -136,8 +136,8 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { } TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) { - llvm::ArrayRef shape{1, 2, 3}; - llvm::ArrayRef slice_shape{1, 2, 3}; + std::vector shape{1, 2, 3}; + std::vector slice_shape{1, 2, 3}; mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(&context_); @@ -145,14 +145,14 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) { mlir::MemRefType::get(shape, builder_.getI4Type()); EXPECT_THAT(FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, slice_shape), + memref_type, mlir::ArrayRef(slice_shape)), StatusIs(absl::StatusCode::kUnimplemented, HasSubstr("Sub-byte types are not yet supported"))); } TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) { - llvm::ArrayRef shape{1, 2, 3}; - llvm::ArrayRef slice_shape{1, 2, 3}; + std::vector shape{1, 2, 3}; + std::vector slice_shape{1, 2, 3}; mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(&context_); @@ -161,7 +161,7 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) { absl::StatusOr fn_or = FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, slice_shape); + memref_type, mlir::ArrayRef(slice_shape)); ASSERT_OK(fn_or); llvm::SmallVector call_ops = @@ -193,34 +193,6 @@ TEST_F(MosaicGpuTest, RuntimeFunctionsAreRegistered) { mosaic_gpu::kRuntimeMemcpyAsyncH2DName)); } -TEST_F(MosaicGpuTest, InitializeBarrierOpEnforcesRelevantInvariants) { - auto loc = builder_.getUnknownLoc(); - auto f32 = builder_.getF32Type(); - auto barrier = BarrierType::get(&context_); - - // InitializeBarrierOp requires a memref with type `BarrierType`. - auto initialize_op = builder_.create( - loc, mlir::MemRefType::get({1, 2}, f32), /*arrival_count=*/1); - EXPECT_FALSE(mlir::succeeded(mlir::verify(*module_))); - ExpectLastErrorContains("must be memref of barrier values"); - initialize_op->erase(); - - // InitializeBarrierOp requires a non-negative arrival count. - initialize_op = builder_.create( - loc, mlir::MemRefType::get({1, 2}, barrier), /*arrival_count=*/0); - EXPECT_FALSE(mlir::succeeded(mlir::verify(*module_))); - ExpectLastErrorContains("value is positive"); - initialize_op->erase(); - - // Checks that InitializeBarrierOp prints nicely. - initialize_op = builder_.create( - loc, mlir::MemRefType::get({1, 2}, barrier), /*arrival_count=*/1); - EXPECT_TRUE(mlir::succeeded(mlir::verify(*module_))); - EXPECT_THAT( - MlirToString(initialize_op), - HasSubstr( - "mosaic_gpu.initialize_barrier 1 : memref<1x2x!mosaic_gpu.barrier>")); -} } // anonymous namespace } // namespace mosaic_gpu diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 1b6b8b935c99..6edad713b17a 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -39,6 +38,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "absl/log/check.h" namespace mlir::tpu { @@ -169,7 +169,7 @@ class RectangularVregBounds : public VRegDataBounds { // // The tiling attribute makes it possible to subdivide a single vector register // into multiple subtiles that traverse the last dimension of a value. For -// example, consider vregs of shape (4, 5) an array: +// example, consider vregs of shape (4, 5) on (2, 10) array: // // a b c d e f g h i j // k l m n o p q r s t @@ -259,18 +259,23 @@ class VectorLayout { int layout_rank() const { return layout_rank(implicit_dim_); } bool operator==(const VectorLayout &other) const; - bool operator!=(const VectorLayout &other) const { - return !(*this == other); - } - - // How many tiles fit in each vector register. - int64_t tilesPerVreg(const std::array target_shape) const { - const int64_t tile_elems = tiling_[0] * tiling_[1]; - const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1]; + bool operator!=(const VectorLayout &other) const { return !(*this == other); } + + static int64_t tilesPerVreg(const std::array target_shape, + const int8_t bitwidth, + const std::array tiling) { + CHECK_NE(0, bitwidth) << "bitwidth cannot be 0"; + const int64_t tile_elems = tiling[0] * tiling[1]; + const int64_t vreg_capacity = + (32 / bitwidth) * target_shape[0] * target_shape[1]; const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems); CHECK_EQ(rem, 0); return tiles_per_vreg; } + // How many tiles fit in each vector register. + int64_t tilesPerVreg(const std::array target_shape) const { + return VectorLayout::tilesPerVreg(target_shape, bitwidth_, tiling_); + } int64_t sublanesPerTile(const std::array target_shape) const { auto [sublanes_per_tile, rem] = @@ -283,8 +288,16 @@ class VectorLayout { // // We never reuse the same vector register to store data of multiple rows, // so only the minormost dimension can increase. + static std::array vregSlice(std::array target_shape, + const int8_t bitwidth, + const std::array tiling) { + return { + tiling[0], + VectorLayout::tilesPerVreg(target_shape, bitwidth, tiling) * tiling[1]}; + } + std::array vregSlice(std::array target_shape) const { - return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]}; + return VectorLayout::vregSlice(target_shape, bitwidth_, tiling_); } template diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 44199612ea73..0019581921c4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -31,6 +31,11 @@ def TPU_Dialect : Dialect { let cppNamespace = "::mlir::tpu"; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + static StringRef GetCoreTypeKey() { return "tpu.core_type"; } + + static std::optional GetCoreTypeAttr(Operation *op); + }]; } class TPU_Attr traits = []> @@ -39,13 +44,26 @@ class TPU_Attr traits = []> } // TODO(b/369418606): Find out the way to verify vreg size. -def TPU_Vreg : Type; +def TPU_Vreg : Type; class TPU_Type traits = []> : TypeDef { let mnemonic = mnemonic_; } +def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ + I32EnumAttrCase<"kTc", 0, "tc">, + I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, + I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_CoreTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; @@ -161,8 +179,8 @@ def TPU_ReductionKindAttr } def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> { - let arguments = (ins AnyVector:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); - let results = (outs AnyVector:$output); + let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) }]; @@ -196,13 +214,29 @@ def TPU_LoadOp : TPU_Op<"load"> { }]; } +// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. +def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; +} + def TPU_StridedLoadOp : TPU_Op<"strided_load"> { let arguments = (ins AnyMemRef:$base, Variadic:$indices, DenseI32ArrayAttr:$strides ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) }]; @@ -211,7 +245,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> { def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let arguments = (ins - AnyVector:$valueToStore, + AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, Variadic:$indices, DenseI32ArrayAttr:$strides @@ -257,7 +291,7 @@ def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { // TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let arguments = (ins - AnyVector:$value, + AnyVectorOfNonZeroRank:$value, SI32Attr:$amount, SI32Attr:$dimension, // When the stride is specified, the rotation amount for each index on the @@ -265,7 +299,7 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { OptionalAttr:$stride, OptionalAttr:$stride_dimension ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value) }]; @@ -274,7 +308,7 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { let arguments = (ins - AnyVector:$value, + AnyVectorOfNonZeroRank:$value, I32:$amount, SI32Attr:$dimension, // When the stride is specified, the rotation amount for each index on the @@ -282,7 +316,7 @@ def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { OptionalAttr:$stride, OptionalAttr:$stride_dimension ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) }]; @@ -291,28 +325,35 @@ def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { def TPU_IotaOp : TPU_Op<"iota", [Pure]> { let arguments = (ins OptionalAttr:$dimension); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ attr-dict `:` type($output) }]; } +// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically. +// b/376295711 def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { let arguments = (ins - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, I32Attr:$dimension, I32Attr:$times ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }]; } def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { + let description = [{ + For each sublane `i`, broadcasts the value in lane `lane + i` along the entire + sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i` + is not defined (can be anything). + }]; let arguments = (ins - AnyVector:$source, // All sublanes should be equal. + TPU_Vreg:$source, // All sublanes should be equal. I32Attr:$lane // Coordinates of the first element to take. ); // Output shape should be the same, except for position dim which contains // the newly inserted dimension. - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $lane attr-dict `:` type($source) `->` type($output) }]; @@ -321,30 +362,30 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { // Integer unpacks are always signed at the moment. def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { let arguments = (ins - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, I32Attr:$index ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; } // Integer packs are always signed at the moment. def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> { let arguments = (ins - Variadic:$sources, + Variadic:$sources, TPU_PackFormatEnum:$pack_format ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; } def TPU_GatherOp : TPU_Op<"gather", [Pure]> { let arguments = (ins - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, DenseI32ArrayAttr:$indices, I32Attr:$dimension ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `[` $indices `]` `in` $dimension attr-dict `:` type($source) `->` type($output) @@ -353,41 +394,66 @@ def TPU_GatherOp : TPU_Op<"gather", [Pure]> { def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { let arguments = (ins - AnyVector:$source, - AnyVector:$indices, // If this is 2D, only the first row matters. + AnyVectorOfNonZeroRank:$source, + AnyVectorOfNonZeroRank:$indices, // If this is 2D, only the first row matters. I32Attr:$dimension ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `[` $indices `]` `in` $dimension attr-dict `:` type($source) `,` type($indices) `->` type($output) }]; } -// TODO(apaszke): Add a verifier for this op + +def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { + let parameters = (ins + ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims, + ArrayRefParameter<"int64_t", "">:$rhs_contracting_dims, + ArrayRefParameter<"int64_t", "">:$lhs_non_contracting_dims, + ArrayRefParameter<"int64_t", "">:$rhs_non_contracting_dims, + // The contract is a flattened structure, wherein, each element is half of a + // pair of indices. The first element is always 0 (lhs) or 1 (rhs) and the + // second index is the index from the lhs or rhs. + ArrayRefParameter<"int64_t", "">:$output_dim_order, + OptionalArrayRefParameter<"int64_t", "">:$lhs_batch_dims, + OptionalArrayRefParameter<"int64_t", "">:$rhs_batch_dims + ); + let assemblyFormat = "`<` `[` $lhs_contracting_dims `]` `,` `[` $rhs_contracting_dims `]` `,` " + "`[` $lhs_non_contracting_dims `]` `,` `[` $rhs_non_contracting_dims `]` `,` " + "`[` $output_dim_order `]` `,` " + "`[` (`]`):($lhs_batch_dims^ `]`)? `,` " + "`[` (`]`):($rhs_batch_dims^ `]`)? `>`"; +} + // TODO(apaszke): Think hard about precision def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { let arguments = (ins - AnyVector:$lhs, - AnyVector:$rhs, - AnyVector:$acc, + AnyVectorOfNonZeroRank:$lhs, + AnyVectorOfNonZeroRank:$rhs, + AnyVectorOfNonZeroRank:$acc, + // These flags are deprecated - if dimension_numbers are defined, + // these flags are ignored. They will always be false after canonicalize. DefaultValuedAttr:$transpose_lhs, DefaultValuedAttr:$transpose_rhs, - OptionalAttr:$precision + OptionalAttr:$precision, + // NOTE: User-level optional, once canonicalized, always present. + OptionalAttr:$dimension_numbers ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { let arguments = (ins - Variadic:$sources, + Variadic:$sources, I32Attr:$dimension ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) }]; @@ -395,8 +461,8 @@ def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { } def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { - let arguments = (ins AnyVector:$input); - let results = (outs AnyVector:$output); + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; let hasVerifier = 1; } @@ -408,16 +474,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { } def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { - let arguments = (ins Variadic:$input); - let results = (outs AnyVector:$output); + let arguments = (ins Variadic:$input); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; } def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { - let arguments = (ins AnyVector:$input); - let results = (outs Variadic:$output); + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs Variadic:$output); let hasCanonicalizeMethod = 1; let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) @@ -571,6 +637,7 @@ def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { ); let results = (outs); let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}]; + let hasVerifier = 1; } def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> { @@ -591,12 +658,18 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { MemRefOf<[TPU_SemaphoreType]>:$semaphore, I32:$amount, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + OptionalAttr:$core_type ); - let assemblyFormat = [{ - $semaphore `,` $amount (`,` $device_id^)? (`,` $core_id^)? attr-dict `:` type($semaphore) +let assemblyFormat = [{ + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; + let builders = [ + // A backward-compatible builder that sets `core_type` to nullptr. + OpBuilder<(ins "Value":$semaphore, "Value":$amount, + "Value":$device_id, "Value":$core_id)>, + ]; } def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { @@ -654,8 +727,8 @@ def TPU_DelayOp : TPU_Op<"delay"> { // Expands the granularity of mask to subelements. def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { - let arguments = (ins AnyVector:$input); - let results = (outs AnyVector:$result); + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($result) }]; @@ -681,7 +754,7 @@ def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { let arguments = (ins); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); } def TPU_LogOp : TPU_Op<"log"> { @@ -692,6 +765,7 @@ def TPU_LogOp : TPU_Op<"log"> { ); let results = (outs); let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; + let hasVerifier = 1; } def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { @@ -706,7 +780,10 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun } def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> { - let options = [Option<"serialize", "serialize", "bool", "", "">]; + let options = [ + Option<"serialize", "serialize", "bool", "", "">, + Option<"target_version", "target-version", "int", "", ""> // Only used when serialize=true. + ]; } def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { @@ -784,6 +861,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, Option<"vmem_banks", "vmem-banks", "int", /*default=*/"-1", "">, + Option<"max_shuffle_sublane_offset", "max-shuffle-sublane-offset", "int", /*default=*/"-1", "Max sublane offset per shuffled load/store">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index df00093fabe6..92e8953837e3 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -33,6 +34,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "absl/hash/hash.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" @@ -68,6 +70,27 @@ void TPUDialect::initialize() { >(); } +/* static */ std::optional TPUDialect::GetCoreTypeAttr( + Operation *op) { + Attribute attr = op->getAttr(GetCoreTypeKey()); + if (attr == nullptr) { + return std::nullopt; + } + if (!mlir::isa(attr)) { + return std::nullopt; + } + return mlir::cast(attr).getValue(); +} + +FailureOr> GetCoreTypeOfParentFunc(Operation &op) { + mlir::Operation *func_op = op.getParentOfType(); + if (func_op == nullptr) { + return op.emitError() << "Operation " << op.getName() + << " is not inside a func.func"; + } + return TPUDialect::GetCoreTypeAttr(func_op); +} + void VectorLayoutAttr::print(AsmPrinter &printer) const { printer << '<'; printer << getLayout(); @@ -210,4 +233,18 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { return false; } +DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, + bool transpose_lhs, + bool transpose_rhs) { + return tpu::DotDimensionNumbersAttr::get( + builder.getContext(), + /*lhs_contracting_dims=*/{transpose_lhs ? 0 : 1}, + /*rhs_contracting_dims=*/{transpose_rhs ? 1 : 0}, + /*lhs_non_contracting_dims=*/{transpose_lhs ? 1 : 0}, + /*rhs_non_contracting_dims=*/{transpose_rhs ? 0 : 1}, + /*output_dim_order=*/{0, transpose_lhs ? 1 : 0, 1, transpose_rhs ? 0 : 1}, + /*lhs_batch_dims=*/{}, + /*rhs_batch_dims=*/{}); +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index e827faed3d0e..a8569acc6239 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -63,6 +64,7 @@ struct ApplyVectorLayoutContext { std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; int64_t vmem_banks = -1; // -1 means "unspecified". + int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". }; std::pair mightCommunicateBetweenChips(Operation* op); @@ -93,6 +95,10 @@ std::unique_ptr> createDebugAssertInsertionPass(); #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +// Determine the core type of the given op based on the `tpu.core_type` +// annotation of its parent function. +FailureOr> GetCoreTypeOfParentFunc(Operation &op); + // Changes the memory space of the value and propagates it through the program. LogicalResult specializeMemorySpace(TypedValue value, MemorySpace memory_space); @@ -103,6 +109,10 @@ MemRefType getMemRefType(Value value); bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8); +DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, + bool transpose_lhs, + bool transpose_rhs); + #define GEN_PASS_REGISTRATION #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 1d3ea99f4d4c..07e2e3e19197 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -25,9 +28,12 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/IRMapping.h" +#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -87,6 +93,25 @@ LogicalResult MemRefSliceOp::verify() { auto target_type = getType(); auto target_layout = target_type.getLayout(); auto target_memory_space = target_type.getMemorySpace(); + auto indices = getBaseIdx(); + auto slice_shape = getResult().getType().getShape(); + if (!source_type.hasStaticShape()) { + return emitOpError( + "Only slicing of memrefs with static shapes is supported."); + } + auto source_shape = source_type.getShape(); + bool is_semaphore = + HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem); + if (is_semaphore && + !isa(source_type.getElementType())) { + return emitOpError( + "References to semaphore memory space must have a semaphore element " + "type."); + } + if (indices.size() != slice_shape.size() || + indices.size() != source_shape.size()) { + return emitOpError("Indices and slice shapes must match."); + } // TODO(apaszke): Check that the result has a smaller shape. // TODO(apaszke): Check that strides are equivalent. // Source and target attributes may be different before propagation is done by @@ -437,6 +462,31 @@ LogicalResult StridedStoreOp::verify() { getValueToStore().getType()); } +LogicalResult VectorStoreOp::verify() { + if (!getStrides().empty()) { + return emitError("Not implemented: general vector store with strides."); + } + VectorType value_ty = getValueToStore().getType(); + MemRefType ref_ty = getBase().getType(); + + if (value_ty.getElementType() != ref_ty.getElementType()) { + return emitOpError( + "Expected base and valueToStore element type should match"); + } + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " indices"; + } + if (getMask()) { + if (value_ty.getElementTypeBitWidth() != 32) { + return emitError( + "Not implemented: masked store with non-32-bit element type"); + } + if (value_ty.getShape() != getMask().getType().getShape()) + return emitOpError("Expected valueToStore shape to match mask shape"); + } + return success(); +} + LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); @@ -465,7 +515,7 @@ LogicalResult verifyRotateOp(Op op) { } if (op.getStride().has_value() != op.getStrideDimension().has_value()) { op.emitOpError( - "Expected either none or both stride and stride dimension are " + "Expected either none or both stride and stride dimension are " "present"); return failure(); } @@ -507,6 +557,289 @@ class CanonicalizeAddOfMatmul : public OpRewritePattern { } }; +LogicalResult MatmulOp::verify() { + // Note - this is not yet an exhaustive verification of matmul. Many of the + // invariants are spread across infer, apply, llo and below. This is, + // however, a good start and the recommended place to add more invariants. + const VectorType lhs_ty = getLhs().getType(); + const VectorType rhs_ty = getRhs().getType(); + const VectorType acc_ty = getAcc().getType(); + const VectorType res_ty = getResult().getType(); + if (acc_ty != res_ty) { + return emitOpError( + "Not implemented: matmul acc and result have different types"); + } + if (acc_ty.getElementTypeBitWidth() != 32) { + return emitOpError("Expected matmul acc to be 32-bit"); + } + + if (getTransposeLhs()) { + emitOpError( + "Lhs transpose not supported via this API - please use the " + "dimension numbers API."); + return failure(); + } + + if (getDimensionNumbers().has_value()) { + auto dimension_numbers = getDimensionNumbers().value(); + auto lhs_contracting_dims = dimension_numbers.getLhsContractingDims(); + auto rhs_contracting_dims = dimension_numbers.getRhsContractingDims(); + if (lhs_contracting_dims.size() != 1) { + emitOpError("Not implemented: lhs contracting dims must be of size 1"); + return failure(); + } + if (rhs_contracting_dims.size() != 1) { + emitOpError("Not implemented: rhs contracting dims must be of size 1"); + return failure(); + } + + auto lhs_contracting_dim = lhs_contracting_dims[0]; + auto rhs_contracting_dim = rhs_contracting_dims[0]; + + auto lhs_batch_dims = dimension_numbers.getLhsBatchDims(); + auto rhs_batch_dims = dimension_numbers.getRhsBatchDims(); + + auto lhs_non_contracting_dims = + dimension_numbers.getLhsNonContractingDims(); + auto rhs_non_contracting_dims = + dimension_numbers.getRhsNonContractingDims(); + + if (lhs_contracting_dims.size() + lhs_non_contracting_dims.size() + + lhs_batch_dims.size() != + lhs_ty.getShape().size()) { + emitOpError( + "Not implemented: lhs contracting + non contracting + batch dims " + "must be of the same size as the lhs shape"); + return failure(); + } + if (rhs_contracting_dims.size() + rhs_non_contracting_dims.size() + + rhs_batch_dims.size() != + rhs_ty.getShape().size()) { + emitOpError( + "Not implemented: rhs contracting + non contracting + batch dims " + "must be of the same size as the rhs shape"); + return failure(); + } + + if (lhs_ty.getShape()[lhs_contracting_dim] != + rhs_ty.getShape()[rhs_contracting_dim]) { + emitOpError( + "Not implemented: lhs and rhs contracting dims must be of the same " + "size"); + return failure(); + } + + if (lhs_batch_dims.size() != rhs_batch_dims.size()) { + emitOpError( + "Not implemented: lhs and rhs should have the same number of batch " + "dims"); + return failure(); + } + if (lhs_batch_dims.size() > 1) { + emitOpError("Not implemented: Up to 1 batch dim supported"); + return failure(); + } + + int64_t lhs_rank = lhs_ty.getShape().size(); + int64_t rhs_rank = rhs_ty.getShape().size(); + + std::vector seen_dims_lhs(lhs_rank, false); + std::vector seen_dims_rhs(rhs_rank, false); + + auto check_and_mark_dims = [&](const std::vector &dims, + std::vector &seen_dims, + const std::string_view operand) { + for (int64_t dim : dims) { + if (seen_dims[dim]) { + emitOpError("Illegal: Dim ") + << dim << " repeats in dimension numbers of " << operand; + return failure(); + } + seen_dims[dim] = true; + } + return success(); + }; + + if (failed( + check_and_mark_dims(lhs_contracting_dims, seen_dims_lhs, "lhs")) || + failed(check_and_mark_dims(lhs_non_contracting_dims, seen_dims_lhs, + "lhs")) || + failed(check_and_mark_dims(lhs_batch_dims, seen_dims_lhs, "lhs"))) { + return failure(); + } + + if (failed( + check_and_mark_dims(rhs_contracting_dims, seen_dims_rhs, "rhs")) || + failed(check_and_mark_dims(rhs_non_contracting_dims, seen_dims_rhs, + "rhs")) || + failed(check_and_mark_dims(rhs_batch_dims, seen_dims_rhs, "rhs"))) { + return failure(); + } + + for (int64_t dim = 0; dim < lhs_rank; ++dim) { + if (!seen_dims_lhs[dim]) { + emitOpError("Illegal: Dim ") + << dim << " is not seen in lhs dimension numbers"; + return failure(); + } + } + for (int64_t dim = 0; dim < rhs_rank; ++dim) { + if (!seen_dims_rhs[dim]) { + emitOpError("Illegal: Dim ") + << dim << " is not seen in rhs dimension numbers"; + } + } + + const std::optional batch_dim_lhs = + lhs_batch_dims.empty() ? std::nullopt + : std::optional(lhs_batch_dims[0]); + const std::optional batch_dim_rhs = + rhs_batch_dims.empty() ? std::nullopt + : std::optional(rhs_batch_dims[0]); + if (batch_dim_lhs != batch_dim_rhs) { + emitOpError("Not Implemented: batch dims must be equal"); + return failure(); + } + if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) { + emitOpError("Not Implemented: batch dims pos must be 0"); + return failure(); + } + // Invariant above enforces only 1 batch dim atm, and that both are eq + std::optional batch_size = std::nullopt; + if (batch_dim_lhs.has_value()) { + batch_size = lhs_ty.getShape()[batch_dim_lhs.value()]; + auto rhs_batch_size = rhs_ty.getShape()[batch_dim_rhs.value()]; + if (batch_size != rhs_batch_size) { + emitOpError("Not Implemented: batch dims must be equal"); + return failure(); + } + if (batch_size == 0) { + emitOpError("Illegal: batch size must be > 0"); + return failure(); + } + } + auto output_dim_order = dimension_numbers.getOutputDimOrder(); + if (output_dim_order.size() % 2 != 0) { + emitOpError( + "Illegal: output dim order must have an even number of elements."); + return failure(); + } + if (batch_size.has_value()) { + if (output_dim_order[0] != 0 || output_dim_order[1] != 0) { + emitOpError( + "Not implemented: Output with batch size must be the lhs 0 idx for " + "now."); + return failure(); + } + } + + // Invariants above enforce a single batch idx for now, and that it is in + // position 0. Future extensions to this will be to: + // 1. Support multiple batch dims + // 2. Support batch dims in any position in the output dim order + if (lhs_non_contracting_dims.size() != 1) { + emitOpError( + "Not implemented: lhs non contracting dims must be of size 1"); + return failure(); + } + if (rhs_non_contracting_dims.size() != 1) { + emitOpError( + "Not implemented: rhs non contracting dims must be of size 1"); + return failure(); + } + + // A bit long winded, but the invariants we enforce below are: + // 1. The output order idx is 0 (lhs) or 1 (rhs) + // 2. The output dim order is in valid bounds + // 3. We saw the rhs and lhs non contracting dims in the output dim order + // 4. We never see the contracting dims in the output dim order + // 5. We only see each of the non contracting dim once + std::vector lhs_dims_seen_in_output(lhs_rank, false); + std::vector rhs_dims_seen_in_output(rhs_rank, false); + + // Iterate over the output dimension order + for (int dim_pos = 0; dim_pos < output_dim_order.size(); dim_pos += 2) { + auto idx = output_dim_order[dim_pos]; + auto dim = output_dim_order[dim_pos + 1]; + + if (idx != 0 && idx != 1) { + emitOpError("Illegal: output dim order index must be 0 or 1"); + return failure(); + } + auto is_lhs = (idx == 0); + + if (is_lhs) { + if (dim < 0 || dim >= lhs_rank) { + emitOpError("Illegal: lhs dimension index out of bounds"); + return failure(); + } + if (lhs_dims_seen_in_output[dim]) { + emitOpError("Illegal: lhs dimension ") + << dim << " appears more than once in output dim order"; + return failure(); + } + if (dim == lhs_contracting_dim) { + emitOpError("Illegal: contracting dimension ") + << dim << " appears in lhs output dim order"; + return failure(); + } + // batch_dim_lhs is either 0 or nullopt + if (dim == batch_dim_lhs) { + // Upstream invariants enforce that batch dim is in position 0 + // of the output dim order. + rhs_dims_seen_in_output[dim] = true; + } + lhs_dims_seen_in_output[dim] = true; + } else { + if (dim < 0 || dim >= rhs_rank) { + emitOpError("Illegal: rhs dimension index out of bounds"); + return failure(); + } + if (rhs_dims_seen_in_output[dim]) { + emitOpError("Illegal: rhs dimension ") + << dim << " appears more than once in output dim order"; + return failure(); + } + if (dim == rhs_contracting_dim) { + emitOpError("Illegal: contracting dimension ") + << dim << " appears in rhs output dim order"; + return failure(); + } + if (dim == batch_dim_rhs) { + // Upstream invariants enforce that batch dim is in position 0 + // of the output dim order. + lhs_dims_seen_in_output[dim] = true; + } + rhs_dims_seen_in_output[dim] = true; + } + } + + // Check that all dims have been seen (except contracting dims) + for (int i = 0; i < lhs_rank; ++i) { + if (i == lhs_contracting_dim) { + continue; + } + if (!lhs_dims_seen_in_output[i]) { + emitOpError("Illegal: lhs non-contracting dimension ") + << i << " is not seen in output dim order"; + return failure(); + } + } + + for (int i = 0; i < rhs_rank; ++i) { + if (i == rhs_contracting_dim) { + continue; + } + if (!rhs_dims_seen_in_output[i]) { + emitOpError("Illegal: rhs non-contracting dimension ") + << i << " is not seen in output dim order"; + return failure(); + } + } + } + return success(); +} + void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add, @@ -535,11 +868,50 @@ LogicalResult GetBarrierSemaphoreOp::verify() { return success(); } +void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, + Value semaphore, Value amount, Value device_id, + Value core_id) { + build(builder, state, semaphore, amount, device_id, core_id, + /*core_type=*/nullptr); +} + LogicalResult SemaphoreSignalOp::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { return emitOpError("Semaphore reference must be rank 0"); } + + FailureOr> issuing_core_type_maybe = + GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core_type_maybe)) { + return issuing_core_type_maybe; + } + CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); + CoreType target_core_type = getCoreType().value_or(issuing_core_type); + + if (getCoreId() == nullptr && getDeviceId() == nullptr) { + if (target_core_type != issuing_core_type) { + return emitOpError( + absl::StrFormat("Target core type (%s) must match source core type " + "(%s) when device_id and core_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); + } + } + if ((issuing_core_type == CoreType::kTc && + target_core_type == CoreType::kScScalarSubcore) || + (issuing_core_type == CoreType::kScScalarSubcore && + target_core_type == CoreType::kTc)) { + return emitOpError("Signalling between TC and SC is not implemented"); + } + return success(); +} + +LogicalResult SemaphoreWaitOp::verify() { + auto sem_type = getMemRefType(getSemaphore()); + if (sem_type.getRank() != 0) { + return emitOpError("Semaphore reference must be rank 0"); + } return success(); } @@ -691,6 +1063,30 @@ LogicalResult ConcatenateOp::verify() { return success(); } +LogicalResult LogOp::verify() { + FailureOr> logging_core_type_maybe = + GetCoreTypeOfParentFunc(**this); + if (failed(logging_core_type_maybe)) { + return failure(); + } + CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc); + if ((logging_core_type == CoreType::kScScalarSubcore || + logging_core_type == CoreType::kScVectorSubcore) && + getFormattedAttr() != nullptr && getFormattedAttr().getValue()) { + return emitOpError("Formatted logging is not supported on SC"); + } + switch (logging_core_type) { + case CoreType::kTc: + case CoreType::kScScalarSubcore: + return success(); + case CoreType::kScVectorSubcore: + return emitOpError("Log op is not supported on the SC vector subcore"); + } + return emitOpError( + absl::StrFormat("Unexpected core type: %s", + stringifyCoreType(logging_core_type_maybe->value()))); +} + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 3a9a36f6c0b7..5c9b3d178c15 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -13,16 +13,15 @@ #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -52,7 +51,6 @@ #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" -#include "llvm/ADT/ArrayRef.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" @@ -61,6 +59,7 @@ #include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" @@ -170,25 +169,6 @@ FailureOr> getInternalScratch( .getResult(); } -// Models Numpy's np.repeat, repeating each element `repeats` times along the -// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is -// 3, this will return [1, 1, 1, 2, 2, 2]. -xla::Array repeat(const xla::Array &src, const int repeats, - const int64_t axis) { - SmallVector dims(toArrayRef(src.dimensions())); - dims[axis] *= repeats; - xla::Array res(dims); - src.Each([&](absl::Span idx, const Value v) { - SmallVector res_idx(toArrayRef(idx)); - res_idx[axis] *= repeats; - for (int i = 0; i < repeats; ++i) { - res(res_idx) = v; - ++res_idx[axis]; - } - }); - return res; -} - // Models Numpy's np.concatenate xla::Array concatenate(const ArrayRef> arrays, const int64_t axis) { @@ -675,6 +655,105 @@ FailureOr> getInLayouts( return in_layouts; } +// Insert a minor dimension to the implicit shape. The original minor dimension +// becomes the new second minor dimension, laid out across sublanes. +// +// The returned vreg array uses the original tiling and the offsets specified in +// new_offsets to hold the value with the new implicit shape. +// +// Args: +// vregs: The vreg array with *implicit* array shape. +// ishape: The implicit shape of the represented value. +// layout: The layout used for the represented value. The implicit +// dimension is ignored, since this function operates directly at +// the level of the implicit shape. +// new_offsets: The offsets to use for the layout of the returned vreg array. +FailureOr> insertImplicitMinorDimension( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + const xla::Array &vregs, const ArrayRef ishape, + const VectorLayout &layout, const LayoutOffsets new_offsets) { + if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) { + return emitError(loc, "Not implemented: Unsupported bitwidth or tiling"); + } + if (layout.offsets()[1].has_value()) { + if (!new_offsets[0]) { + // TODO(tlongeri): This can only be valid if the dim size is 1. + return emitError(loc, "Not implemented: Replication mismatch"); + } + if (*new_offsets[0] != *layout.offsets()[1] % ctx.target_shape[0] && + *layout.offsets()[1] + *(ishape.end() - 1) > ctx.target_shape[1]) { + // This requires blending data from different vregs. + return emitError(loc, + "Not implemented: Misaligned offsets and shape does not " + "fit in one vreg"); + } + } + // new_layout is only to get the new vreg array shape, the implicit dim is + // irrelevant (since we already have the implicit shape): + const VectorLayout new_layout(layout.bitwidth(), new_offsets, layout.tiling(), + VectorLayout::ImplicitDim::kNone); + SmallVector new_ishape(ishape); + new_ishape.push_back(1); + xla::Array new_vregs(new_layout.tileArrayShape( + /*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(new_ishape), + ctx.target_shape)); + // Preallocate an indices vector to avoid repeated allocations: + SmallVector idxs; + new_vregs.Each([&](const absl::Span dst_idx, + Value *const dst_vreg) { + // Indices of the new vreg in the new vreg array: + const int64_t new_2nd_minor_idx = *(dst_idx.end() - 2); + const int64_t new_3rd_minor_idx = *(dst_idx.end() - 3); + idxs.assign(dst_idx.begin(), dst_idx.end()); + if (!layout.offsets()[0].has_value() && new_3rd_minor_idx != 0) { + // All vregs along that dimension are the same + *(idxs.end() - 3) = 0; + *dst_vreg = new_vregs(idxs); + } else if (!layout.offsets()[1].has_value() && new_2nd_minor_idx != 0) { + // All vregs along that dimension are the same + *(idxs.end() - 2) = 0; + *dst_vreg = new_vregs(idxs); + } else { + // dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])] + // of the after-offsets source shape + const int64_t row_idx = + layout.offsets()[0] ? new_3rd_minor_idx + *layout.offsets()[0] : 0; + const int64_t col_idx = layout.offsets()[1] + ? new_2nd_minor_idx * ctx.target_shape[0] + + *layout.offsets()[1] - *new_offsets[0] + : 0; + + idxs.pop_back(); + *(idxs.end() - 2) = row_idx / ctx.target_shape[0]; + *(idxs.end() - 1) = col_idx / ctx.target_shape[1]; + Value src_vreg = vregs(idxs); + // TODO(tlongeri): We can sometimes skip operations when dst_vreg will + // hold a single non-padding element (first or last) and we don't need + // replication in the output. + if (layout.offsets()[0].has_value()) { + // [ . . . . . . . . ] [ . . . . a b c d ] + // [ . . . . a b c d ] => [ . . . . a b c d ] + // [ . . . . . . . . ] [ . . . . a b c d ] + // [ . . . . . . . . ] [ . . . . a b c d ] + src_vreg = broadcastSublane( + builder, src_vreg, + /*sublane_idx=*/row_idx % ctx.target_shape[0], ctx.target_shape); + } + if (layout.offsets()[1].has_value()) { + // [ . . . . a b c d ] [ a a a a a a a a ] + // [ . . . . a b c d ] => [ b b b b b b b b ] + // [ . . . . a b c d ] [ c c c c c c c c ] + // [ . . . . a b c d ] [ d d d d d d d d ] + src_vreg = builder.create( + loc, src_vreg.getType(), src_vreg, + /*lane=*/col_idx % ctx.target_shape[1]); + } + *dst_vreg = src_vreg; + } + }); + return new_vregs; +} + LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -771,20 +850,13 @@ FailureOr> ext_op_rule_impl(RewriteContext &ctx, const VectorLayout &layout_out) { const auto result_ty = cast(op.getResult().getType()); auto source = cast>(op.getIn()); - const auto source_ty = source.getType(); auto output_vregs_shape = - layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape); + layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( xla::Array input_vregs, - disassemble(builder, layout_in, source, ctx.target_shape)); + disassemble(builder, layout_in, source, ctx.target_shape, + /*use_implicit_shape=*/true)); xla::Array output_vregs(output_vregs_shape); - // TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble? - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(), - ctx.target_shape)); - output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), - ctx.target_shape)); - } const VectorType res_vreg_ty = getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_in.implicit_dim() != layout_out.implicit_dim()) { @@ -821,9 +893,6 @@ FailureOr> ext_op_rule_impl(RewriteContext &ctx, op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); }); } - if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - output_vregs.Reshape(output_vregs_shape); - } return output_vregs; } @@ -846,8 +915,9 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, *layouts_out.front())); const auto result_ty = cast(extf_op.getResult().getType()); extf_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape) - .getResult()); + std::move(output_vregs), ctx.target_shape, + /*use_implicit_shape=*/true) + .getResult()); extf_op.erase(); return success(); } @@ -867,8 +937,10 @@ LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, *layouts_out.front())); const auto result_ty = cast(extsi_op.getResult().getType()); extsi_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape) - .getResult()); + std::move(output_vregs), + ctx.target_shape, + /*use_implicit_shape=*/true) + .getResult()); extsi_op.erase(); return success(); } @@ -919,8 +991,10 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op, *v = builder.create(op.getLoc(), res_vreg_ty, unpacked); }); extui_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape) - .getResult()); + std::move(output_vregs), + ctx.target_shape, + /*use_implicit_shape=*/true) + .getResult()); extui_op.erase(); return success(); } @@ -931,13 +1005,13 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_out) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); auto source = cast>(op.getIn()); - const auto source_ty = source.getType(); auto result_ty = cast(op.getResult().getType()); auto output_vregs_shape = - layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape); + layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( xla::Array input_vregs, - disassemble(builder, layout_in, source, ctx.target_shape)); + disassemble(builder, layout_in, source, ctx.target_shape, + /*use_implicit_shape=*/true)); xla::Array output_vregs(output_vregs_shape); if (layout_in.bitwidth() != 32) { return op.emitOpError("Not implemented: Only 32-bit truncation supported"); @@ -952,12 +1026,6 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, if (layout_in.tiling() != ctx.target_shape) { return op.emitOpError("Not implemented: Only (8,128) tiling supported"); } - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(), - ctx.target_shape)); - output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), - ctx.target_shape)); - } VectorType res_vreg_ty = getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_out.tiling() == ctx.target_shape) { @@ -1002,11 +1070,9 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, } else { return op.emitOpError("Not implemented: unsupported output tiling"); } - if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - output_vregs.Reshape(output_vregs_shape); - } op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, - std::move(output_vregs), ctx.target_shape) + std::move(output_vregs), ctx.target_shape, + /*use_implicit_shape=*/true) .getResult()); op.erase(); return success(); @@ -1715,15 +1781,36 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); })); TPU_ASSERT_OP(layouts_out.front().has_value()); auto matmul_op = cast(op); - const auto transpose_lhs = matmul_op.getTransposeLhs(); - const auto transpose_rhs = matmul_op.getTransposeRhs(); - const auto &layout_lhs = *layouts_in[0]; - const auto &layout_rhs = *layouts_in[1]; - const auto &layout_acc = *layouts_in[2]; - const auto &layout_out = *layouts_out[0]; + if (matmul_op.getTransposeRhs()) { + return op.emitOpError( + "Transposition must have been erased into dimension numbers during " + "canonicalization"); + } + + auto dimension_numbers = matmul_op.getDimensionNumbers(); + if (!dimension_numbers.has_value()) { + return op.emitOpError( + "Dimension numbers must be provided, ensure canonicalization has been " + "run."); + } + auto transposed_mkn = isTransposedMatmul(dimension_numbers.value()); + if (!transposed_mkn.has_value()) { + return op.emitOpError( + "Dimension numbers must be MKN, ensure canonicalization has been " + "run."); + } + auto [transpose_lhs, transpose_rhs] = transposed_mkn.value(); if (transpose_lhs) { - return op.emitOpError("Not implemented: Transposed LHS"); + return op.emitOpError( + "Transposition of LHS is not supported in apply_vector_layout, ensure " + "canonicalization has been run."); } + + auto &layout_lhs = *layouts_in[0]; + auto &layout_rhs = *layouts_in[1]; + auto &layout_acc = *layouts_in[2]; + auto &layout_out = *layouts_out[0]; + const std::array, 4> all_layouts = {layout_lhs, layout_rhs, layout_acc, layout_out}; for (const VectorLayout &layout : all_layouts) { @@ -1763,19 +1850,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // TODO(tlongeri): This should be part of the tpu::MatmulOp verifier TPU_ASSERT_EQ_OP(lhs_shape.size(), 2); TPU_ASSERT_EQ_OP(rhs_shape.size(), 2); - // The code below puts no constraints on the second dimension of both lhs and - // rhs. However, leading axis of lhs and rhs needs to be a multiple of native - // tiling for packed types. - if (layout_lhs.packing() != 1 && lhs_shape[0] % layout_lhs.tiling()[0] != 0) { - return op.emitOpError( - "Not implemented: Unsupported LHS shape with padded tiling and " - "narrower data type"); - } - if (layout_rhs.packing() != 1 && rhs_shape[0] % layout_rhs.tiling()[0] != 0) { - return op.emitOpError( - "Not implemented: Unsupported RHS shape with padded tiling and " - "narrower data type"); - } const int64_t padded_lhs_rows = llvm::alignTo(lhs_shape[0], layout_lhs.tiling()[0]); @@ -1786,10 +1860,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, const int64_t padded_rhs_cols = llvm::alignTo(rhs_shape[1], layout_rhs.tiling()[1]); - if (llvm::alignTo(lhs_shape[0], layout_acc.tiling()[0]) != padded_lhs_rows) { - return op.emitOpError( - "Not implemented: Matmul acc requires less padding than lhs"); - } FAILUREOR_ASSIGN_OR_RETURN( xla::Array lhs_vregs, disassemble(builder, layout_lhs, lhs, ctx.target_shape)); @@ -1800,7 +1870,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, xla::Array rhs_vregs, disassemble(builder, layout_rhs, rhs, ctx.target_shape)); TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]); - TPU_ASSERT_EQ_OP(padded_lhs_rows, acc_vregs.dim(0) * layout_acc.tiling()[0]); TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]); const VectorType i32_vreg_ty = @@ -1822,12 +1891,14 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // We can also extend this helper function with padding_top and padding_left // based on the offsets in vregs. - // TODO(b/341729764): Support mask subelements. + const Value i32_zeros_vreg = builder.create( + op.getLoc(), + DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0))); + const Value i32_max_vreg = builder.create( + op.getLoc(), DenseElementsAttr::get( + i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff))); auto maskVregs = [&](xla::Array &vregs, int64_t padding_bottom, int64_t padding_right) { - const Value i32_zeros_vreg = builder.create( - op.getLoc(), - DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0))); auto vreg_ty = cast(vregs.begin()->getType()); int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1; // Mask out the bottom. @@ -1835,14 +1906,49 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // We have limited the row size of LHS and RHS need to be a multiple of // native tiling at the beginning of this rule. Therefore, it is safe to // bitcast to x32 vreg for masking. - CHECK_EQ(padding_bottom % packing, 0); - padding_bottom /= packing; - auto mask_bottom = getX32VmaskByPaddingEnd(0, padding_bottom); + int sub_padding = padding_bottom % packing; + int x32_padding_bottom = padding_bottom / packing; + auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom); + // Create an int32 vreg which contains subelement masking and then + // logical_and with target vreg to mask out the unaligned paddings. + // Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is + // [8, 128], then the mask will be: + // + // sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff] + // sublane 6: [0 , 0 , ..., 0 ] + // sublane 7: [0 , 0 , ..., 0 ] + // + // Through this way, in order to mask sub-elements, each target vreg only + // needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + // + packing). + Value partial_sublane_mask = builder.create( + op.getLoc(), + DenseElementsAttr::get( + i32_vreg_ty, + builder.getI32IntegerAttr( + 0xffffffff >> + (sub_padding * vreg_ty.getElementTypeBitWidth())))); + // Insert 0xffffffff above the blended sublane. + Value sublane_mask = builder.create( + getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg, + partial_sublane_mask); + // Insert 0 below the blended sublane. + sublane_mask = builder.create(mask_bottom, sublane_mask, + i32_zeros_vreg); for (int64_t i = 0; i < vregs.dim(1); ++i) { Value &vreg = vregs({vregs.dim(0) - 1, i}); Value i32_vreg = builder.create(i32_vreg_ty, vreg); - i32_vreg = builder.create(mask_bottom, i32_vreg, - i32_zeros_vreg); + if (sub_padding > 0) { + i32_vreg = builder.create(i32_vreg, sublane_mask); + } else { + i32_vreg = builder.create(mask_bottom, i32_vreg, + i32_zeros_vreg); + } vreg = builder.create(vreg_ty, i32_vreg); } } @@ -1928,8 +2034,9 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, lhs_zeros_vreg); xla::Array target_rhs_vregs( {target_rhs_row_vregs, target_rhs_col_vregs}, rhs_zeros_vreg); - xla::Array target_acc_vregs({acc_vregs.dim(0), target_acc_col_vregs}, - acc_zeros_vreg); + xla::Array target_acc_vregs( + {lhs_vregs.dim(0) * layout_lhs.packing(), target_acc_col_vregs}, + acc_zeros_vreg); target_lhs_vregs.UpdateSlice(lhs_vregs, {0, 0}); target_rhs_vregs.UpdateSlice(rhs_vregs, {0, 0}); target_acc_vregs.UpdateSlice(acc_vregs, {0, 0}); @@ -1984,6 +2091,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, const tpu::ContractPrecisionAttr precision_attr = // May be null op.getAttrOfType("precision"); + const tpu::DotDimensionNumbersAttr dot_dimension_numbers_attr = + defaultDimensionNumbers(builder, false, transpose_rhs); for (int64_t j = 0; j < nj; ++j) { for (int64_t k = 0; k < nk; ++k) { // TODO(tlongeri): there should be a way to slice without copying @@ -2000,7 +2109,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, acc_col->setAttr("out_layout", acc_layout_attr); auto new_acc_col = builder.create( op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col, - transpose_lhs, transpose_rhs, precision_attr); + /*transpose_lhs=*/false, /*transpose_rhs=*/false, precision_attr, + dot_dimension_numbers_attr); auto new_acc_vregs = builder.create( op.getLoc(), TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))), @@ -2550,7 +2660,10 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(res_layout.has_value()); auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank(); - if (dimension >= num_untiled_dims) { + if (res_ty.getRank() == 1 && + res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) { + tiling_dim = 1; + } else if (dimension >= num_untiled_dims) { tiling_dim = dimension - num_untiled_dims; } @@ -2572,6 +2685,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: result/input offsets mismatch."); } + if (layout.implicit_dim() != res_layout->implicit_dim()) { + return op.emitOpError( + "Not implemented: result/input implicit dim mismatch."); + } + if (i > 1) { auto curr_offsets = layout.offsets(); auto last_operand_offsets = layouts_in[i - 1]->offsets(); @@ -2607,29 +2725,47 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, if (!tiling_dim.has_value()) { out_vregs = concatenate(operand_vregs, dimension); } else { - if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { + bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 && + res_layout->implicit_dim() == + VectorLayout::ImplicitDim::kNone; + if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor || + is_rank1_with_no_implicit_dim) { return op.emitOpError("Not implemented: implicit dim"); } + if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && + res_layout->bitwidth() != 32) { + return op.emitOpError( + "Not implemented: only 32-bit bitwidth supported for SecondMinor " + "implicit dim"); + } if (res_layout->offsets()[tiling_dim.value()] != 0) { return op.emitOpError("Not implemented: result non-zero offset."); } - if (!res_layout->hasNativeTiling(ctx.target_shape)) { + if (!res_layout->hasNativeTiling(ctx.target_shape) && + res_ty.getRank() != 1) { return op.emitOpError("Not implemented: Non native tiling in concat."); } int64_t offset_at_dim = 0; { for (int i = 0; i < op.getNumOperands(); ++i) { - auto operand = op.getOperand(i); - auto const &layout = *layouts_in[i]; - - auto vty = cast(operand.getType()); - auto shape = vty.getShape(); - - auto starting_point = offset_at_dim; - auto offset_amount = - starting_point % layout.tiling()[tiling_dim.value()]; - if (offset_amount != layout.offsets()[tiling_dim.value()]) { + Value operand = op.getOperand(i); + const Layout &layout = *layouts_in[i]; + xla::Array vreg_array = operand_vregs[i]; + std::array vreg_slice = layout->vregSlice(ctx.target_shape); + std::array tiling = layout->tiling(); + + VectorType vty = cast(operand.getType()); + ArrayRef shape = vty.getShape(); + + int64_t starting_point = offset_at_dim; + int64_t offset_amount = + starting_point % vreg_slice[tiling_dim.value()]; + if (offset_amount >= tiling[tiling_dim.value()]) { + return op.emitError( + "Not implemented: Input offsets outside of the first tile"); + } + if (offset_amount != layout->offsets()[tiling_dim.value()]) { return op.emitOpError( "Not implemented: Relayout not called, unaligned dims " "concatenated without proper offsets. Ensure that " @@ -2644,9 +2780,12 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, for (size_t i = 0; i < operand_vregs.size(); ++i) { auto &vreg = operand_vregs[i]; const auto &layout = layouts_in[i]; + const int packing = res_layout->packing(); - if (layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Not implemented: implicit dim"); + if (layout->tiling()[0] % packing != 0) { + return op.emitOpError( + "Illegal tiling: Non-native tiling in concat - this should " + "have been caught earlier!"); } const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; @@ -2659,8 +2798,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, } const auto bitwidth = res_ty.getElementTypeBitWidth(); - const int packing = res_layout->packing(); - SmallVector out_idx; vreg.Each([&](absl::Span idx, Value *v) { out_idx.assign(idx.begin(), idx.end()); @@ -2670,17 +2807,29 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const VectorType vmask_ty = getNativeVregOrVmaskType( builder.getI1Type(), bitwidth, ctx.target_shape); if (tiling_dim.value() == 0) { // sublane - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(operand_offset * packing), - boundIdxConst(layout->tiling()[1])}); + if (operand_offset % packing != 0) { + // Packed case, degenerate where we have a half or quarter + // sublane. + // TODO(mvoz): We can probably always use the + // CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add + // support for unpacked types in some of the invariants in + // lower_to_llo. + mask = builder.create( + op.getLoc(), vmask_ty, 0, operand_offset, packing); + } else { + auto sublane_offset = operand_offset / packing; + mask = builder.create( + op.getLoc(), vmask_ty, + ArrayRef{boundIdxConst(0), boundIdxConst(0)}, + ArrayRef{boundIdxConst(sublane_offset), + boundIdxConst(layout->tiling()[1])}); + } } else { // lane mask = builder.create( op.getLoc(), vmask_ty, ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(layout->tiling()[0]), - boundIdxConst(operand_offset * packing)}); + ArrayRef{boundIdxConst(layout->tiling()[0] / packing), + boundIdxConst(operand_offset)}); } // Blend the current value with the existing value in the output. *v = builder.create(op.getLoc(), mask, @@ -2949,48 +3098,6 @@ LogicalResult tpu_region_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Not implemented: Only 2D layouts supported"); - } - if (layout_in != layout_out) { - return op.emitOpError("Not implemented: Changing layout mid-repeat"); - } - if (!layout_in.hasNaturalTopology(ctx.target_shape) || - layout_in.offsets() != LayoutOffsets{0, 0}) { - return op.emitOpError("Not implemented: Non-trivial layouts unsupported"); - } - OpBuilder builder(&op); - tpu::RepeatOp repeat_op = cast(op); - VectorType src_ty = repeat_op.getSource().getType(); - const uint32_t dim = repeat_op.getDimension(); - if (dim != src_ty.getRank() - 1) { - return op.emitOpError( - "Not implemented: Only repeats along the last dim supported"); - } - if (src_ty.getShape().back() % ctx.target_shape.back() != 0) { - return op.emitOpError("Not implemented: Only free repeats are suppported"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array &in_vregs, - disassemble(builder, layout_in, repeat_op.getSource(), ctx.target_shape)); - xla::Array out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim); - repeat_op->replaceAllUsesWith( - assemble(builder, repeat_op.getResult().getType(), layout_out, out_vregs, - ctx.target_shape) - .getOperation()); - repeat_op->erase(); - return success(); -} - LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -3020,14 +3127,29 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( Tiling memref_tiling, getMemRefTiling(load_op.getBase(), ctx.target_shape)); - if (memref_tiling != layout_out.tiling() && - !(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && - memref_tiling[1] % layout_out.tiling()[1] == 0)) { - // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). - // TODO(b/295393167): need to support strided load for bitwidth < 32. - if (layout_out.bitwidth() != 32 || - layout_out.tiling() != std::array{1, ctx.target_shape[1]}) { - return op.emitOpError("Not implemented"); + if (memref_tiling != layout_out.tiling()) { + if (memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && + memref_tiling[1] % layout_out.tiling()[1] == 0) { + // In this case, it is valid to use output tiling (1, 128 * packing) when + // loading from a 1D memref. + } else if (layout_out.bitwidth() == 32 && + layout_out.tiling() == + std::array{1, ctx.target_shape[1]}) { + // In this case, it is valid to use output tiling (1, TARGET_SHAPE.lanes) + // because we strided-load one row from each tile of the memref. This can + // save us a bunch of loads! + // TODO(b/295393167): need to support strided load for bitwidth < 32. + } else if (layout_out.bitwidth() == 32 && + canReinterpretToUntiledMemref( + load_op.getBase(), ctx.target_shape, + /*allow_minormost_padding=*/true)) { + // In this case, if the memref can be reinterpreted to untiled, it is + // valid to use any tiling for output. But using native tiling can save us + // a bunch of loads! + } else { + return op.emitOpError( + "Not implemented: dismatch in memref tiling and vector tiling in " + "load"); } } // TODO(apaszke): Check that loads are from vmem! @@ -3213,8 +3335,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, } const VectorLayout &layout_out = *layouts_out.front(); DenseElementsAttr value = cast(constant_op.getValue()); - const VectorType target_vty = - getNativeVregType(vty.getElementType(), ctx.target_shape); + const VectorType target_vty = getNativeVregOrVmaskType( + vty.getElementType(), layout_out.bitwidth(), ctx.target_shape); if (value.isSplat()) { if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { return op.emitOpError( @@ -4097,6 +4219,13 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, dst_tiled_dims[1] % dst_vreg_slice[1] == 0) { // Shapecast (..., 128) -> (..., m * 128 * packing). no_op = true; + } else if (layout_in.offsets() == LayoutOffsets{0, 0} && + layout_out.offsets() == LayoutOffsets{0, 0} && + layout_in.tiling()[0] == 1 && layout_out.tiling()[0] == 1 && + src_vreg_slice[1] == dst_vreg_slice[1] && + src_tiled_dims[1] % src_vreg_slice[1] == 0 && + dst_tiled_dims[1] % dst_vreg_slice[1] == 0) { + no_op = true; } FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_vregs, @@ -4112,54 +4241,16 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, layout_in.bitwidth() == 32 && layout_in.hasNativeTiling(ctx.target_shape) && layout_in.tiling() == layout_out.tiling() && - layout_in.offsets()[0].value_or(0) == 0 && - layout_in.offsets()[1] == 0 && layout_out.offsets()[0] == 0 - // layout_out.offsets[1] can be anything, as we produce a - // replicated result - ) { - // First, insert the new singleton lane dimension. - SmallVector s = layout_in.implicitShape(src_shape); - s.push_back(1); - xla::Array dst_vregs_local(layout_out.tileArrayShape( - /*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(s), - ctx.target_shape)); - TPU_ASSERT_EQ_OP(dst_vregs_local.dimensions().back(), - 1); // We're inserting a singleton dimension - dst_vregs_local.Each( - [&](const absl::Span dst_idx, Value *const dst_vreg) { - const int64_t col_idx = *(dst_idx.end() - 2); - const int64_t row_idx = *(dst_idx.end() - 3); - auto [sublanes_in_lane, rem] = - std::div(ctx.target_shape[1], ctx.target_shape[0]); - CHECK_EQ(rem, 0); - if (!layout_in.offsets()[0].has_value() && row_idx != 0) { - return; // All vregs along that dimension are the same. - } - SmallVector src_idx(toArrayRef(dst_idx)); - src_idx.pop_back(); - *(src_idx.end() - 2) /= ctx.target_shape[0]; - *(src_idx.end() - 1) /= sublanes_in_lane; - Value col_vreg = src_vregs(src_idx); - // BroadcastInSublanesOp requires the sublanes to be replicated. - if (layout_in.offsets()[0].has_value()) { - const int32_t sublane = row_idx % ctx.target_shape[0]; - col_vreg = broadcastSublane(builder, col_vreg, sublane, - ctx.target_shape); - } - *dst_vreg = builder.create( - col_vreg.getType(), col_vreg, - /*lane=*/(col_idx % sublanes_in_lane) * ctx.target_shape[0]); - }); - if (!layout_in.offsets()[0].has_value()) { - // Broadcast the sublane vregs. - // TODO(tlongeri): This could be done more efficiently - dst_vregs_local.Each([&](const absl::Span dst_idx, - Value *const dst_vreg) { - SmallVector first_row_idx(toArrayRef(dst_idx)); - *(first_row_idx.end() - 3) = 0; - *dst_vreg = dst_vregs_local(first_row_idx); - }); - } + (!layout_in.offsets()[1].has_value() || + *layout_in.offsets()[1] % ctx.target_shape[0] == + layout_out.offsets()[0] || + *layout_in.offsets()[1] + src_tiled_dims[1] <= + ctx.target_shape[1])) { + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs_local, + insertImplicitMinorDimension(ctx, builder, op.getLoc(), src_vregs, + layout_in.implicitShape(src_shape), + layout_in, layout_out.offsets())); // Now, reshape the major axes of the vreg array. dst_vregs_local.Reshape( layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); @@ -4177,18 +4268,15 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, shape_cast_op->erase(); return success(); } -LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - MLIRContext *const mlir_ctx = op.getContext(); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), - [&](const Layout &l) { return l.has_value(); })); + +template +LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, + const VectorLayout &to_store_layout, + TypedValue store_mask = nullptr) { + Operation &op = *(store_op.getOperation()); + MLIRContext *const mlir_ctx = store_op.getContext(); ImplicitLocOpBuilder builder(op.getLoc(), &op); - vector::StoreOp store_op = cast(op); const VectorType ty = store_op.getValueToStore().getType(); - const VectorLayout &to_store_layout = *layouts_in.front(); const auto memref_ty = getMemRefType(store_op.getBase()); if (!ty.getRank()) { return op.emitOpError("Not implemented: scalar stores to vmem"); @@ -4204,14 +4292,31 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( const Tiling memref_tiling, getMemRefTiling(store_op.getBase(), ctx.target_shape)); - if (memref_tiling != to_store_layout.tiling() && - !(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && - memref_tiling[1] % to_store_layout.tiling()[1] == 0)) { - // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). - // TODO(b/295393167): need to support strided store for bitwidth < 32. - if (to_store_layout.bitwidth() != 32 || - to_store_layout.tiling() != Tiling{1, ctx.target_shape[1]}) { - return op.emitOpError("Not implemented"); + if (memref_tiling != to_store_layout.tiling()) { + if (memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && + memref_tiling[1] % to_store_layout.tiling()[1] == 0) { + // In this case, it is valid to have to_store tiling (1, 128 * packing) + // when storing to a 1D memref. + } else if (to_store_layout.bitwidth() == 32 && + to_store_layout.tiling() == + std::array{1, ctx.target_shape[1]}) { + // In this case, it is valid to have to_store tiling (1, + // TARGET_SHAPE.lanes) because we strided-store one row to each tile of + // the memref. This can save us a bunch of stores! + // TODO(b/295393167): need to support strided store for bitwidth < 32. + } else if (to_store_layout.bitwidth() == 32 && + // We accept padding in the minormost dim, because + // apply_vector_layout will properly mask stores。 + canReinterpretToUntiledMemref( + store_op.getBase(), ctx.target_shape, + /*allow_minormost_padding=*/true)) { + // In this case, if the memref can be reinterpreted to untiled, it is + // valid to use any tiling for to_store. But using native tiling can save + // us a bunch of stores! + } else { + return op.emitOpError( + "Not implemented: dismatch in memref tiling and vector tiling in " + "store"); } } @@ -4268,10 +4373,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } else { // Convert dynamic store to dynamic slice + static store. This saves us a // bunch of scalar core work. - auto slice_result = - sliceRef(builder, store_op.getBase(), - store_op.getVectorType().getShape(), store_op.getIndices(), - ArrayRef(memref_tiling).take_back(tiled_dims)); + auto slice_result = sliceRef( + builder, store_op.getBase(), ty.getShape(), store_op.getIndices(), + ArrayRef(memref_tiling).take_back(tiled_dims)); if (failed(slice_result)) { return failure(); } @@ -4292,6 +4396,13 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, xla::Array tiles, disassemble(builder, to_store_layout, store_op.getValueToStore(), ctx.target_shape)); + std::optional> tile_masks; + if (store_mask) { + FAILUREOR_ASSIGN_OR_RETURN( + tile_masks, + disassemble(builder, to_store_layout, store_mask, ctx.target_shape)); + TPU_ASSERT_EQ_OP(tile_masks->dimensions(), tiles.dimensions()); + } const int64_t ndims = ty.getRank(); const auto base_s = is_1d ? IdxConst(0, builder, op.getLoc()) : tile_base_idxs.front(); @@ -4313,6 +4424,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, const absl::Status status = tiles.EachStatus([&](const absl::Span idx, const Value tile) -> absl::Status { + const auto tile_mask = store_mask ? (*tile_masks)(idx) : nullptr; const std::unique_ptr bounds = to_store_layout.tileDataBounds(mlir_ctx, stored_shape, toArrayRef(idx), ctx.target_shape); @@ -4372,19 +4484,19 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, updated = builder.create(mask, tile, data); } builder.create( - updated, base_addr, indices, sublane_mask, - /*mask=*/nullptr, + updated, base_addr, indices, sublane_mask, tile_mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } else { builder.create( tile, base_addr, indices, sublane_mask, - /*mask=*/mask, + tile_mask + ? builder.create(mask, tile_mask).getResult() + : mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } } else { builder.create( - tile, base_addr, indices, sublane_mask, - /*mask=*/nullptr, + tile, base_addr, indices, sublane_mask, tile_mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } return absl::OkStatus(); @@ -4394,7 +4506,35 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } store_op->erase(); return success(); +} + +LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto store_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), + [&](const Layout &l) { return l.has_value(); })); + return vector_store_impl(ctx, store_op, *layouts_in.front()); +} + +LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto store_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + TPU_ASSERT_OP(layouts_in.front().has_value()); + auto other_layouts_in = layouts_in.drop_front(); + if (store_op.getMask()) { + TPU_ASSERT_EQ_OP(layouts_in.front(), layouts_in.back()); + other_layouts_in = other_layouts_in.drop_back(); } + TPU_ASSERT_OP(llvm::none_of(other_layouts_in, + [&](const Layout &l) { return l.has_value(); })); + return vector_store_impl(ctx, store_op, *layouts_in.front(), + store_op.getMask()); +} LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, @@ -4584,48 +4724,84 @@ LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op, } const llvm::StringMap &rules() { - static auto rules = new llvm::StringMap{ - {arith::ConstantOp::getOperationName(), arith_constant_rule}, - {arith::ExtFOp::getOperationName(), arith_extf_rule}, - {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, - {arith::ExtUIOp::getOperationName(), arith_extui_rule}, - {arith::TruncFOp::getOperationName(), arith_truncf_rule}, - {arith::TruncIOp::getOperationName(), arith_trunci_rule}, - {func::ReturnOp::getOperationName(), func_return_rule}, - {scf::ForOp::getOperationName(), scf_for_rule}, - {scf::WhileOp::getOperationName(), scf_while_rule}, - {scf::ConditionOp::getOperationName(), scf_condition_rule}, - {scf::IfOp::getOperationName(), scf_if_rule}, - {scf::YieldOp::getOperationName(), yield_rule}, - {tpu::YieldOp::getOperationName(), yield_rule}, - {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, - {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, - {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, - {tpu::IotaOp::getOperationName(), tpu_iota_rule}, - {tpu::GatherOp::getOperationName(), tpu_gather_rule}, - {tpu::LoadOp::getOperationName(), tpu_load_rule}, - {tpu::StoreOp::getOperationName(), tpu_store_rule}, - {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, - {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, - {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, - {tpu::RegionOp::getOperationName(), tpu_region_rule}, - {tpu::RepeatOp::getOperationName(), tpu_repeat_rule}, - {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, - {tpu::TraceOp::getOperationName(), tpu_trace_rule}, - {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, - {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, - {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, - {vector::ExtractOp::getOperationName(), vector_extract_rule}, - {vector::LoadOp::getOperationName(), vector_load_rule}, - {vector::MultiDimReductionOp::getOperationName(), - vector_multi_reduction_rule}, - {vector::ExtractStridedSliceOp::getOperationName(), - vector_extract_strided_slice_rule}, - {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, - {vector::StoreOp::getOperationName(), vector_store_rule}, - {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; + static const llvm::StringMap *rules = [] { + static auto rules = new llvm::StringMap{ + {arith::ConstantOp::getOperationName(), arith_constant_rule}, + {arith::ExtFOp::getOperationName(), arith_extf_rule}, + {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, + {arith::ExtUIOp::getOperationName(), arith_extui_rule}, + {arith::TruncFOp::getOperationName(), arith_truncf_rule}, + {arith::TruncIOp::getOperationName(), arith_trunci_rule}, + {func::ReturnOp::getOperationName(), func_return_rule}, + {scf::ForOp::getOperationName(), scf_for_rule}, + {scf::WhileOp::getOperationName(), scf_while_rule}, + {scf::ConditionOp::getOperationName(), scf_condition_rule}, + {scf::IfOp::getOperationName(), scf_if_rule}, + {scf::YieldOp::getOperationName(), yield_rule}, + {tpu::YieldOp::getOperationName(), yield_rule}, + {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, + {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, + {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, + {tpu::IotaOp::getOperationName(), tpu_iota_rule}, + {tpu::GatherOp::getOperationName(), tpu_gather_rule}, + {tpu::LoadOp::getOperationName(), tpu_load_rule}, + {tpu::StoreOp::getOperationName(), tpu_store_rule}, + {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, + {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, + {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, + {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, + {tpu::RegionOp::getOperationName(), tpu_region_rule}, + {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, + {tpu::TraceOp::getOperationName(), tpu_trace_rule}, + {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, + {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, + {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, + {vector::ExtractOp::getOperationName(), vector_extract_rule}, + {vector::LoadOp::getOperationName(), vector_load_rule}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_reduction_rule}, + {vector::ExtractStridedSliceOp::getOperationName(), + vector_extract_strided_slice_rule}, + {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, + {vector::StoreOp::getOperationName(), vector_store_rule}, + {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; + + llvm::StringMap extended_rules = mlir::tpu::extensions::rules(); + for (auto &entry : extended_rules) { + rules->insert(&entry); + } + return rules; + }(); return *rules; } + +// Determines whether we should handle bank conflict for the given stride and +// max_sublane_offset. +// +// See `handleBankConflict` for how this is done. +bool shouldHandleBankConflict(const ApplyVectorLayoutContext &ctx, + int32_t stride, int max_sublane_offset) { + return ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0] && + ctx.max_shuffle_sublane_offset > 0 && + ctx.max_shuffle_sublane_offset >= max_sublane_offset; +} + +// Handles load/store bank conflict by adding one extra sublane to stride and +// adjusting sublane offsets accordingly. +// +// For example, when store stride is 4 and load sublane offsets are +// [0, 1, 2, 3, 4, 5, 6, 7], the store bank conflict can be avoided by changing +// stride to 5 and sublane offsets to [0, 1, 2, 3, 5, 6, 7, 8]. +void handleBankConflict(int32_t &stride, absl::Span sublane_offsets) { + // Add one extra sublane to stride to avoid bank conflict. + for (int i = 0; i < sublane_offsets.size(); ++i) { + // Adjust sublane offsets to match the stride. + sublane_offsets[i] += i / stride; + } + ++stride; +} + } // namespace RollVectorsOp assemble(OpBuilder &builder, VectorType vty, @@ -4676,6 +4852,11 @@ FailureOr> disassemble( TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value()); TPU_ASSERT_LOC(val.getLoc(), def_layout->generalizes(layout, vty.getShape(), target_shape)); + auto layout_product = + xla::Product(layout.tileArrayShape(vty.getShape(), target_shape)); + auto def_layout_product = + xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape)); + TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product); // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of // having `tileArrayShape` and `tileArrayImplicitShape`. SmallVector layout_shape = @@ -5143,8 +5324,10 @@ FailureOr> tpu_rotate_with_overflow( // Compute the mask for the blend. // Positive blends blend "forward" and negative blends blend "backward". auto mask_val = amount; + auto vreg_rot_amount = amount; if (amount < 0) { mask_val = layout_in.tiling()[tiling_dim] - std::abs(amount); + vreg_rot_amount += target_shape[tiling_dim]; } auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, loc); auto mask = builder.create( @@ -5156,7 +5339,8 @@ FailureOr> tpu_rotate_with_overflow( in_tiles.Each([&](absl::Span idxs, Value *v) { if (dim >= in_tiles.num_dimensions() - 2) { *v = builder.create(loc, res_vreg_ty, in_tiles(idxs), - amount, tiling_dim, nullptr, nullptr); + vreg_rot_amount, tiling_dim, nullptr, + nullptr); } }); @@ -5606,16 +5790,37 @@ LogicalResult retileToLargeTileWithScratch( // The older hardware has limited support for shuffles so even if we have bank // conflicts, we just accept them and will have the lowering unroll the // loads/stores. + int64_t num_offsets = sublane_offsets.num_elements(); + // The max sublane offset before handling bank conflicts is always + // (num_offsets - 1). To avoid bank conflicts, we need to add one extra + // sublane to stride so (num_offsets - 1) / stride is the extra offset needed + // to pad sublanes. + // + // For example, if store stride = 4, sublane_count = 8, and + // load offsets = [0, 1, 2, 3, 4, 5, 6, 7], then the sublane offsets after + // handling bank conflicts will be [0, 1, 2, 3, 5, 6, 7, 8] and the max + // sublane offset will be 7 + (8 - 1) / 4 = 8. + // + // Before + // <-------- sublanes ---------> + // 0 1 ... 32 + // store: x---x---x---x---x---x---x---x + // load: xxxxxxxxx-------------------- + // + // After + // <-------- sublanes ---------> + // 0 5 ... 40 + // store: x----x----x----x----x----x----x----x + // load: xxxx-xxxx--------------------------- + // + // where "x" indicates a sublane that needs to be accessed and "-"" indicates + // a sublane that does not need to be accessed. + int max_sublane_offset = (num_offsets - 1) + (num_offsets - 1) / stride; bool should_handle_bank_confict = - ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && - ctx.vmem_banks < stride * ctx.target_shape[0]; - // Add one extra sublane to stride to avoid bank conflict. + shouldHandleBankConflict(ctx, stride, max_sublane_offset); if (should_handle_bank_confict) { - // Adjust sublane offsets to match the stride. - for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { - *(sublane_offsets.begin() + i) += i / stride; - } - stride += 1; + handleBankConflict(stride, absl::MakeSpan(sublane_offsets.data(), + sublane_offsets.num_elements())); } sublane_offsets.TransposeDimensions({0, 2, 1}); @@ -5738,9 +5943,34 @@ LogicalResult retileToSmallTileWithScratch( // The older hardware has limited support for shuffles so even if we have // bank conflicts, we just accept them and will have the lowering unroll the // loads/stores. + int64_t num_offsets = sublane_offsets.num_elements(); + // The max sublane offset before handling bank conflicts is always + // (num_offsets - 1). To avoid bank conflicts, we need to add one extra + // sublane to stride so (num_offsets - 1) / stride is the extra offset needed + // to pad sublanes. + // + // For example, if store stride = 4, sublane_count = 8, and + // load offsets = [0, 1, 2, 3, 4, 5, 6, 7], then the sublane offsets after + // handling bank conflicts will be [0, 1, 2, 3, 5, 6, 7, 8] and the max + // sublane offset will be 7 + (8 - 1) / 4 = 8. + // + // Before + // <-------- sublanes ---------> + // 0 4 ... + // store: x---x---x---x---x---x---x---x + // load: xxxxxxxxx------------------- + // + // After + // <-------- sublanes ---------> + // 0 5 ... + // store: x----x----x----x----x----x----x----x + // load: xxxx-xxxx--------------------------- + // + // where "x" indicates a sublane that needs to be accessed and "-"" indicates + // a sublane that does not need to be accessed. + int max_sublane_offset = (num_offsets - 1) + (num_offsets - 1) / stride; bool should_handle_bank_confict = - ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && - ctx.vmem_banks < stride * ctx.target_shape[0]; + shouldHandleBankConflict(ctx, stride, max_sublane_offset); bool use_shuffled_load = false; if (ctx.hardware_generation <= 4) { if (src_tile[0] == 8) { @@ -5759,11 +5989,8 @@ LogicalResult retileToSmallTileWithScratch( // Add one extra sublane to stride to avoid bank conflict. if (should_handle_bank_confict) { - // Adjust sublane offsets to match the stride. - for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { - *(sublane_offsets.begin() + i) += i / stride; - } - stride += 1; + handleBankConflict(stride, absl::MakeSpan(sublane_offsets.data(), + sublane_offsets.num_elements())); } sublane_offsets.TransposeDimensions({0, 2, 1}); auto mlirIndexConst = [&](int d) { @@ -5937,35 +6164,49 @@ FailureOr>> changeTiling( } const int packing = src.packing(); const int8_t bitwidth = src.bitwidth(); - VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, - src.implicit_dim()); - if (!dst.isValid(target_shape)) { - return emitError(loc, "Not implemented: invalid offsets in tiling target"); - } - auto dst_tiles_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating // sublanes. if (try_replicate_rows && packing == 1 && *(vregs.dimensions().end() - 2) == 1 && - src.offsets() == LayoutOffsets{0, 0} && src.tiling() == std::array{1, ctx.target_shape[1]} && dst_tiling == ctx.target_shape) { - xla::Array retiled(dst_tiles_shape); + DCHECK_EQ(src.offsets()[0].value_or(0), 0); + const LayoutOffset dst_minor_offset = + src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1]) + : std::nullopt; + const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset}, + dst_tiling, src.implicit_dim()); + xla::Array retiled( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); *(src_idx.end() - 2) *= target_shape[0]; - *(src_idx.end() - 1) /= target_shape[0]; - const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0]; - CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + if (!src.offsets()[1].has_value()) { + // With (1, 128) tiling each vreg holds values from a single row. This + // means that if the columns are replicated, then the whole vreg is + // already replicated. + *(src_idx.end() - 1) = 0; + *tile = vregs(src_idx); + } else { + // The column (in units of sublanes) of the sublane we want: + const int64_t sublane_column = + *(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1]; + *(src_idx.end() - 1) = sublane_column / target_shape[0]; + const int64_t src_sl_idx = sublane_column % target_shape[0]; + *tile = + broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + } }); - // We have successfully replicated sublanes. - dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, - dst.implicit_dim()); + // We have successfully replicated sublanes return std::pair(dst, std::move(retiled)); } + VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, + src.implicit_dim()); + if (!dst.isValid(target_shape)) { + return emitError(loc, "Not implemented: invalid offsets in tiling target"); + } + auto dst_tiles_shape = + dst.tileArrayImplicitShape(vty.getShape(), target_shape); // (8,128) -> (8 * packing,128) tiling change for packed type. if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape && dst_tiling == std::array{ctx.target_shape[0] * dst.packing(), @@ -6177,6 +6418,26 @@ FailureOr>> changeImplicitDim( }); return std::make_pair(dst, new_vregs); } + if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && + dst_implicit_dim == VectorLayout::ImplicitDim::kMinor && + src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { + // TODO(tlongeri): Make insertImplicitMinorDimension more flexible about + // offsets, then we can pass dst_offset_hints directly. + const LayoutOffset dst_2nd_minor_offset = + !src.offsets()[1] || *src.offsets()[1] + *(vty.getShape().end() - 1) <= + ctx.target_shape[1] + ? dst_offset_hints[0] + : LayoutOffset(*src.offsets()[1] % ctx.target_shape[0]); + VectorLayout dst(src.bitwidth(), + {dst_2nd_minor_offset, dst_offset_hints[1]}, src.tiling(), + VectorLayout::ImplicitDim::kMinor); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs, + insertImplicitMinorDimension(ctx, builder, loc, vregs, + src.implicitShape(vty.getShape()), src, + dst.offsets())); + return std::make_pair(dst, std::move(dst_vregs)); + } return emitError(loc, "Not implemented: Unsupported implicit dim change: from ") << src << " to " << dst_implicit_dim; @@ -6194,6 +6455,14 @@ FailureOr> relayout(RewriteContext &ctx, return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } VectorType vty = v.getType(); + const bool is_mask = vty.getElementTypeBitWidth() == 1; + if (is_mask) { + if (src.bitwidth() != 32 || dst.bitwidth() != 32) { + return emitError(v.getLoc(), + "Not implemented: mask relayout with non-32 bitwidth in " + "vector layout"); + } + } { // Replication imposes a replication constraint on the *logical* value of // the vector: When moving along a replicated axis, all elements must be @@ -6227,32 +6496,99 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + if (is_mask) { + auto new_tile_ty = + getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape); + src_tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = + builder.create(tile->getLoc(), new_tile_ty, *tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI32Type()); + } + auto assemble_with_mask_check = [&](xla::Array &tiles, + bool use_implicit_shape = false) { + if (is_mask) { + auto zeros_tile = builder.create( + tiles.begin()->getLoc(), + DenseElementsAttr::get(cast(tiles.begin()->getType()), + builder.getI32IntegerAttr(0))); + tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = builder.create( + tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI1Type()); + } + return assemble(builder, vty, dst, tiles, target_shape, use_implicit_shape) + .getResult(); + }; // Two easy cases: source is more general, or is replicated. if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with // a non-zero offset. - if (xla::Product(src.tileArrayShape(vty.getShape(), target_shape)) != - xla::Product(dst.tileArrayShape(vty.getShape(), target_shape))) { - return emitError(v.getLoc(), - "Not implemented: source layout is more general, but " - "vreg count changes"); + auto src_product = + xla::Product(src.tileArrayShape(vty.getShape(), target_shape)); + auto dst_product = + xla::Product(dst.tileArrayShape(vty.getShape(), target_shape)); + if (src_product != dst_product) { + TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product); + auto src_offsets = src.offsets(); + + TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets()); + TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth()); + + if (src.implicit_dim() != dst.implicit_dim()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and implicit dims are mismatched"); + } + + if (src.tiling() != dst.tiling()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and tiling are mismatched"); + } + + // This case is moving from a replicated to a non replicated layout. + // As such, we need to make a new destination shape that is the + // materialization of the src shape with replication. + FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs, + disassemble(builder, src, v, target_shape, + /*use_implicit_shape=*/true)); + auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape); + xla::Array dst_vregs(dst_vregs_shape); + dst_vregs.Each([&](const absl::Span idx, Value *vreg) { + SmallVector local_idx(idx.begin(), idx.end()); + if (!src_offsets[0].has_value()) { + local_idx[local_idx.size() - 2] = 0; + } + if (!src_offsets[1].has_value()) { + local_idx[local_idx.size() - 1] = 0; + } + *vreg = src_vregs(local_idx); + }); + return assemble(builder, vty, dst, std::move(dst_vregs), target_shape, + /*use_implicit_shape=*/true) + .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && - !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { + !src.offsets()[1].has_value()) { // A fully replicated value is always easy to relayout - // It would be nice to be able to assert this here, but given replicated - // values our rules can introduce equivalent expressions. - // assert all(t is src_tiles_list[0] for t in src_tiles_list) xla::Array dst_tiles( - /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), - /*value=*/src_tiles.data()[0]); - return assemble(builder, vty, dst, std::move(dst_tiles), target_shape) - .getResult(); + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + SmallVector idxs; + dst_tiles.Each([&](const absl::Span src_idx, Value *vreg) { + idxs.assign(src_idx.begin(), src_idx.end()); + dst.eraseImplicit(idxs); + src.insertImplicit(idxs, 0); + *(idxs.end() - 2) = 0; + *(idxs.end() - 1) = 0; + *vreg = src_tiles(idxs); + }); + return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit @@ -6290,9 +6626,8 @@ FailureOr> relayout(RewriteContext &ctx, dst.offsets())); CHECK_EQ(src, dst); // At this point we've should be done. - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } // TODO(apaszke): Implement a debug mode that inserts additional assertions. @@ -6318,8 +6653,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (vector_operand == nullptr) { continue; } - auto vty = vector_operand.getType(); - // The operand should always be an Operation (and not a BlockArgument) // since we expect the FuncOp to have only memrefs and semaphores as // arguments. @@ -6334,7 +6667,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); - if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) { + if (*lo == *li) { continue; } OpBuilder builder(&op); @@ -6420,6 +6753,7 @@ struct ApplyVectorLayoutPass mxu_noncontracting_size = ctx.mxu_shape[1]; max_sublanes_in_scratch = ctx.max_sublanes_in_scratch; vmem_banks = ctx.vmem_banks; + max_shuffle_sublane_offset = ctx.max_shuffle_sublane_offset; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -6432,7 +6766,9 @@ struct ApplyVectorLayoutPass .target_shape = {sublane_count, lane_count}, .mxu_shape = {mxu_contracting_size, mxu_noncontracting_size}, .max_sublanes_in_scratch = max_sublanes_in_scratch, - .vmem_banks = vmem_banks}; + .vmem_banks = vmem_banks, + .max_shuffle_sublane_offset = max_shuffle_sublane_offset, + }; if (failed(applyLayoutFunc(ctx, getOperation()))) { signalPassFailure(); return; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h new file mode 100644 index 000000000000..33c9e7421004 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -0,0 +1,21 @@ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ + +#include + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" + +namespace mlir::tpu::extensions { + +const llvm::StringMap< + std::function, ArrayRef)>> & +rules(); + +} // namespace mlir::tpu::extensions + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 9f2a8ed73a44..b471f92609c3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,5 +1,10 @@ +#include +#include #include #include +#include +#include +#include #include #include "llvm/ADT/STLExtras.h" @@ -15,6 +20,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" @@ -22,6 +29,7 @@ #include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/Block.h" #include "mlir/include/mlir/IR/Builders.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/Operation.h" @@ -39,6 +47,9 @@ namespace mlir::tpu { LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); + auto transpose_lhs = op.getTransposeLhs(); + auto transpose_rhs = op.getTransposeRhs(); + auto lhs = op.getLhs(); auto rhs = op.getRhs(); auto acc = op.getAcc(); @@ -51,6 +62,51 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { auto rhs_element_type = rhs_ty.getElementType(); auto acc_element_type = acc_ty.getElementType(); + // there are a few primary paths for dimension_numbers in matmul + // 1) No dimension numbers provided -> set to default + // 2) defined and not default -> verify and apply + // 3) defined and matching defaultDimensionNumbers -> no-op for + // canonicalization of dims + std::optional batch_size = std::nullopt; + + // MKN matmul - no dims or transpositions set + if (!op.getDimensionNumbers().has_value()) { + // Legacy API - convert it to dimension numbers + op.setDimensionNumbersAttr( + defaultDimensionNumbers(builder, transpose_lhs, transpose_rhs)); + } else if ( + // Dot dim API - dimensions are provided and are not default + (op.getDimensionNumbers().value() != + defaultDimensionNumbers(builder, false, false))) { + auto dimension_numbers = op.getDimensionNumbers(); + auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims(); + auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims(); + + auto lhs_batch_dims = dimension_numbers->getLhsBatchDims(); + auto rhs_batch_dims = dimension_numbers->getRhsBatchDims(); + + // Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs + // are the same + // Invariant in matmul verifier: Exactly one contracting and non contracting + // dim in each of lhs and rhs for now. + batch_size = + lhs_batch_dims.empty() + ? std::nullopt + : std::optional(lhs_ty.getShape()[lhs_batch_dims[0]]); + // Lower each dim in contracting dims by size(batch_dims) + auto batch_adjusted_lhs_contracting_dim = + lhs_contracting_dims[0] - lhs_batch_dims.size(); + auto batch_adjusted_rhs_contracting_dim = + rhs_contracting_dims[0] - rhs_batch_dims.size(); + + if (batch_adjusted_lhs_contracting_dim != 1) { + transpose_lhs = true; + } + if (batch_adjusted_rhs_contracting_dim != 0) { + transpose_rhs = true; + } + } + auto extsi_sitofp = [&builder, &op](TypedValue element) { const VectorType ty = element.getType(); auto shape = ty.getShape(); @@ -87,10 +143,12 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { if (lhs_element_type.isInteger()) { auto float_lhs = extsi_sitofp(lhs); op->setOperand(0, float_lhs); + lhs = cast>(float_lhs.getResult()); } if (rhs_element_type.isInteger()) { auto float_rhs = extsi_sitofp(rhs); op->setOperand(1, float_rhs); + rhs = cast>(float_rhs.getResult()); } } // TODO(mvoz): Add more invariants. @@ -113,6 +171,91 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { return failure(); } } + + auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) { + auto precision_attr = op.getPrecisionAttr(); + + // If we are transposing the lhs, we need to transpose the lhs before + // matmul here, as we don't have lhs fusion implemented in apply. + if (transpose_lhs) { + auto lhs_ty = cast(lhs.getType()); + auto rank = lhs_ty.getShape().size(); + + // This transposition must run on vectors with rank >= 2 + CHECK_GE(rank, 2); + + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[rank - 2], perm[rank - 1]); + + std::vector shape(lhs_ty.getShape()); + std::swap(shape[rank - 2], shape[rank - 1]); + + auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType()); + + const SmallVector perm_vec = + SmallVector(perm.begin(), perm.end()); + lhs = builder.create( + lhs_ty_transposed, lhs, + DenseI64ArrayAttr::get(builder.getContext(), perm_vec)); + } + auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false, + transpose_rhs); + // transpose flags are always false here, because ddn takes precedence + // after this pass. + auto matmul_res = builder.create( + op.getLoc(), acc.getType(), lhs, rhs, acc, + /*transpose_lhs=*/false, + /*transpose_rhs=*/false, precision_attr, ddn); + return matmul_res; + }; + + // If we have a batch_size, we want to slice rhs and lhs [:batch_size], + // and then do O[i] = A[i] @ B[i] + // Produce an output shape of [batch_size, m, n] + if (batch_size.has_value()) { + std::vector outputs; + + for (int64_t i = 0; i < batch_size; ++i) { + auto sliced_lhs = builder.create(op.getLoc(), lhs, + ArrayRef{i}); + auto sliced_rhs = builder.create(op.getLoc(), rhs, + ArrayRef{i}); + + auto sliced_acc = builder.create(op.getLoc(), acc, + ArrayRef{i}); + + auto matmul_res = + dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(), + sliced_acc.getResult()); + auto res_ty = matmul_res.getType().cast(); + auto res_shape = res_ty.getShape(); + // reshape to 1x[prior_shape] + auto reshape_shape = llvm::to_vector(res_shape); + reshape_shape.insert(reshape_shape.begin(), 1); + auto shape_cast = builder.create( + op.getLoc(), VectorType::get(reshape_shape, res_ty.getElementType()), + matmul_res); + outputs.push_back(shape_cast); + } + // Technically almost identical to the case where batch_size is 1, but + // we want to avoid the spurious concat here. + if (batch_size == 1) { + op.replaceAllUsesWith(outputs[0]); + op.erase(); + return success(); + } + auto output = builder + .create(op.getLoc(), acc_ty, outputs, + /*dimension=*/0) + .getResult(); + op.replaceAllUsesWith(output); + op.erase(); + } else { + auto matmul_res = dot_dim_matmul(lhs, rhs, acc).getResult(); + op.replaceAllUsesWith(matmul_res); + op.erase(); + } return success(); }; @@ -308,9 +451,14 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) { } const tpu::ContractPrecisionAttr precision_attr = // May be null contraction_op->getAttrOfType("precision"); + + const auto dot_dimension_numbers_attr = + defaultDimensionNumbers(builder, false, transpose_rhs); + auto matmul_op = builder.create( contraction_op->getLoc(), acc_ty, lhs, rhs, acc, - /*transpose_lhs=*/false, transpose_rhs, precision_attr); + /*transpose_lhs=*/false, + /*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr); contraction_op.replaceAllUsesWith(matmul_op.getResult()); contraction_op.erase(); auto result = tpu_matmul_rule(matmul_op); @@ -350,6 +498,29 @@ LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) { return success(); } +LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) { + auto op = dyn_cast(raw_op); + if (!isa(op.getType())) { + return op.emitOpError("Only vector types supported"); + } + auto operand = op.getSource(); + auto times = op.getTimes(); + if (times == 1) { + // A true no op - kind of an odd edge case, but this does come up in + // flash_attention_backward tests. + op.replaceAllUsesWith(operand); + op.erase(); + return success(); + } + auto operands = std::vector(times, operand); + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto concat = builder.create(op.getLoc(), op.getType(), + operands, op.getDimension()); + op.replaceAllUsesWith(concat.getResult()); + op.erase(); + return success(); +} + using canonicalize_rule_type = std::function; @@ -357,10 +528,11 @@ const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ {tpu::MatmulOp::getOperationName(), canonicalize_matmul}, {vector::ContractionOp::getOperationName(), canonicalize_contraction}, - {vector::ContractionOp::getOperationName(), canonicalize_extract}, + {vector::ExtractOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, - {arith::SelectOp::getOperationName(), canonicalize_select}}; + {arith::SelectOp::getOperationName(), canonicalize_select}, + {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc index 5478c64f9944..846e3bbb341f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc @@ -122,6 +122,14 @@ void tpu_strided_store_rule(tpu::StridedStoreOp op) { /*strides=*/op.getStrides()); } +void tpu_vector_store_rule(tpu::VectorStoreOp op) { + // TODO(b/379925823): Take strides into account. + assertIsValidSubwindow( + op, op.getIndices(), + /*window_shape=*/op.getValueToStore().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape()); +} + const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ // TODO: tpu::LoadOp, tpu::StoreOp @@ -133,6 +141,8 @@ const llvm::StringMap &rules() { as_generic_rule(tpu_strided_load_rule)}, {tpu::StridedStoreOp::getOperationName(), as_generic_rule(tpu_strided_store_rule)}, + {tpu::VectorStoreOp::getOperationName(), + as_generic_rule(tpu_vector_store_rule)}, }; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc new file mode 100644 index 000000000000..e7528533938f --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -0,0 +1,19 @@ +#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "mlir/include/mlir/IR/Operation.h" + +namespace mlir::tpu::extensions { + +using RewriteContext = ApplyVectorLayoutContext; + +using rule_type = std::function, ArrayRef)>; + +const llvm::StringMap &rules() { + static const llvm::StringMap *rules = + new llvm::StringMap{}; + return *rules; +} + +} // namespace mlir::tpu::extensions \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc new file mode 100644 index 000000000000..a67728076de1 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -0,0 +1,13 @@ +#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" + +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/LogicalResult.h" + +namespace mlir::tpu::extensions { + +bool canInferVectorLayout(const Operation &op) { return false; } + +LogicalResult inferVectorLayout(const Operation &op) { return failure(); } + +} // namespace mlir::tpu::extensions \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 541393fc2758..046b642f98a3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -87,6 +87,16 @@ FailureOr inferLayout(MemRefType memref_ty, int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = dyn_cast(memref_ty.getLayout())) { + if (leading_tile_rows > 0 && !tiled_layout_attr.getTiles().empty() && + tiled_layout_attr.getTiles().front().dimensions().size() == 2 && + tiled_layout_attr.getTiles().front().dimensions()[0] != + leading_tile_rows) { + return emitError(UnknownLoc::get(memref_ty.getContext()), + "Trying to infer memref layout with sublane tiling ") + << leading_tile_rows + << ", but the memref already has sublane tiling " + << tiled_layout_attr.getTiles().front().dimensions()[0]; + } return tiled_layout_attr; } if (auto affine_map_attr = dyn_cast(memref_ty.getLayout())) { @@ -226,13 +236,25 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, if (auto alloca_op = dyn_cast(op)) { TypedValue arg = alloca_op.getResult(); const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, - target_shape, tpu_tiling_flags)); + // If the memref can be reinterpreted to untiled, force to use tiling + // {1, target.lane_count} for 32 bit. + int64_t leading_tile_rows = 0; + // TODO(b/375038685): generalize untiled memref with packed type which + // needs to update load/store rules. + if (memref_ty.getElementTypeBitWidth() == 32 && memref_ty.getRank() > 1 && + *(memref_ty.getShape().end() - 1) <= target_shape[1]) { + leading_tile_rows = 1; + } + FAILUREOR_ASSIGN_OR_RETURN( + const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, target_shape, + tpu_tiling_flags, leading_tile_rows)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); builder.setInsertionPointAfter(alloca_op); + // TODO(b/376130272): add a canonicalizer for EraseLayoutOp so that if we + // have erase(erase(x)) then we rewrite it to erase(x). auto erase_op = builder.create( arg.getLoc(), MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), @@ -296,22 +318,56 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, } FAILUREOR_ASSIGN_OR_RETURN( - const MemRefType new_memref_ty, + MemRefType new_memref_ty, inferMemref(memref_ty, hardware_generation, target_shape, tpu_tiling_flags, leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { + Value val = arg; + Operation * arg_use_op = nullptr; + // If the arg memref can be reinterpreted to untiled, we can insert + // ReinterpretCastOp to use tiling {packing, target.lane_count} before + // EraseLayoutOp for only the arg memrefs and expect the rest memref + // layout inference is based on the casted layout automatically. This + // would help lift many restrictions in alignment check when consuming + // this memref. + if (canReinterpretToUntiledMemref(cast>(val), + target_shape, + /*allow_minormost_padding=*/true) && + // TODO(b/375038685): generalize untiled memref with packed type which + // needs to update load/store rules. + new_memref_ty.getElementTypeBitWidth() == 32) { + auto tiled_layout = + cast(new_memref_ty.getLayout()); + SmallVector tiles(tiled_layout.getTiles()); + SmallVector new_tile_strides(tiled_layout.getTileStrides()); + for (int i = 0; i < new_tile_strides.size() - 2; ++i) { + new_tile_strides[i] *= tiles[0].dimension(0); + } + tiles[0] = ::xla::Tile({1, target_shape[1]}); + new_memref_ty = MemRefType::get( + new_memref_ty.getShape(), new_memref_ty.getElementType(), + TiledLayoutAttr::get(new_memref_ty.getContext(), tiles, + new_tile_strides), + new_memref_ty.getMemorySpace()); + arg_use_op = builder.create(val.getLoc(), + new_memref_ty, val); + val = arg_use_op->getResult(0); + } // Some standard MLIR ops have static checks that seems unreasonable, // and we know they hold in the way they are used in Mosaic. Still, // verification with layouts likes to fail, because it can't statically // prove the properties. auto erase_op = builder.create( - arg.getLoc(), + val.getLoc(), MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), /*layout=*/nullptr, new_memref_ty.getMemorySpace()), - arg); - arg.replaceAllUsesExcept(erase_op.getResult(), erase_op); + val); + if (!arg_use_op) { + arg_use_op = erase_op; + } + arg.replaceAllUsesExcept(erase_op.getResult(), arg_use_op); } } f.setFunctionType( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2adc3bf0768e..dd63ba66cb9c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -50,6 +50,8 @@ limitations under the License. #include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" namespace mlir::tpu { @@ -185,7 +187,7 @@ class VectorLayoutInferer { false_ty.getElementTypeBitWidth() == kNativeBitwidth, "Only 32-bit select supported"); } - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + if (inferElementwise(&any_op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -196,7 +198,7 @@ class VectorLayoutInferer { auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() : op.getIn().getType().getIntOrFloatBitWidth(); if (in_bitwidth == 1) { - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + if (inferElementwise(&any_op).failed()) { return failure(); } } else { @@ -212,7 +214,7 @@ class VectorLayoutInferer { TPU_CHECK_OP(static_cast(lhs_ty) == static_cast(rhs_ty), "Only one side of cmp is a vector?"); // TODO(tlongeri): Check that TPU generation supports comparison. - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + if (inferElementwise(&any_op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -287,10 +289,6 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -323,8 +321,14 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (inferStore(op, + /*has_mask=*/op.getMask() != nullptr) + .failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { + if (inferStore(op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -340,6 +344,10 @@ class VectorLayoutInferer { if (inferElementwise(&any_op).failed()) { return failure(); } + } else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) { + if (mlir::tpu::extensions::inferVectorLayout(any_op).failed()) { + return failure(); + } } else { any_op.emitOpError("unsupported in vector layout inference"); return failure(); @@ -363,6 +371,11 @@ class VectorLayoutInferer { TPU_CHECK_OP(ty.getRank() > 0, "rank 0 vectors unsupported"); TPU_CHECK_OP(elems, "expected vector constants to use DenseElementsAttr"); auto bitwidth = ty.getElementTypeBitWidth(); + if (bitwidth == 1) { + // i1 is a special case where the layout bitwidth can be different from + // the element bitwidth, see comment in VectorLayout class + bitwidth = kNativeBitwidth; + } if (elems.isSplat()) { if (ty.getRank() == 1) { // Here, we choose to lay out along lanes arbitrarily. It would be @@ -763,14 +776,11 @@ class VectorLayoutInferer { LogicalResult infer(tpu::ConcatenateOp op) { TPU_CHECK_OP(!op.getSources().empty(), "Need at least one vector to concatenate"); - auto res_rank = op.getType().getRank(); - auto dimension = op.getDimension(); + int64_t res_rank = op.getType().getRank(); + uint32_t dimension = op.getDimension(); TPU_CHECK_OP(0 <= dimension && dimension < res_rank, "Expect a valid concatenate dimension"); - if (res_rank == 1) { - NYI("Support concatenation with 1D vectors"); - } - auto res_ty = op.getResult().getType(); + VectorType res_ty = op.getResult().getType(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); std::optional tiling_dim; @@ -783,29 +793,39 @@ class VectorLayoutInferer { if (tiling_dim.has_value()) { int64_t starting_point = 0; - auto first_layout = getLayout(op.getSources().front()); - auto op_layouts = getLayoutFromOperands(op); + Layout first_layout = getLayout(op.getSources().front()); + SmallVector op_layouts = getLayoutFromOperands(op); SmallVector in_layouts; in_layouts.reserve(op.getSources().size()); - auto native_tiling = nativeTiling(bitwidth); - + // Set implicit dim to treat 1D as (1, N) and tile it as (1, 128) + std::array tiling = + res_rank == 1 ? std::array{1L, target_shape_[1]} + : nativeTiling(bitwidth); + ImplicitDim implicit_dim = + res_rank == 1 ? ImplicitDim::kSecondMinor : ImplicitDim::kNone; + std::array vreg_slice = + VectorLayout::vregSlice(target_shape_, bitwidth, tiling); for (int i = 0; i < op.getSources().size(); ++i) { // Compute the offset per source. // Ex: for a cat of (10, 128), (10, 128) on dim 0, where the - // vreg_sice for that dim is 8, the first source starts at + // vreg_slice for that dim is 8, the first source starts at // offset 0, and overflows the vreg // by 2, so the offset for the second input is 2. - auto op_shape = + ArrayRef op_shape = cast(op.getSources()[i].getType()).getShape(); - auto offset_amount = starting_point % native_tiling[tiling_dim.value()]; - auto op_layout = op_layouts[i]; + Layout op_layout = op_layouts[i]; + int64_t offset_amount = starting_point % vreg_slice[tiling_dim.value()]; + if (offset_amount >= tiling[tiling_dim.value()]) { + return op.emitError( + "Not implemented: Input offsets outside of the first tile"); + } SmallVector in_idx{op_layout->offsets()[0].value_or(0), op_layout->offsets()[1].value_or(0)}; in_idx[tiling_dim.value()] = offset_amount; starting_point += op_shape[dimension]; in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]}, - native_tiling, ImplicitDim::kNone)); + tiling, implicit_dim)); } SmallVector res_layout_offsets( {first_layout->offsets()[0].value_or(0), @@ -814,13 +834,13 @@ class VectorLayoutInferer { // TODO(mvoz): A tiny optimization we could do here later is to // no-op setting tiling when sublane dim size is aligned to sublane // tiling. - auto res_layout = + VectorLayout res_layout = VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]}, - native_tiling, ImplicitDim::kNone); + tiling, implicit_dim); setLayout(op, in_layouts, res_layout); return success(); } else { - auto layout = getLayout(op.getSources().front()); + Layout layout = getLayout(op.getSources().front()); // When concatenating vectors with replicated offsets, we want to reset // the replicated offset to zero. Because we are not sure if the // replicated value from each vector are same. @@ -883,66 +903,21 @@ class VectorLayoutInferer { } LogicalResult infer(tpu::MatmulOp op) { - auto get_operand_layout = - [&](Value v, llvm::StringRef operand_name, - std::optional major_multiple = std::nullopt, - std::optional minor_multiple = - std::nullopt) -> std::optional { - auto layout = getLayout(v); - if (!layout.has_value()) { - op->emitOpError("Internal error: assert failed: Operand ") - << operand_name << " has no vector layout"; - return std::nullopt; - } - auto vty = cast(v.getType()); - auto tiling = nativeTiling(vty.getElementTypeBitWidth()); - auto shape = vty.getShape().take_back(2); - if (shape[0] % major_multiple.value_or(tiling[0]) != 0 || - shape[1] % minor_multiple.value_or(tiling[1]) != 0) { - op->emitOpError("Matmul operand ") - << operand_name << " must have a shape divisible by (" - << major_multiple.value_or(tiling[0]) << ", " - << minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0] - << ", " << shape[1] << ")"; - return std::nullopt; - } - // Override tiling to match the native one. - return VectorLayout(layout->bitwidth(), {0, 0}, tiling, - ImplicitDim::kNone); - }; - auto res_ty = dyn_cast(op->getResult(0).getType()); - TPU_CHECK_OP(res_ty, "only vector results supported"); - TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit matmul results supported"); - std::array in_layout; - CHECK_EQ(op->getNumOperands(), 3); - std::optional lhs_major_multiple; - std::optional rhs_major_multiple; - // We don't restrict the first lhs axis when the data is not packed. - if (cast(op->getOperand(0).getType()) - .getElementTypeBitWidth() == kNativeBitwidth) { - lhs_major_multiple = 1; - } - // We don't restrict the first rhs axis when the data is not packed. - if (cast(op->getOperand(1).getType()) - .getElementTypeBitWidth() == kNativeBitwidth) { - rhs_major_multiple = 1; - } - in_layout[0] = - get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1); - if (!in_layout[0].has_value()) { - return failure(); - } - in_layout[1] = - get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1); - if (!in_layout[1].has_value()) { - return failure(); - } - in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1); - if (!in_layout[2].has_value()) { - return failure(); - } - setLayout(op, in_layout, + auto lhs_bitwidth = op.getLhs().getType().getElementTypeBitWidth(); + auto rhs_bitwidth = op.getRhs().getType().getElementTypeBitWidth(); + auto acc_bitwidth = op.getAcc().getType().getElementTypeBitWidth(); + auto res_bitwidth = op.getResult().getType().getElementTypeBitWidth(); + TPU_CHECK_OP(acc_bitwidth == kNativeBitwidth, + "Expected 32-bit acc in tpu::MatmulOp"); + TPU_CHECK_OP(res_bitwidth == kNativeBitwidth, + "Expected 32-bit result in tpu::MatmulOp"); + auto lhs_layout = VectorLayout( + lhs_bitwidth, {0, 0}, nativeTiling(lhs_bitwidth), ImplicitDim::kNone); + auto rhs_layout = VectorLayout( + rhs_bitwidth, {0, 0}, nativeTiling(rhs_bitwidth), ImplicitDim::kNone); + auto acc_layout = VectorLayout( + acc_bitwidth, {0, 0}, nativeTiling(acc_bitwidth), ImplicitDim::kNone); + setLayout(op, {lhs_layout, rhs_layout, acc_layout}, VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, ImplicitDim::kNone)); return success(); @@ -1019,12 +994,6 @@ class VectorLayoutInferer { return success(); } - LogicalResult infer(tpu::RepeatOp op) { - auto src_layout = getLayout(op.getSource()); - setLayout(op, src_layout, src_layout); - return success(); - } - LogicalResult infer(tpu::TraceOp op) { static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { TPU_CHECK_OP(isa(op), "expected yield terminator"); @@ -1096,12 +1065,10 @@ class VectorLayoutInferer { return success(); } if (auto src_ty = dyn_cast(some_src_ty)) { - TPU_CHECK_OP(src_ty.getRank() >= 2, "source rank below 2D unsupported"); - TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported"); auto some_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; - if (layout.implicit_dim() != ImplicitDim::kNone) { + if (layout.implicit_dim() != ImplicitDim::kNone && src_ty.getRank() > 1) { VectorLayout layout_2d(layout.bitwidth(), layout.offsets(), layout.tiling(), ImplicitDim::kNone); if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) { @@ -1187,7 +1154,7 @@ class VectorLayoutInferer { } LogicalResult infer(vector::LoadOp op) { - auto src_ty = op.getMemRefType(); + auto src_ty = getMemRefType(op.getBase()); auto res_ty = op.getVectorType(); TPU_CHECK_OP(src_ty.getRank() == res_ty.getRank(), "memref and vector rank mismatch"); @@ -1280,6 +1247,17 @@ class VectorLayoutInferer { setLayout(op, in_layout, VectorLayout(bitwidth, {std::nullopt, offsets[1]}, layout_tiling, ImplicitDim::kNone)); + } else if (bitwidth == 32 && + canReinterpretToUntiledMemref( + op.getBase(), target_shape_, + /*allow_minormost_padding=*/true) && + *(src_ty.getShape().end() - 2) > 1) { + // Since it is untiled, we can load from any arbitrary address which + // means we can always set the sublane offset to 0. + // Note: if the src_shape[-2] == 1, we can just use the tiling from ref. + setLayout(op, in_layout, + VectorLayout(bitwidth, {0, offsets[1].value_or(0)}, + nativeTiling(bitwidth), ImplicitDim::kNone)); } else { setLayout( op, in_layout, @@ -1437,7 +1415,14 @@ class VectorLayoutInferer { // 1D tilings that use 1 in the sublane dimension. int64_t sublane_tiling = vreg_slice[0]; do { - if (src_tiled_ishape[1] == res_tiled_ishape[1] && + auto src_res_tiled_equal = src_tiled_ishape[1] == res_tiled_ishape[1]; + auto vreg_num_elements = + target_shape_[0] * target_shape_[1] * layout.packing(); + auto single_subline_mod_1024 = + (sublane_tiling == 1 && + src_tiled_ishape[1] % vreg_num_elements == 0 && + res_tiled_ishape[1] % vreg_num_elements == 0); + if ((src_res_tiled_equal || single_subline_mod_1024) && src_tiled_ishape[0] % sublane_tiling == 0 && res_tiled_ishape[0] % sublane_tiling == 0) { std::array tiling = {sublane_tiling, target_shape_[1]}; @@ -1445,11 +1430,11 @@ class VectorLayoutInferer { // unfolding, it's still a no-op, but we need to // add support in apply-vector-layout. LayoutOffsets offsets = {0, layout.offsets()[1]}; - setLayout(op, - VectorLayout(layout.bitwidth(), offsets, tiling, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), offsets, tiling, - implicit_dim)); + setLayout( + op, + VectorLayout(layout.bitwidth(), offsets, tiling, + layout.implicit_dim()), + VectorLayout(layout.bitwidth(), offsets, tiling, implicit_dim)); return success(); } sublane_tiling /= 2; @@ -1514,8 +1499,9 @@ class VectorLayoutInferer { return failure(); } - LogicalResult infer(vector::StoreOp op) { - auto ref_ty = op.getMemRefType(); + template + LogicalResult inferStore(Op op, bool has_mask = false) { + auto ref_ty = getMemRefType(op.getBase()); auto store_ty = op.getValueToStore().getType(); TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(), "memref and vector rank mismatch"); @@ -1596,15 +1582,36 @@ class VectorLayoutInferer { // We can strided store sublanes if we're storing a single sublane for // multiple times. Enabling this helps store one entire row to memref // more efficiently. - store_layout = VectorLayout(store_ty.getElementTypeBitWidth(), offsets, - {1, tiling[1]}, ImplicitDim::kNone); + store_layout = + VectorLayout(bitwidth, offsets, {1, tiling[1]}, ImplicitDim::kNone); + } else if (bitwidth == 32 && + // We accept padding in the minormost dim, because + // apply_vector_layout will properly mask stores. + canReinterpretToUntiledMemref( + op.getBase(), target_shape_, + /*allow_minormost_padding=*/true)) { + // Since it is untiled, we can store to any arbitrary address which + // means the sublane offset can be any value and we can fold it to + // 2nd minor index. + auto prev_store_layout = getLayout(op.getValueToStore()); + TPU_CHECK_OP(prev_store_layout.has_value(), "missing vector layout"); + offsets[0] = prev_store_layout->offsets()[0].value_or(0); + if (offsets[1].value_or(0) >= tiling[1]) { + offsets[1] = 0; + } + store_layout = VectorLayout(bitwidth, offsets, nativeTiling(bitwidth), + ImplicitDim::kNone); } else { - store_layout = VectorLayout(store_ty.getElementTypeBitWidth(), offsets, - {tiling[0], tiling[1]}, ImplicitDim::kNone); + store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]}, + ImplicitDim::kNone); } } SmallVector in_layout{store_layout}; in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout); + if (has_mask) { + // Mask layout should be the same as the layout of value to store. + in_layout.push_back(store_layout); + } setInLayout(op, in_layout); return success(); } @@ -1719,7 +1726,7 @@ class VectorLayoutInferer { return success(); } - LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) { + LogicalResult inferElementwise(Operation *op) { TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported"); TPU_CHECK_OP(op->getNumOperands() > 0, "elementwise ops with no operands unsupported"); @@ -1728,26 +1735,45 @@ class VectorLayoutInferer { std::optional out_layout_candidate; std::optional out_layout; SmallVector, 4> in_layouts; - int64_t bit_width = -1; + int64_t bitwidth = -1; + // Find the bitwidth of the operands/results. They must all be the same + // except for the case of i1s, which use a "fake" bitwidth for layouts. + // They can be relayouted (in principle) to any other fake bitwidth, so we + // don't commit to their bitwidth. See comments in VectorLayout class. + for (Value val : llvm::concat(op->getOperands(), op->getResults())) { + if (const VectorType vty = dyn_cast(val.getType())) { + const int64_t val_bitwidth = vty.getElementTypeBitWidth(); + if (val_bitwidth != 1) { + if (bitwidth == -1) { + bitwidth = val_bitwidth; + } else if (bitwidth != val_bitwidth) { + return op->emitOpError( + "Mismatched bitwidth in elementwise for non-i1 " + "operands/results"); + } + } + } + } for (int64_t i = 0; i < op->getNumOperands(); ++i) { if (auto vty = dyn_cast(op->getOperand(i).getType())) { - if (bit_width == -1) { - bit_width = vty.getElementTypeBitWidth(); - } - TPU_CHECK_OP( - !check_bitwidth || bit_width == vty.getElementTypeBitWidth(), - "Generic elementwise rule only supports operands of same width"); auto some_layout = getLayout(op->getOperand(i)); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; - // If the input is fully replicated, don't use it to commit to any - // layout. Replicated values are easy to relayout. - if (is_fully_replicated(some_layout)) { + if (bitwidth == -1) { + // All operands/results are i1s, just commit to the first bitwidth + DCHECK(!out_layout.has_value()); + bitwidth = layout.bitwidth(); + out_layout = layout; + in_layouts.push_back(layout); + } else if (bitwidth != layout.bitwidth()) { + DCHECK_EQ(vty.getElementTypeBitWidth(), 1); + in_layouts.push_back(std::nullopt); + } else if (is_fully_replicated(some_layout)) { + // If the input is fully replicated, don't use it to commit to any + // layout. Replicated values are easy to relayout. in_layouts.push_back(std::nullopt); out_layout_candidate = layout; - continue; - } - if (!out_layout) { + } else if (!out_layout) { // TODO(apaszke): There are probably smarter ways to choose layout. out_layout = layout; in_layouts.push_back(some_layout); @@ -1761,8 +1787,9 @@ class VectorLayoutInferer { // any replication bits that might have been present in out_layout, // since there is no guarantee that the conflicting inputs could // even become replicated. + DCHECK_EQ(out_layout->bitwidth(), bitwidth); out_layout = - VectorLayout(out_layout->bitwidth(), + VectorLayout(bitwidth, {out_layout->offsets()[0].value_or(0), out_layout->offsets()[1].value_or(0)}, out_layout->tiling(), out_layout->implicit_dim()); @@ -1777,9 +1804,6 @@ class VectorLayoutInferer { } Layout final_out_layout = std::nullopt; if (auto out_vty = dyn_cast(op->getResult(0).getType())) { - TPU_CHECK_OP( - !check_bitwidth || bit_width == out_vty.getElementTypeBitWidth(), - "Generic elementwise rule can't change element type width"); if (out_layout) { final_out_layout = *out_layout; } else if (out_layout_candidate) { @@ -1809,9 +1833,9 @@ class VectorLayoutInferer { "only 32-bit random bit generation supported"); // TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp. LayoutOffsets offsets = {0, 0}; - setOutLayout(op, VectorLayout( - kNativeBitwidth, offsets, nativeTiling(kNativeBitwidth), - ImplicitDim::kNone)); + setOutLayout( + op, VectorLayout(kNativeBitwidth, offsets, + nativeTiling(kNativeBitwidth), ImplicitDim::kNone)); return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h new file mode 100644 index 000000000000..dc16ddbdf26c --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -0,0 +1,15 @@ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ + +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/Support/LLVM.h" + +namespace mlir::tpu::extensions { + +bool canInferVectorLayout(const Operation &op); + +LogicalResult inferVectorLayout(const Operation &op); + +} // namespace mlir::tpu::extensions + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index 569038500067..0fd88ac1f294 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -70,6 +70,10 @@ LogicalResult specializeMemorySpace(TypedValue value, to_update.pop_back(); // Here we only have to handle the operations allowed on refs with // unspecified memory space. + if (auto op = dyn_cast(some_op)) { + updateResultFrom(op, op.getInput().getType()); + continue; + } if (auto op = dyn_cast(some_op)) { updateResultFrom(op, op.getMemRef().getType()); continue; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 3f6050f31dab..6717e3a3e8ec 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -43,6 +43,8 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; +// When this is bumped, we should file a TODO to update the forward-compatible +// version in tpu_custom_call.py in a month! constexpr int kVersion = 3; StringRef mangle(StringRef name, std::string* storage) { @@ -63,7 +65,7 @@ std::optional demangle(StringRef name) { using rule_type = std::function; -LogicalResult enqueue_dma_rule(Operation* op, int version) { +LogicalResult enqueue_dma_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { if (op->getNumOperands() == 3) { // Local DMA. @@ -84,7 +86,14 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) { return success(); } -LogicalResult semaphore_signal_rule(Operation* op, int version) { +LogicalResult enqueue_dma_downgrade(Operation* op, int version) { + if (version < 2) { + return op->emitError("Downgrade to version ") << version << " unsupported"; + } + return success(); +} + +LogicalResult semaphore_signal_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. @@ -102,7 +111,25 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { return success(); } -LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) { +LogicalResult semaphore_signal_downgrade(Operation* op, int version) { + if (version < 2) { + auto operands = op->getAttrOfType( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr()); + if (!operands || operands.size() != 4) { + return op->emitError("Missing or invalid AttrSizedOperandSegments"); + } + if (operands[3]) { + return op->emitError("Downgrade to version ") + << version << " impossible: core_id is set"; + } + op->removeAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr()); + } + return success(); +} + +LogicalResult vector_multi_dim_reduce_upgrade(Operation* op, int version) { // Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr // in version 3. if (version < 3) { @@ -130,21 +157,49 @@ LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) { return success(); } +LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { + if (version < 3) { + return op->emitError("Downgrade to version ") << version << " unsupported"; + } + return success(); +} + const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ - {EnqueueDMAOp::getOperationName(), enqueue_dma_rule}, - {SemaphoreSignalOp::getOperationName(), semaphore_signal_rule}, + {EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade}, + {SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade}, {vector::MultiDimReductionOp::getOperationName(), - vector_multi_dim_reduce_rule} + vector_multi_dim_reduce_upgrade} }; return *rules; } +const llvm::StringMap& downgrade_rules() { + static auto rules = new llvm::StringMap{ + {EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade}, + {SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_dim_reduce_downgrade}}; + return *rules; +} + struct MosaicSerdePass : public impl::MosaicSerdePassBase { using Base::Base; void runOnOperation() override { ModuleOp module = getOperation(); + if (!serialize.hasValue()) { + module.emitError("serialize option must be specified"); + return signalPassFailure(); + } + int serialize_version = + target_version.hasValue() ? target_version : kVersion; + if (serialize && serialize_version > kVersion) { + module.emitError("The highest supported version is ") + << kVersion << " but requested serialization at version " + << serialize_version; + return signalPassFailure(); + } if (serialize && !module->getContext()->allowsUnregisteredDialects()) { module.emitError() << "Cannot serialize within a context that does not " "allow unregistered dialects."; @@ -156,7 +211,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { module->setAttr( kVersionAttrName, IntegerAttr::get(IntegerType::get(module->getContext(), 64), - kVersion)); + serialize_version)); } else { IntegerAttr version_attr = module->getAttrOfType(kVersionAttrName); @@ -175,7 +230,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { module->removeAttr(kVersionAttrName); } std::string name_storage; - auto result = module.walk([this, &name_storage, version](Operation* op) { + auto result = module.walk([&](Operation* op) { if (isa(op)) { // Don't mangle the ModuleOp itself. return WalkResult::advance(); } @@ -207,6 +262,16 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { auto new_op = Operation::create( op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); + // Downgrade the op to the target version, if needed. + if (serialize && kVersion != serialize_version) { + if (const auto rule = + downgrade_rules().find(op->getName().getStringRef()); + rule != downgrade_rules().end()) { + if (rule->second(new_op, serialize_version).failed()) { + return WalkResult::interrupt(); + } + } + } op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op); op->replaceAllUsesWith(new_op->getResults()); op->erase(); diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 0e3e6d0d9cd8..c5f9833761b9 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -15,12 +15,18 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/util.h" +#include #include +#include +#include #include "llvm/Support/MathExtras.h" #include "absl/types/span.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/IR/ValueRange.h" #include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { SmallVector ComputeTileStrides(MemRefType memref_ty, @@ -39,4 +45,106 @@ SmallVector ComputeTileStrides(MemRefType memref_ty, } return tile_strides; } + +std::optional> isTransposedMatmul( + DotDimensionNumbersAttr dim_numbers) { + auto lhs_contracting_dims = dim_numbers.getLhsContractingDims(); + auto rhs_contracting_dims = dim_numbers.getRhsContractingDims(); + auto lhs_non_contracting_dims = dim_numbers.getLhsNonContractingDims(); + auto rhs_non_contracting_dims = dim_numbers.getRhsNonContractingDims(); + + if (lhs_contracting_dims.size() != 1 || rhs_contracting_dims.size() != 1 || + lhs_non_contracting_dims.size() != 1 || + rhs_non_contracting_dims.size() != 1) { + return std::nullopt; + } + + int64_t lhs_non_contracting_dim = lhs_non_contracting_dims[0]; + int64_t lhs_contracting_dim = lhs_contracting_dims[0]; + int64_t rhs_non_contracting_dim = rhs_non_contracting_dims[0]; + int64_t rhs_contracting_dim = rhs_contracting_dims[0]; + + bool lhs_transposed = lhs_non_contracting_dim > lhs_contracting_dim; + + bool rhs_transposed = rhs_contracting_dim > rhs_non_contracting_dim; + + return std::pair{lhs_transposed, rhs_transposed}; +} + +bool canReinterpretToUntiledMemref(TypedValue tiled_memref, + const std::array& target_shape, + bool allow_minormost_padding) { + MemRefType tiled_memref_ty = tiled_memref.getType(); + auto tiled_layout = + dyn_cast(tiled_memref_ty.getLayout()); + ValueRange dynamic_sizes = {}; + if (!tiled_layout) { + if (auto erase_op = tiled_memref.getDefiningOp()) { + tiled_memref = erase_op.getOperand(); + tiled_memref_ty = tiled_memref.getType(); + tiled_layout = + dyn_cast(tiled_memref_ty.getLayout()); + // TODO(b/375641258): Currently we rely on the pattern `slice -> + // (squeeze)* -> eraseLayout` to get the dynamic sizes, but other patterns + // may not work here: eg., slice -> eraseLayout -> reshape -> + // eraseLayout`. We should fix this! For now, if we can not get the + // expected dynamic sizes, we consider the memref cannot be reinterpreted + // to untiled. + Value ref = tiled_memref; + while (auto squeeze_op = ref.getDefiningOp()) { + ref = squeeze_op.getInput(); + } + if (auto slice_op = ref.getDefiningOp()) { + dynamic_sizes = slice_op.getDynamicSizes(); + } + } + } + if (!tiled_layout) { + // We expect the tiled memref to have a tiled layout. + return false; + } + if (tiled_memref_ty.getNumDynamicDims() != dynamic_sizes.size()) { + return false; + } + if (tiled_layout.getTiles().empty() || + tiled_layout.getTiles().front().dimensions().size() != 2 || + tiled_memref_ty.getRank() < 2) { + // TODO(b/375642202): Currently we only support >= 2D memref, we might + // need to handle 1D memref if we find a use case. + return false; + } + auto rank = tiled_memref_ty.getRank(); + auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth(); + if (tiled_memref_ty.isDynamicDim(rank - 1)) { + // TODO(jevinjiang): we can still allow the minormost padding if we know the + // max bound of the dynamic size is not larger than the target_shape[1]. + if (!isGuaranteedDivisible(dynamic_sizes.back(), target_shape[1])) { + return false; + } + dynamic_sizes = dynamic_sizes.drop_back(); + } else { + if (!allow_minormost_padding && + tiled_memref_ty.getShape()[rank - 1] != target_shape[1]) { + return false; + } + } + if (tiled_memref_ty.isDynamicDim(rank - 2)) { + if (!isGuaranteedDivisible(dynamic_sizes.back(), packing)) { + return false; + } + } else { + if (tiled_memref_ty.getShape()[rank - 2] % packing != 0) { + return false; + } + } + // Check if the minormost dim has a single tile. + return *(tiled_layout.getTileStrides().end() - 1) == 1 && + *(tiled_layout.getTileStrides().end() - 2) == 1; +} + +bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) { + auto memory_space = + dyn_cast_or_null(ty.getMemorySpace()); + return memory_space && memory_space.getValue() == space; +} } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index c18bd1b3fbc2..9052afad499a 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -2,7 +2,6 @@ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #include -#include #include #include #include @@ -16,7 +15,8 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/types/span.h" -#include "tsl/platform/statusor.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "mlir/include/mlir/IR/Value.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -98,6 +98,25 @@ std::string shapeToString(const T &shape) { SmallVector ComputeTileStrides(MemRefType memref_ty, absl::Span tiling); +// Assuming MKN matmul - This function must only be called after +// canonicalization passes. +// +// Given a set of dimension numbers, Returns a pair of booleans, where the +// first is true if the lhs is transposed +// and the second is true if the rhs is transposed. +std::optional> isTransposedMatmul( + DotDimensionNumbersAttr dim_numbers); + +// Returns true if a >=2D memref has a tiled layout and can be equivalently +// considered as an untiled memref, except for potential padding in the +// minormost dimension up to target_shape[1] (if allow_minormost_padding is +// true). +bool canReinterpretToUntiledMemref(TypedValue tiled_memref, + const std::array &target_shape, + bool allow_minormost_padding = false); + +// Determines whether the given MemRefType has the given memory space. +bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space); } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index e5eaeb347137..cb52488e79cc 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -26,6 +26,19 @@ py_library( deps = [":_mosaic_gpu_ext"], ) +cc_library( + name = "target", + srcs = ["target.cc"], + hdrs = ["target.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:MC", + ], +) + cc_library( name = "passes", srcs = [ @@ -104,16 +117,21 @@ cc_library( srcs = ["custom_call.cc"], deps = [ ":passes", + ":target", "//jaxlib/cuda:cuda_vendor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:ControlFlowToLLVM", @@ -168,10 +186,11 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/strings", "@nanobind", - "@xla//xla/service:custom_call_status", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 103f9f78c32f..54792b3097f7 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -31,12 +31,14 @@ limitations under the License. #include #include -#include "jaxlib/gpu/vendor.h" #include "absl/base/optimization.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "llvm/include/llvm/ADT/SmallVector.h" #include "llvm/include/llvm/Support/CodeGen.h" @@ -52,6 +54,7 @@ limitations under the License. #include "mlir/include/mlir/Conversion/Passes.h" #include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" @@ -79,8 +82,10 @@ limitations under the License. #include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/include/mlir/Transforms/Passes.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" #include "jaxlib/mosaic/gpu/passes.h" +#include "jaxlib/mosaic/gpu/target.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -89,8 +94,30 @@ namespace { using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); +absl::StatusOr> GetSmAndPtxIsaVersion() { + // Assumes driver has been initialized and a context exists. XLA already has + // some utilities to query this, but we try to stay runtime-agnostic, so we + // build our own here. + CUdevice device; + if (cuCtxGetDevice(&device) != CUDA_SUCCESS) { + return absl::InternalError("Failed to get device for current context"); + } + int major = 0; + if (cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + device) != CUDA_SUCCESS) { + return absl::InternalError("Failed to get major compute capability"); + } + int minor = 0; + if (cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, + device) != CUDA_SUCCESS) { + return absl::InternalError("Failed to get minor compute capability"); + } + return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); +} + mlir::FailureOr GetPassPipeline( - mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target) { + mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, + const std::string& sm, const std::string& ptx_isa) { static bool register_once = []() { llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); @@ -117,12 +144,14 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerGpuLaunchLoweringPass(); mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerByvalInsertionPass(); + mlir::arith::registerArithExpandOpsPass(); return true; }(); (void)register_once; - return mlir::parsePassPipeline( + return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( + arith-expand, canonicalize, gpu-launch-sink-index-computations, convert-nvgpu-to-nvvm, @@ -131,7 +160,9 @@ mlir::FailureOr GetPassPipeline( convert-scf-to-cf, convert-nvvm-to-llvm, expand-strided-metadata, - nvvm-attach-target{O=3 chip=sm_90a fast=false features=+ptx80 ftz=false module= triple=nvptx64-nvidia-cuda}, + nvvm-attach-target{O=3 chip=)", + sm, R"( fast=false features=+)", ptx_isa, + R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, convert-arith-to-llvm{index-bitwidth=0}, convert-index-to-llvm{index-bitwidth=64}, @@ -144,19 +175,19 @@ mlir::FailureOr GetPassPipeline( gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, - gpu-module-to-binary{format=)" + - mlir::gpu::stringifyCompilationTarget(target).str() + R"(}, + gpu-module-to-binary{format=)", + mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, - )" + + )", (target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering," - : "") + + : ""), R"( convert-to-llvm, reconcile-unrealized-casts ) - )"); + )")); } mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, @@ -251,7 +282,8 @@ class TemporaryDirectory { std::string path; }; -void DumpCompilationOutput(mlir::ModuleOp module) { +void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, + const std::string& ptx_isa) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; @@ -260,8 +292,9 @@ void DumpCompilationOutput(mlir::ModuleOp module) { } module = module.clone(); // Prevent accidental modification. - auto passes = GetPassPipeline(module.getContext(), - mlir::gpu::CompilationTarget::Assembly); + absl::Cleanup module_destroyer = [module] { module->erase(); }; + auto passes = GetPassPipeline( + module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); if (mlir::failed(passes) || mlir::failed(RunPasses(std::move(*passes), module))) { return; @@ -297,7 +330,7 @@ void DumpCompilationOutput(mlir::ModuleOp module) { // Run ptxas to generate SASS. std::vector ptxas_args = { "ptxas", "--opt-level", "3", - "--gpu-name", "sm_90a", "--output-file", + "--gpu-name", sm.c_str(), "--output-file", elf_path.c_str(), ptx_path.c_str()}; if (dump_ptxas) { ptxas_args.push_back("-v"); @@ -321,9 +354,15 @@ void DumpCompilationOutput(mlir::ModuleOp module) { absl::StatusOr> Compile( mlir::ModuleOp module) { - DumpCompilationOutput(module); - auto passes = GetPassPipeline(module.getContext(), - mlir::gpu::CompilationTarget::Binary); + auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); + if (!sm_and_ptx_isa.ok()) { + return sm_and_ptx_isa.status(); + } + const std::string sm = sm_and_ptx_isa.value().first; + const std::string ptx_isa = sm_and_ptx_isa.value().second; + DumpCompilationOutput(module, sm, ptx_isa); + auto passes = GetPassPipeline( + module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } @@ -377,6 +416,40 @@ GetKernelCache() { return std::make_pair(&context_cache, &mutex); } +absl::StatusOr> GetHostAndInitFuncNames( + mlir::ModuleOp module_op) { + // We look for two top level C-interface functions: + // - "host" function with symbol name "_mlir_ciface_" + // - "init" function with symbol name "_mlir_ciface__init" + constexpr std::string_view prefix = "_mlir_ciface_"; + std::vector names; + for (mlir::LLVM::LLVMFuncOp llvm_func : + module_op.getOps()) { + if (llvm_func.getName().starts_with(prefix)) { + names.push_back(llvm_func.getName().str()); + } + } + if (auto size = names.size(); size != 2) { + return absl::InternalError(absl::StrFormat( + "Expected to locate 2 symbols with %s prefix in the MLIR module, found " + "%d instead.", + prefix, size)); + } + // _mlir_ciface__init now follows _mlir_ciface_ + std::sort(names.begin(), names.end()); + + std::string host_func_name = names[0]; + std::string init_func_name = names[1]; + + if (init_func_name != absl::StrCat(host_func_name, "_init")) { + return absl::InternalError(absl::StrFormat( + "Expected init function name to equal the concatenation of the host " + "function name and the \"_init\" suffix, instead got " + "init_func_name=%s, host_func_name=%s.", + init_func_name, host_func_name)); + } + return std::make_pair(host_func_name, init_func_name); +} absl::StatusOr CompileAndInit(const char* module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); @@ -392,9 +465,16 @@ absl::StatusOr CompileAndInit(const char* module) { return maybe_engine.status(); } mlir::ExecutionEngine* execution_engine = maybe_engine->get(); - auto main = execution_engine->lookupPacked("_mlir_ciface_main"); - auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); - if (!init || !main) { + + auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); + if (!host_and_init_func_names.ok()) { + return host_and_init_func_names.status(); + } + auto [host_name, init_name] = host_and_init_func_names.value(); + + auto host = execution_engine->lookupPacked(host_name); + auto init = execution_engine->lookupPacked(init_name); + if (!init || !host) { return absl::InternalError("Failed to retrieve kernel function"); } void* module_ptr = nullptr; @@ -404,7 +484,7 @@ absl::StatusOr CompileAndInit(const char* module) { void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*main)); + reinterpret_cast(*host)); } // Each compiled kernel has a unique init func, and each kernel is used from diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index ec574de4368f..608270239882 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -13,50 +13,132 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include #include "nanobind/nanobind.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" +#include "absl/cleanup/cleanup.h" +#include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax::cuda { namespace { +namespace ffi = xla::ffi; namespace nb = nanobind; -void EventRecordCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto* event = reinterpret_cast(opaque); - if (gpuEventRecord(**event, reinterpret_cast(stream)) != - gpuSuccess) { - const char message[] = "Failed to record event"; - XlaCustomCallStatusSetFailure(status, message, sizeof(message)); +static std::string ToString(CUresult result) { + const char* error_name; + if (cuGetErrorName(result, &error_name)) { + return absl::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); } + const char* error_string; + if (cuGetErrorString(result, &error_string)) { + return error_name; + } + return absl::StrCat(error_name, ": ", error_string); +} + +// Ensure it is safe to store gpuEvent_t in a uint64_t buffer. +static_assert(sizeof(gpuEvent_t) <= sizeof(uint64_t)); + +static const auto* kEventRecord = + ffi::Ffi::Bind() + .Ctx>() + .Attr("copy_before") + .RemainingArgs() + .Ret>() // event + .RemainingRets() + .To([](gpuStream_t stream, bool copy_before, + auto remaining_args, auto ret, auto remaining_rets) { + static auto* event = new gpuEvent_t; + if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); + res) { + return ffi::Error::Internal( + absl::StrCat("Failed to create event: ", ToString(res))); + } + auto do_copy = [&]() { + gpuMemcpyAsync(ret->untyped_data(), event, + sizeof(gpuEvent_t), gpuMemcpyHostToDevice, stream); + }; + if (copy_before) { + do_copy(); + } + if (auto res = gpuEventRecord(*event, stream); res) { + return ffi::Error::Internal( + absl::StrCat("Failed to record event: ", ToString(res))); + } + if (!copy_before) { + do_copy(); + } + return ffi::Error::Success(); + }) + .release(); + +XLA_FFI_Error* EventRecord(XLA_FFI_CallFrame* call_frame) { + return kEventRecord->Call(call_frame); +} + +static const auto* kEventElapsed = + ffi::Ffi::Bind() + .Ctx>() + .Arg>() // start_event + .Arg>() // end_event + .Ret>() // elapsed_ms + .To([](gpuStream_t stream, auto start, auto end, auto out) { + gpuStreamSynchronize(stream); + auto start_event = std::make_unique(); + auto end_event = std::make_unique(); + absl::MakeCleanup([&]() { + gpuEventDestroy(*start_event); + gpuEventDestroy(*end_event); + }); + gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpyDeviceToHost); + gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpyDeviceToHost); + float elapsed; + if (auto res = + gpuEventElapsedTime(&elapsed, *start_event, *end_event); + res) { + return ffi::Error::Internal(absl::StrCat( + "Failed to get elapsed time between events: ", ToString(res))); + } + gpuMemcpy(out->untyped_data(), &elapsed, sizeof(float), + gpuMemcpyHostToDevice); + return ffi::Error::Success(); + }) + .release(); + +XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) { + return kEventElapsed->Call(call_frame); } NB_MODULE(_mosaic_gpu_ext, m) { - m.def("_gpu_event_create", []() { - gpuEvent_t* event = new gpuEvent_t(); - gpuEventCreate(event, GPU_EVENT_DEFAULT); - return reinterpret_cast(event); + m.def("registrations", []() { + return nb::make_tuple( + nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)), + nb::make_tuple("mgpu_event_elapsed", EncapsulateFunction(EventElapsed)) + ); }); - m.def("_gpu_event_destroy", [](uintptr_t event) { - gpuEventDestroy(*reinterpret_cast(event)); - }); - m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { - float elapsed_ms = -1; - if (gpuEventElapsedTime( - &elapsed_ms, *reinterpret_cast(start_event), - *reinterpret_cast(end_event)) != gpuSuccess) { - throw std::runtime_error("Failed to get elapsed time between events"); + m.def("_sync_all_devices", []() { + int devices = 0; + if (cudaGetDeviceCount(&devices) != gpuSuccess) { + throw std::runtime_error("Failed to get device count"); + } + for (int i = 0; i < devices; ++i) { + if (cudaSetDevice(i) != gpuSuccess) { + throw std::runtime_error("Failed to set device"); + } + if (cudaDeviceSynchronize() != gpuSuccess) { + throw std::runtime_error("Failed to synchronize device"); + } } - return elapsed_ms; }); - m.def("_record_event_capsule", - []() { return EncapsulateFunction(EventRecordCall); }); } } // namespace diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc new file mode 100644 index 000000000000..a1a66a709cbe --- /dev/null +++ b/jaxlib/mosaic/gpu/target.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/mosaic/gpu/target.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "llvm/include/llvm/MC/MCSubtargetInfo.h" +#include "llvm/include/llvm/MC/TargetRegistry.h" + +namespace mosaic::gpu { + +absl::StatusOr> GetSmAndPtxIsaVersion( + int major, int minor) { + // "base" compute capability as reported by the driver. + // For example for a Hopper H200 GPU this would return sm_90, and never + // sm_90a. + std::string sm_base = absl::StrCat("sm_", major, minor); + + const std::string triple = "nvptx64-nvidia-cuda"; + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (target == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to lookup LLVM target based on triple %s: %s", triple, error)); + } + + // Check if there's a variant of the current SM that ends in "a" + // (has architecture-specific capabilities) + const char* sm_arch_specific = nullptr; + { + // generic subtarget + std::unique_ptr subtarget_info{ + target->createMCSubtargetInfo(triple, "", "")}; + if (subtarget_info == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to get generic LLVM subtarget info for triple %s", triple)); + } + for (const llvm::SubtargetSubTypeKV& subtype : + subtarget_info->getAllProcessorDescriptions()) { + if (absl::StartsWith(subtype.Key, sm_base) && + absl::EndsWith(subtype.Key, "a")) { + sm_arch_specific = subtype.Key; + break; + } + } + } + + const std::string sm = sm_arch_specific ? sm_arch_specific : sm_base; + + std::unique_ptr subtarget_info{ + target->createMCSubtargetInfo(triple, sm, "")}; + if (subtarget_info == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to get LLVM subtarget info for sm %s", sm)); + } + + for (const llvm::SubtargetFeatureKV& feature : + subtarget_info->getEnabledProcessorFeatures()) { + if (absl::StartsWith(feature.Key, "ptx")) { + std::string ptx_isa = feature.Key; + return std::make_pair(sm, ptx_isa); + } + } + return absl::InternalError(absl::StrFormat( + "Failed to find a PTX ISA LLVM subtarget feature for %s", sm)); +} + +} // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/target.h b/jaxlib/mosaic/gpu/target.h new file mode 100644 index 000000000000..070ecedebd01 --- /dev/null +++ b/jaxlib/mosaic/gpu/target.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_TARGET_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_TARGET_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace mosaic::gpu { + +absl::StatusOr> GetSmAndPtxIsaVersion( + int major, int minor); + +} // namespace mosaic::gpu + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_TARGET_H_ diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 48268bfcf30a..6e575fb3092a 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -17,6 +17,20 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") load("@rules_python//python:defs.bzl", "py_library") +py_library( + name = "gpu_dialect", + srcs = [ + "mosaic_gpu.py", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_enums.py", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_ops.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//jaxlib/mlir", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + ], +) + gentbl_filegroup( name = "tpu_python_gen_raw", tbl_outs = [ diff --git a/jaxlib/mosaic/python/mosaic_gpu.py b/jaxlib/mosaic/python/mosaic_gpu.py new file mode 100644 index 000000000000..f99f53cfdb69 --- /dev/null +++ b/jaxlib/mosaic/python/mosaic_gpu.py @@ -0,0 +1,36 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python bindings for the MLIR Mosaic GPU dialect. + +Note: this file *must* be called `mosaic_gpu.py`, in order to match the dialect +name. Otherwise, MLIR is unable to find the module during dialect search. +""" + +# ruff: noqa: F401 +# ruff: noqa: F403 + + +# pylint: disable=g-bad-import-order +from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_ops import * # pylint: disable=wildcard-import # type: ignore[import-not-found] +from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_enums import * # pylint: disable=wildcard-import # type: ignore[import-not-found] +from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import * # pylint: disable=wildcard-import # type: ignore[import-not-found] + +try: + from jaxlib.mlir.dialects._ods_common import _cext +except ImportError: + from mlir.dialects._ods_common import _cext # type: ignore[import-not-found] + + +_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python") diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 8c1e96ed5966..6b481682a885 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -427,6 +427,48 @@ pybind_extension( ], ) +cc_library( + name = "hip_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_hybrid", + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_hybrid_kernels", + ":hip_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) + cc_library( name = "triton_kernels", srcs = ["//jaxlib/gpu:triton_kernels.cc"], @@ -494,6 +536,7 @@ py_library( name = "rocm_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_rnn", diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index c6855879e8be..f28b5c9b4e53 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -12,134 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include -#include -#include #include "nanobind/nanobind.h" -#include "absl/status/status.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/ffi/api/c_api.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" -#include "xla/pjrt/c/pjrt_c_api_helpers.h" -#include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" -#include "xla/tsl/python/lib/core/numpy.h" -#include "xla/util.h" +#include "jaxlib/gpu_plugin_extension.h" namespace nb = nanobind; namespace xla { namespace { -absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, - const char* fn_name_c_str, size_t fn_name_size, - nb::object fn, int api_version, - XLA_FFI_Handler_Traits traits) { - if (c_api->extension_start == nullptr) { - return Unimplemented("The plugin does not have extension."); - } - const PJRT_Extension_Base* next = - reinterpret_cast(c_api->extension_start); - while (next != nullptr && - next->type != - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { - next = next->next; - } - if (next == nullptr) { - return Unimplemented("The plugin does not have a custom call extension."); - } - PJRT_Gpu_Register_Custom_Call* register_custom_call = - reinterpret_cast(next)->custom_call; - - if (traits != 0) { - return Unimplemented("The plugin does not support custom call traits."); - } - - PJRT_Gpu_Register_Custom_Call_Args args; - args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name_c_str; - args.function_name_size = fn_name_size; - -#if PJRT_API_GPU_EXTENSION_VERSION >= 1 - args.api_version = api_version; -#endif - - auto as_capsule = [](nb::object obj) -> absl::StatusOr { - nb::capsule capsule; - if (!nb::try_cast(obj, capsule)) { - return absl::InvalidArgumentError( - "Custom call target registration requires handlers as PyCapsules"); - } - return capsule; - }; - -#if PJRT_API_GPU_EXTENSION_VERSION <= 1 - TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); - args.custom_call_function = fn_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); -#else - args.handler_instantiate = nullptr; - args.handler_prepare = nullptr; - args.handler_initialize = nullptr; - args.handler_execute = nullptr; - - // Register legacy custom call target (untyped void* API). - if (api_version == 0) { - TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); - args.handler_execute = capsule_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - // Register XLA FFI handler (typed API with explicit function signatures). - if (api_version == 1) { - auto capsule_execute = as_capsule(fn); - if (capsule_execute.ok()) { - args.handler_execute = capsule_execute->data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - nb::dict bundle; - if (nb::try_cast(fn, bundle)) { - auto handler = [&](const char* name) -> absl::StatusOr { - if (!bundle.contains(name)) return nullptr; - TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); - return capsule.data(); - }; - - TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); - TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); - TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); - TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - return absl::InvalidArgumentError( - "Unsupported custom call target type for api_version=1"); - } - - return absl::UnimplementedError(absl::StrFormat( - "API version %d is not supported by RegisterCustomCallTarget. " - "Supported versions are 0 and 1.", - api_version)); -#endif -} - -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -179,31 +65,7 @@ std::string ToString(hipError_t result) { } // namespace NB_MODULE(rocm_plugin_extension, m) { - tsl::ImportNumpy(); - m.def( - "register_custom_call_target", - [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, - nb::str xla_platform_name, int api_version, - XLA_FFI_Handler_Traits traits) { - const char* fn_name_c_str; - size_t fn_name_size; - nb::str fn_name_bn_str; - if (nb::try_cast(fn_name_py, fn_name_bn_str)) { - fn_name_c_str = fn_name_bn_str.c_str(); - fn_name_size = nb::len(fn_name_bn_str); - } else{ - nb::bytes bytes = nb::cast(fn_name_py); - fn_name_c_str = bytes.c_str(); - fn_name_size = bytes.size(); - } - xla::ThrowIfError(RegisterCustomCallTarget( - static_cast(c_api.data()), fn_name_c_str, - fn_name_size, std::move(fn), api_version, traits)); - }, - nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), - nb::arg("xla_platform_name"), nb::arg("api_version") = 0, - nb::arg("traits") = 0); - m.def("registrations", &Registrations); + BuildGpuPluginExtension(m); m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/setup.py b/jaxlib/setup.py index dea9503c7c00..989a8314eb92 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -72,6 +72,7 @@ def has_ext_modules(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ 'jaxlib': [ diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 4553dc1e3ea8..48dc03cfb7d6 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -14,9 +14,11 @@ # JAX is Autograd and XLA +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows", "jax_py_test") +load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") licenses(["notice"]) # Apache 2 @@ -30,11 +32,11 @@ py_binary( "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", - "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:api.h", + "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", + "@xla//xla/python:xla_client.py", + "@xla//xla/python:xla_extension", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + if_cuda([ @@ -44,11 +46,11 @@ py_binary( "//jaxlib/rocm:rocm_gpu_support", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) @@ -57,7 +59,7 @@ jax_py_test( srcs = ["build_wheel_test.py"], data = [":build_wheel"], deps = [ - "@bazel_tools//tools/python/runfiles", + "@bazel_tools//tools/python/runfiles", ], ) @@ -102,11 +104,11 @@ py_binary( "//jax_plugins/rocm:__init__.py", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) @@ -131,10 +133,75 @@ py_binary( "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +selects.config_setting_group( + name = "macos", + match_any = [ + "@platforms//os:osx", + "@platforms//os:macos", + ], +) + +selects.config_setting_group( + name = "arm64", + match_any = [ + "@platforms//cpu:aarch64", + "@platforms//cpu:arm64", + ], +) + +selects.config_setting_group( + name = "macos_arm64", + match_all = [ + ":arm64", + ":macos", + ], +) + +selects.config_setting_group( + name = "win_amd64", + match_all = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", ], ) + +string_flag( + name = "jaxlib_git_hash", + build_setting_default = "", +) + +config_setting( + name = "jaxlib_git_hash_nightly_or_release", + flag_values = { + ":jaxlib_git_hash": "nightly", + }, +) + +jax_wheel( + name = "jaxlib_wheel", + wheel_binary = ":build_wheel", +) + +jax_wheel( + name = "jax_cuda_plugin_wheel", + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_kernels_wheel", +) + +jax_wheel( + name = "jax_cuda_pjrt_wheel", + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_plugin_wheel", +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 99334dca0162..7abbf5958225 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -108,6 +108,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", @@ -145,6 +146,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_solver.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_rnn.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", @@ -172,11 +174,12 @@ def prepare_wheel_rocm( if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: + git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=args.jaxlib_git_hash, + git_hash=git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 0e2bba0c74d0..08c2389c292a 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -167,11 +167,12 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: + git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=args.jaxlib_git_hash, + git_hash=git_hash, ) finally: if tmpdir: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 5ebdf6e4c6b6..4db36fa0ea97 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -231,6 +231,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) @@ -244,6 +245,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", ], ) @@ -410,7 +412,8 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash) + git_hash = build_utils.get_githash(args.jaxlib_git_hash) + build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) finally: if tmpdir: tmpdir.cleanup() diff --git a/pyproject.toml b/pyproject.toml index 9f5f06e7a1b0..d688f7fbbf01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ show_error_codes = true disable_error_code = "attr-defined, name-defined, annotation-unchecked" no_implicit_optional = true warn_redundant_casts = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = [ @@ -59,8 +58,6 @@ filterwarnings = [ # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", - "default:The .* method is good for exploring strategies.*", # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also diff --git a/setup.py b/setup.py index 5df8ea75ffa0..eb28752c18a0 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.35' +_current_jaxlib_version = '0.4.36' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.34' +_latest_jaxlib_version_on_pypi = '0.4.36' -_libtpu_version = '0.0.2' +_libtpu_version = '0.0.6' _libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup' def load_version_module(pkg_path): @@ -119,6 +119,7 @@ def load_version_module(pkg_path): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], zip_safe=False, ) diff --git a/tests/BUILD b/tests/BUILD index dc5d5b37316b..c25d10f460aa 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -225,9 +225,7 @@ jax_multiplatform_test( "tpu_v4_2x2", "tpu_v5p_2x2", "tpu_v5e_4x2", - "cpu_shardy", "gpu_2gpu_shardy", - "tpu_v3_2x2_shardy", "tpu_v5e_4x2_shardy", ], shard_count = { @@ -246,10 +244,8 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "cpu_shardy", "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", - "tpu_v4_2x2_shardy", "tpu_v3_2x2", "gpu_2gpu", ], @@ -270,7 +266,13 @@ jax_multiplatform_test( backend_tags = { "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, + enable_configs = [ + "tpu_v3_2x2_shardy", + ], tags = ["multiaccelerator"], + deps = [ + "//jax:experimental", + ], ) jax_multiplatform_test( @@ -280,6 +282,7 @@ jax_multiplatform_test( "tpu_v3_2x2", "tpu_v5e_4x2", "tpu_v4_2x2", + "tpu_v3_2x2_shardy", ], deps = [ "//jax:experimental", @@ -293,7 +296,6 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_backends = ["gpu"], - env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"}, tags = [ "config-cuda-only", "multiaccelerator", @@ -307,6 +309,24 @@ jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], enable_backends = ["gpu"], + enable_configs = [ + "gpu_2gpu_shardy", + ], + tags = [ + "config-cuda-only", + ], + deps = [ + "//jax:experimental", + ], +) + +jax_multiplatform_test( + name = "mock_gpu_topology_test", + srcs = ["mock_gpu_topology_test.py"], + enable_backends = ["gpu"], + enable_configs = [ + "gpu_h100", + ], tags = [ "config-cuda-only", ], @@ -411,11 +431,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test( @@ -550,6 +565,7 @@ jax_multiplatform_test( name = "lax_test", srcs = ["lax_test.py"], backend_tags = { + "cpu": ["not_run:arm"], # Numerical issues, including https://github.com/jax-ml/jax/issues/24787 "tpu": ["noasan"], # Times out. }, shard_count = { @@ -560,7 +576,7 @@ jax_multiplatform_test( deps = [ "//jax:internal_test_util", "//jax:lax_reference", - ] + py_deps("numpy"), + ] + py_deps("numpy") + py_deps("mpmath"), ) jax_multiplatform_test( @@ -647,6 +663,13 @@ jax_multiplatform_test( }, ) +jax_multiplatform_test( + name = "magma_linalg_test", + srcs = ["magma_linalg_test.py"], + enable_backends = ["gpu"], + deps = py_deps("magma"), +) + jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], @@ -995,6 +1018,9 @@ jax_multiplatform_test( "gpu": ["--jax_num_generated_cases=40"], "tpu": ["--jax_num_generated_cases=40"], }, + disable_configs = [ + "cpu_shardy", # TODO(b/376475853): array values mismatch, need to fix and re-enable. + ], shard_count = { "cpu": 50, "gpu": 50, @@ -1174,6 +1200,12 @@ jax_multiplatform_test( srcs = ["key_reuse_test.py"], ) +jax_multiplatform_test( + name = "roofline_test", + srcs = ["roofline_test.py"], + enable_backends = ["cpu"], +) + jax_multiplatform_test( name = "x64_context_test", srcs = ["x64_context_test.py"], @@ -1232,6 +1264,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. + ], enable_configs = [ "cpu", "gpu_h100", @@ -1247,6 +1282,9 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + disable_configs = [ + "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. + ], enable_configs = [ "tpu_v2_1x1", "tpu_v3_2x2", @@ -1261,6 +1299,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. + ], enable_configs = [ "cpu", "gpu_h100", @@ -1311,10 +1352,8 @@ jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], enable_configs = [ - "cpu_shardy", "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", - "tpu_v4_2x2_shardy", ], shard_count = { "cpu": 50, @@ -1346,6 +1385,15 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "colocated_python_test", + srcs = ["colocated_python_test.py"], + deps = [ + "//jax:experimental_colocated_python", + "//jax/extend:ifrt_programs", + ], +) + jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], @@ -1395,13 +1443,13 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/355263220): enable once export is supported. + ], enable_configs = [ "tpu_v3_2x2", ], tags = [], - deps = [ - "//jax/experimental/export", - ], ) jax_multiplatform_test( @@ -1435,6 +1483,7 @@ jax_multiplatform_test( disable_configs = [ "gpu_a100", # TODO(b/269593297): matmul precision issues "gpu_h100", # Scarce resources. + "cpu_shardy", # TODO(b/355263220): enable once export is supported. ], shard_count = { "cpu": 40, @@ -1499,16 +1548,28 @@ jax_multiplatform_test( srcs = ["cudnn_fusion_test.py"], enable_backends = [], enable_configs = [ + "gpu_a100", "gpu_h100", ], tags = ["multiaccelerator"], ) +jax_py_test( + name = "custom_partitioning_sharding_rule_test", + srcs = ["custom_partitioning_sharding_rule_test.py"], + deps = [ + "//jax", + "//jax:experimental", + "//jax:test_util", + ], +) + exports_files( [ "api_test.py", "array_test.py", "cache_key_test.py", + "colocated_python_test.py", "compilation_cache_test.py", "memories_test.py", "pmap_test.py", diff --git a/tests/aot_test.py b/tests/aot_test.py index bca0d66ed384..62fecfaf48a4 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -19,6 +19,7 @@ import jax from jax._src import core from jax._src import test_util as jtu +import jax._src.lib from jax._src.lib import xla_client as xc from jax.experimental import topologies from jax.experimental.pjit import pjit @@ -62,6 +63,11 @@ def verify_serialization(lowered): jax.pmap(lambda x: x * x).lower( np.zeros((len(jax.devices()), 4), dtype=np.float32))) + @unittest.skipIf( + jax._src.lib.xla_extension_version < 300, + 'AOT compiler registration was broken in XLA extension version below' + ' 300.', + ) def test_topology_pjit_serialize(self): try: aot_topo = topologies.get_topology_desc( diff --git a/tests/api_test.py b/tests/api_test.py index d0a711f4a617..38467809e9d3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -33,6 +33,7 @@ import re import subprocess import sys +import traceback import types from typing import NamedTuple import unittest @@ -286,13 +287,15 @@ def test_jit_default_device(self, module): self.assertEqual(f(sticky).devices(), system_default_devices) self.assertEqual(f(1).devices(), system_default_devices) - # TODO(skye): make this work! def test_jit_default_platform(self): - with self.assertRaisesWithLiteralMatch( - ValueError, "jax.default_device must be passed a Device object " - "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"): with jax.default_device("cpu"): - jax.jit(lambda x: x + 1)(1) + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + self.assertEqual(result.device, jax.local_devices(backend="cpu")[0]) + + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + self.assertEqual(result.device, jax.local_devices()[0]) def test_complex_support(self): self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j) @@ -1372,6 +1375,35 @@ def f(x): } ) + def test_compile_options_jit(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + f_jit = jit( + f, + compiler_options={ + "xla_embed_ir_in_executable": True, + "xla_dump_max_hlo_modules": 200, + "xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5, + })(1.0) # doesn't crash. + + def test_exec_time_optimization_effort_compiler_option(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + f_jit = jit( + f, + compiler_options={ + "exec_time_optimization_effort": 0.0, + })(1.0) # doesn't crash. + + with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"): + f_jit = jit( + f, + compiler_options={ + "exec_time_compilation_effort": 0.0, + })(1.0) + def test_jit_lower_compile_with_compiler_options_invalid(self): def f(x): return jnp.sqrt(x ** 2) + 1. @@ -1389,7 +1421,21 @@ def f(x): lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) - def test_jit_lower_compile_with_compiler_options_multiple(self): + def test_jit_compile_with_compiler_options_multiple(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + with jtu.count_jit_compilation_cache_miss() as count: + jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.) + jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.) + self.assertEqual(count[0], 2) + + # We should still error on invalid options after some valid compiles + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"): + jit(f, compiler_options={"invalid_key": "invalid_value"})(1.) + + def test_lower_compile_with_compiler_options_multiple(self): def f(x): return jnp.sqrt(x ** 2) + 1. @@ -1457,6 +1503,8 @@ def test_caches_depend_on_axis_env(self): ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)() self.assertEqual(ans, expected) + # Since stackless, the vmap(f) version gets compiled a second time + @unittest.skip def test_caches_dont_depend_on_unnamed_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 f = jax.jit(lambda: jnp.sin(1)) @@ -1919,6 +1967,23 @@ def f(x1, x2, g): ): jax.vmap(f, (0, 0, None))(jnp.ones(2), jnp.ones(3), jnp.add) + def test_vmap_inconsistent_sizes_constructs_proper_error_message_kwargs(self): + # regression test for https://github.com/jax-ml/jax/issues/24406 + def f(x1, x2, a3): + return x1 + x2 + a3 + + with self.assertRaisesRegex( + ValueError, + "vmap got inconsistent sizes for array axes to be mapped:\n" + r" \* most axes \(2 of them\) had size 2, e.g. axis 0 of argument x1 of type float32\[2\];\n" + r" \* one axis had size 1: axis 0 of kwargs\['a3'\] of type float32\[1\]", + ): + jax.vmap(f)( + jnp.ones(2, dtype=jnp.float32), + a3=jnp.ones(1, dtype=jnp.float32), + x2=jnp.ones(2, dtype=jnp.float32) + ) + def test_device_get_scalar(self): x = np.arange(12.).reshape((3, 4)).astype("float32") x = api.device_put(x) @@ -2986,9 +3051,11 @@ def test_error_for_invalid_dtype(self): with jax.enable_checks(False): with self.assertRaisesRegex(TypeError, err_str): lax.add(jnp.array(7), np.array("hello")) - with jax.enable_checks(True): - with self.assertRaises(AssertionError): - lax.add(jnp.array(7), np.array("hello")) + # TODO(dougalm): re-enable checks at the beginning of `bind`. We just + # need to know which arguments to a generic primitive are ordinary operands vs functions. + # with jax.enable_checks(True): + # with self.assertRaises(AssertionError): + # lax.add(jnp.array(7), np.array("hello")) def test_vmap_preserves_docstr(self): def superfun(a): @@ -3071,7 +3138,7 @@ def f(x, y): "vmap got inconsistent sizes for array axes to be mapped:\n" r" \* one axis had size 1: axis 0 of argument x of type int32\[1\];" "\n" - r" \* one axis had size 2: axis 0 of argument y of type int32\[2\]"): + r" \* one axis had size 2: axis 0 of kwargs\['y'\] of type int32\[2\]"): f(jnp.array([1], 'int32'), y=jnp.array([1, 2], 'int32')) def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): @@ -3420,13 +3487,10 @@ def test_escaped_tracers_cant_lift_sublevels(self): re.DOTALL)): api.jit(lambda x: x)(self._saved_tracer) + @unittest.skip # TODO(dougalm): rethink what this should do under stackless def test_escaped_tracers_tracer_from_higher_level(self): api.grad(self.helper_save_tracer)(0.) - with self.assertRaisesRegex( - UnexpectedTracerError, - re.compile( - "Encountered an unexpected tracer.*Tracer from a higher level", - re.DOTALL)): + with self.assertRaises(UnexpectedTracerError): api.grad(lambda x: x)(self._saved_tracer) def test_escaped_tracers_incompatible_sublevel(self): @@ -3446,8 +3510,7 @@ def func1(x): return x + self._saved_tracer with self.assertRaisesRegex( UnexpectedTracerError, - re.compile("Encountered an unexpected tracer.*Can't lift", - re.DOTALL)): + re.compile("unexpected tracer")): api.grad(func1)(2.) def test_escaped_tracers_not_among_input_tracers(self): @@ -3673,18 +3736,6 @@ def g(x): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) - def test_join_concrete_arrays_with_omnistaging(self): - # https://github.com/jax-ml/jax/issues/4622 - x = jnp.array([1., 2., 3.]) - y = jnp.array([1., 2., 4.]) - - @jit - def f(): - core.lattice_join(core.ConcreteArray(x.dtype, x), - core.ConcreteArray(y.dtype, y)) - - f() # doesn't crash - def test_linearize_aux(self): def fn(x): return x * 2 - 3, x > 0 @@ -3842,7 +3893,7 @@ def g(x): x = g(x) return x - msg = r'Leaked trace MainTrace\(2,DynamicJaxprTrace\)' + msg = r'Leaked trace DynamicJaxprTrace' with self.assertRaisesRegex(Exception, f"{msg}"): f(3) @@ -4571,7 +4622,7 @@ def test_cache_miss_explanations_no_source_info(self): jax.jit(operator.add)(42, 24) @parameterized.named_parameters([ - {"testcase_name": f"{dtype}", "dtype": dtype} + {"testcase_name": f"{np.dtype(dtype)}", "dtype": dtype} for dtype in jtu.dtypes.custom_floats]) def test_jit_custom_floats(self, dtype): f = lambda x: x + 1 @@ -4707,6 +4758,7 @@ def f(inputs): for a, b in zip(ans, expected): self.assertAllClose(a, b) + @unittest.skip # TODO(dougalm): figure out with Matt what to do with this feature def test_inner_jit_forwarded_consts_stay_const(self): out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash self.assertEqual(out, 3) @@ -4752,6 +4804,21 @@ def add_one_and_dupe(x: int) -> tuple[int, int]: jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True) jax.eval_shape(jit_add_one_dupe, 0) # don't crash + def test_use_direct_linearize(self): + + def check_invariant_to_use_direct_linearize(f): + with config.use_direct_linearize(False): + ans1 = f() + with config.use_direct_linearize(True): + ans2 = f() + + self.assertEqual(ans1, ans2) + + def sin_of_sin(x): + return lax.sin(jax.jit(lax.sin)(x)) + + check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) + class RematTest(jtu.JaxTestCase): @@ -4856,6 +4923,7 @@ def g(x): msg = str(e) self.assertNotIn('static_argnums', msg) + @unittest.skip def test_remat_grad_python_control_flow_static_argnums(self): @partial(jax.remat, static_argnums=(0,)) def g(x): @@ -4878,6 +4946,7 @@ def f(x): expected = np.cos(2.) self.assertAllClose(ans, expected, check_dtypes=False) + @unittest.skip def test_remat_grad_python_control_flow_unhashable_static_argnums(self): @partial(jax.remat, static_argnums=(0,)) def g(x): @@ -6423,6 +6492,21 @@ def f(x): y_, = vjp(jnp.ones_like(y)) self.assertAllClose(y, y_, atol=0, rtol=0) + def test_concreteness_error_includes_user_code(self): + @jax.remat + def f(x): + if x > 0: + return x + else: + return jnp.sin(x) + + try: + f(3.) + except TracerBoolConversionError: + self.assertIn('x > 0', traceback.format_exc()) + else: + assert False + @jtu.with_config(jax_pprint_use_color=False) class JaxprTest(jtu.JaxTestCase): @@ -7105,8 +7189,8 @@ def g_jvp(primals, tangents): g.defjvp(g_jvp) return g(1.) - self.assertRaises(ad.CustomJVPException, lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.)) + self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) def test_nondiff_arg(self): @partial(jax.custom_jvp, nondiff_argnums=(0,)) @@ -7181,7 +7265,7 @@ def g_jvp(h, primals, tangents): h = lambda y: x + y # capture x return g(h, x) - with self.assertRaisesRegex(ad.CustomJVPException, "Detected differentiation"): + with self.assertRaises(UnexpectedTracerError): api.jvp(f, (2.,), (1.,)) def test_vmap_axes(self): @@ -7592,8 +7676,8 @@ def f_jvp(primals, _): f.defjvp(f_jvp) primals = (2., 3) - tangents = (np.ones(()), np.zeros((), float0),) - expected_tangents = (2., np.zeros((), float0)) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) self.assertAllClose(api.jvp(f, primals, tangents), (primals, expected_tangents)) @@ -10760,6 +10844,67 @@ def rule(axis_size, in_batched, xs): ys = api.vmap(f)(x=xs) self.assertAllClose(ys, jnp.cos(xs)) + def test_partial_eval_raises(self): + @jax.custom_batching.custom_vmap + def f(x): + return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + del axis_size # unused + return jnp.cos(xs), in_batched[0] + + with self.assertRaisesRegex( + ValueError, + "Linearization failed to produce known values for all output primals", + ): + jax.grad(f)(0.5) + + def test_compose_custom_vjp(self): + @jax.custom_vjp + @jax.custom_batching.custom_vmap + def f(x, y): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + return jnp.cos(xs) * ys, True + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + + xs = jnp.linspace(0, 1, 5) + ys = jnp.linspace(-0.1, 0.1, 5) + self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) + jax.grad(f)(xs[0], ys[0]) # Doesn't crash. + + def test_compose_custom_vjp_bwd_rule(self): + # This tests the case where both the forward and backward rules are wrapped + # in custom_vmap. + @jax.custom_batching.sequential_vmap + def fun_fwd(x, y): + return jnp.sin(x) * y, (x, y) + + @jax.custom_batching.sequential_vmap + def fun_bwd(res, ct): + x, y = res + return x * ct, y * ct + + fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) + fun.defvjp(fun_fwd, fun_bwd) + + xs = jnp.linspace(0, 1, 5) + y = jnp.array(0.5, dtype=xs.dtype) + f = jax.vmap(jax.jit(fun), in_axes=(0, None)) + out, f_vjp = jax.vjp(f, xs, y) + f_vjp(out) # Doesn't crash. + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2ac2edcdfd99..2f8d4d1c666f 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -6,9 +6,16 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking +array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted + +# clip out dtype has ambiguous semantics (https://github.com/numpy/numpy/issues/24976) +array_api_tests/test_operators_and_elementwise_functions.py::test_clip + +# JAX raises a ValueError rather than the expected IndexError for out-of-bound axis +array_api_tests/test_manipulation_functions.py::test_expand_dims diff --git a/tests/array_test.py b/tests/array_test.py index b3492e4d152f..9618a8cf4665 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -24,7 +24,6 @@ import jax import jax.numpy as jnp from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import op_shardings from jax._src import test_util as jtu @@ -608,16 +607,11 @@ def test_array_not_hashable(self): with self.assertRaisesRegex(TypeError, "unhashable type"): hash(x) - @jax.jit - def check_tracer_hash(x): - self.assertIsInstance(hash(x), int) + with self.assertRaisesRegex(TypeError, "unhashable type"): + jax.jit(hash)(x) - if deprecations.is_accelerated('tracer-hash'): - with self.assertRaisesRegex(TypeError, "unhashable type"): - check_tracer_hash(x) - else: - with self.assertWarnsRegex(FutureWarning, "unhashable type"): - check_tracer_hash(x) + with self.assertRaisesRegex(TypeError, "unhashable type"): + jax.vmap(hash)(x) def test_shape_dtype_struct_sharding_jit(self): mesh = jtu.create_mesh((8,), ('x')) @@ -1133,6 +1127,14 @@ def test_default_pmap_sharding_with_devices(self): ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order) self.assertEqual(ps._device_assignment, new_order) + def test_default_pmap_sharding_replicated(self): + x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32) + x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(x) + ps = jax.sharding.PmapSharding.default( + shape=(8,), sharded_dim=None, + devices=jax.local_devices()) + self.assertEqual(x.sharding, ps) + def test_mesh_repr(self): mesh = jtu.create_mesh((1, 1), ('x', 'y')) mesh_repr = repr(mesh) diff --git a/tests/batching_test.py b/tests/batching_test.py index 2b0b0d63a6f5..608053c23254 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -932,23 +932,6 @@ def f(scale): self.assertAllClose(ans, expected, check_dtypes=False, rtol=jtu.default_gradient_tolerance) - def testIssue387(self): - # https://github.com/jax-ml/jax/issues/387 - R = self.rng().rand(100, 2) - - def dist_sq(R): - dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :] - zero = jnp.zeros_like(dR) - dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR)) - return jnp.sum(dR ** 2, axis=2) - - @jit - def f(R): - _ = dist_sq(R) - return jnp.sum(R ** 2) - - _ = hessian(f)(R) # don't crash on UnshapedArray - @jax.legacy_prng_key('allow') def testIssue489(self): # https://github.com/jax-ml/jax/issues/489 diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 00925c5f7dfc..f84a9d5fb39f 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -31,6 +31,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.mesh import Mesh from jax._src.partition_spec import PartitionSpec as P @@ -68,6 +69,8 @@ def test_serialized_compile_options(self): debug_options.xla_dump_hlo_as_long_text = True debug_options.xla_dump_disable_metadata = True debug_options.xla_dump_hlo_pipeline_re = "xyzzy" + if jaxlib_version > (0, 4, 35): + debug_options.xla_gpu_experimental_autotune_cache_mode = 2 hash2 = self.get_hashed_value( cache_key._hash_serialized_compile_options, compile_options ) @@ -173,7 +176,7 @@ def _infer_sharding_from_operands(mesh, arg_shapes, result_shape): @custom_partitioning def _cp_add(x, y): - return jax.numpy.add(x, y) + return jax.numpy.add(x, y) _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, @@ -196,14 +199,59 @@ def _cp_add(x, y): r'(.*?backend_config\s*=\s*"([^"]*)".*?)' r'\}' ) - with config.remove_custom_partitioning_ptr_from_cache_key(True): - with computation.context: - updated_module = cache_key._remove_custom_partitioning_ptr( - type_cast(ir.Module, computation.operation.clone())) - bcs = [match[2] for - match in re.findall(pattern, str(updated_module), re.DOTALL)] - for bc in bcs: - self.assertEqual(bc, "REMOVED") + with computation.context: + updated_module = cache_key._remove_callbacks( + type_cast(ir.Module, computation.operation.clone()), + ignore_callbacks=cache_key.IgnoreCallbacks.ALL, + ) + bcs = [ + match[2] + for match in re.findall(pattern, str(updated_module), re.DOTALL) + ] + for bc in bcs: + self.assertEqual(bc, "REMOVED") + + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + hash_without_callback_ptrs = cache_key.get( + computation, + devices, + compile_options, + backend, + ignore_callbacks=cache_key.IgnoreCallbacks.CUSTOM_PARTITIONING, + ) + expected_hash = cache_key.get( + updated_module, devices, compile_options, backend + ) + self.assertEqual(expected_hash, hash_without_callback_ptrs) + + @jtu.skip_on_devices("cpu") + def test_host_callbacks_ptrs_removed(self): + def _host_callback(x, y): + jax.debug.print("x={x[0]} y={y[0]}", x=x, y=y) + + computation = ( + jax.jit(_host_callback) + .lower( + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + ) + .compiler_ir() + ) + pattern = r'(.*?backend_config\s*=\s*"([^"]*)".*?)' + with computation.context: + updated_module = cache_key._remove_callbacks( + type_cast(ir.Module, computation.operation.clone()), + ignore_callbacks=cache_key.IgnoreCallbacks.ALL, + ) + bcs = [ + match[1] + for match in re.findall(pattern, str(updated_module), re.DOTALL) + ] + for bc in bcs: + self.assertEqual(bc, "REMOVED") def test_different_device_assignment(self): computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py new file mode 100644 index 000000000000..bbd5c38068f3 --- /dev/null +++ b/tests/colocated_python_test.py @@ -0,0 +1,327 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import threading +import time +from typing import Sequence + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member +from jax.experimental import colocated_python +from jax.experimental.colocated_python import func as colocated_python_func +from jax.experimental.colocated_python import serialization +from jax.extend.ifrt_programs import ifrt_programs +import jax.numpy as jnp +import numpy as np + +config.parse_flags_with_absl() + + +def _colocated_cpu_devices( + devices: Sequence[jax.Device], +) -> Sequence[jax.Device]: + """Returns CPU devices colocated with the given devices.""" + try: + return colocated_python.colocated_cpu_devices(devices) + except (ValueError, AttributeError): + # PjRt-IFRT prepares CPU devices by its own. + # TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU + # devices by its own. + cpu_backend_devices = jax.local_devices(backend="cpu") + device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + + available_devices = devices[: min(len(cpu_backend_devices), len(devices))] + return [ + cpu_backend_devices[device_index_map[d.id]] for d in available_devices + ] + + +@contextlib.contextmanager +def _count_colocated_python_specialization_cache_miss() -> list[int]: + """Counts the number of cache misses for colocated_python specialization.""" + original_get_specialized_func = colocated_python_func._get_specialized_func + count = [0] + + @jax.util.cache(max_size=None) + def get_specialized_func(*args, **kwargs): + count[0] += 1 + return original_get_specialized_func(*args, **kwargs) + + colocated_python_func._get_specialized_func = get_specialized_func + try: + yield count + finally: + colocated_python_func._get_specialized_func = original_get_specialized_func + + +_exit_stack = contextlib.ExitStack() + + +def setUpModule(): + # TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT + # prepares CPU devices by its own. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + +def tearDownModule(): + _exit_stack.close() + + +class ColocatedPythonTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if xla_extension_version < 300: + self.skipTest("Requires xla_extension_version >= 300") + + def testMakeColocatedPythonProgram(self): + def add_one(x): + return x + 1 + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) + sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) + + pickled_function = serialization._serialize(add_one) + program = ifrt_programs.make_colocated_python_program( + "add_one", pickled_function, [cpu_devices[0]], [sds], [sds] + ) + del program + + def testSimpleFunction(self): + @colocated_python.colocated_python + def add_one(x): + return x + 1 + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + + with _count_colocated_python_specialization_cache_miss() as count: + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + def testSimpleFunctioWithTree(self): + @colocated_python.colocated_python + def add_one(x): + return jax.tree.map(lambda x: x + 1, x) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = [np.array(1), (np.array(2), {"v": np.array(3)})] + x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) + + with _count_colocated_python_specialization_cache_miss() as count: + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 1) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 1) + + def testEmptyInputFailsWithoutSpecialization(self): + @colocated_python.colocated_python + def make_zero(): + return jnp.array(0) + + with self.assertRaisesRegex( + ValueError, + "No devices found. colocated_python function without input arguments" + " must be first specialized with devices.", + ): + _ = make_zero() + + def testEmptyInputWithDevicesSpecialization(self): + @colocated_python.colocated_python + def make_zero(): + return jnp.array(0) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + + with _count_colocated_python_specialization_cache_miss() as count: + make_zero = make_zero.specialize(devices=cpu_devices[:1]) + out = make_zero() + out = jax.device_get(out) + self.assertEqual(out, np.array(0)) + self.assertEqual(count[0], 1) + + out = make_zero() + out = jax.device_get(out) + self.assertEqual(out, np.array(0)) + self.assertEqual(count[0], 1) + + def testInputPolymorphismWithoutOutSpecsFn(self): + @colocated_python.colocated_python + def add_one(x): + return jax.tree.map(lambda x: x + 1, x) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + + with _count_colocated_python_specialization_cache_miss() as count: + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + # Different input tree structure and dtype/shape. + x = [np.array(1), (np.array(2), {"v": np.array(3)})] + x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + def testInputPolymorphismAllowedWithOutSpecsFn(self): + @colocated_python.colocated_python + def add_one(x): + return jax.tree.map(lambda x: x + 1, x) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + + with _count_colocated_python_specialization_cache_miss() as count: + add_one = add_one.specialize(out_specs_fn=lambda x: x) + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + # Different input tree structure and dtype/shape. + x = [np.array(1), (np.array(2), {"v": np.array(3)})] + x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + out = add_one(x) + out = jax.device_get(out) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + @parameterized.named_parameters( + ("on_main_thread", True), + ("on_non_main_thread", False), + ) + def testSequentialExecution(self, on_main_thread: bool): + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + # Make sure that this input array is ready for use by the colocated Python + # function and does not disrupt elapsed time measurement. + jax.block_until_ready(x) + + @colocated_python.colocated_python + def sleep(x: jax.Array) -> jax.Array: + time.sleep(5) + return x + + # Specify out_specs_fn so that all executions are asynchronously dispatched. + sleep = sleep.specialize(out_specs_fn=lambda x: x) + + def sleep_twice_and_wait(x: jax.Array) -> None: + _ = sleep(x) + jax.block_until_ready(sleep(x)) + + start_time = time.time() + + # Two executions of `sleep` within `sleep_twice_and_wait` should run + # sequentially. + if on_main_thread: + sleep_twice_and_wait(x) + else: + t = threading.Thread(target=sleep_twice_and_wait, args=(x,)) + t.start() + t.join() + + elapsed_time = time.time() - start_time + + # If sequential execution did not happen, elapsed time typically will be + # around 5 seconds. + self.assertGreaterEqual(elapsed_time, 10) + + def testConcurrentExecution(self): + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + # Make sure that this input array is ready for use by the colocated Python + # function and does not disrupt elapsed time measurement. + jax.block_until_ready(x) + + @colocated_python.colocated_python + def sleep(x: jax.Array) -> jax.Array: + time.sleep(5) + return x + + # Specify out_specs_fn so that all executions are asynchronously dispatched. + sleep = sleep.specialize(out_specs_fn=lambda x: x) + + def sleep_and_wait(x: jax.Array) -> None: + jax.block_until_ready(sleep(x)) + + start_time = time.time() + + # All three executions of `sleep_and_wait` should run concurrently. + t1 = threading.Thread(target=sleep_and_wait, args=(x,)) + t2 = threading.Thread(target=sleep_and_wait, args=(x,)) + t1.start() + t2.start() + sleep_and_wait(x) + t1.join() + t2.join() + + elapsed_time = time.time() - start_time + + self.assertGreaterEqual(elapsed_time, 5) + # If concurrent execution did not happen, elapsed time typically will be + # around 15 seconds. + self.assertLess(elapsed_time, 10) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index e5222814fb02..428e518eab51 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -18,6 +18,7 @@ from functools import partial import logging import math +import os import platform import unittest from unittest import mock @@ -40,6 +41,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface +from jax._src.lib import xla_client as xc from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -420,6 +422,8 @@ def test_persistent_cache_hit_no_logging(self): self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING)) def test_persistent_cache_miss_logging_with_explain(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") with (config.explain_cache_misses(True), config.compilation_cache_dir("jax-cache")): @@ -464,6 +468,8 @@ def test_persistent_cache_miss_logging_with_explain(self): def test_persistent_cache_miss_logging_with_no_explain(self): # test that cache failure messages do not get logged in WARNING + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") with (config.explain_cache_misses(False), config.compilation_cache_dir("jax-cache")): # omitting writing to cache because compilation is too fast @@ -531,6 +537,43 @@ def test_backend_serialization_deserialization(self): self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) + def test_persistent_cache_enable_xla_caches(self): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("Test requires AutotuneCacheMode bindings") + s = os.sep + with config.compilation_cache_dir("jax-cache"): + with config.persistent_cache_enable_xla_caches("none"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("all"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, f"jax-cache{s}xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("xla_gpu_kernel_cache_file"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, f"jax-cache{s}xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("xla_gpu_per_fusion_autotune_cache_dir"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) @jtu.with_config( jax_enable_compilation_cache=False, @@ -566,5 +609,17 @@ def test_tasks_disable_cache_metric(self): "/jax/compilation_cache/task_disabled_cache"] self.assertEqual(count_after_second_use, count_after_first_use) + def test_persistent_cache_enable_xla_caches_disabled(self): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("Test requires AutotuneCacheMode bindings") + with config.enable_compilation_cache(False): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/core_test.py b/tests/core_test.py index 94b7010907a9..7ca941c69c7b 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -33,14 +33,13 @@ from jax._src import linear_util as lu from jax._src import util from jax._src import test_util as jtu -from jax._src.core import UnshapedArray, ShapedArray, DBIdx +from jax._src.core import ShapedArray, DBIdx from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow config.parse_flags_with_absl() -_ = pe.PartialVal.unknown(UnshapedArray(np.float32)) __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): @@ -348,13 +347,6 @@ def g_vmap(x): 'This BatchTracer with object id'): g_vmap(jnp.ones((1, ))) - def test_concrete_array_string_representation(self): - # https://github.com/jax-ml/jax/issues/5364 - self.assertEqual( - str(core.ConcreteArray(np.dtype(np.int32), - np.array([1], dtype=np.int32))), - 'ConcreteArray([1], dtype=int32)') - def test_dropvar_avals(self): def f(x): def body(c, _): @@ -541,15 +533,6 @@ def test_jaxpr_undefined_eqn_invar(self): r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr)) - @parameterized.parameters( - {'value': 0, 'weak_type': True}, - {'value': np.int32(0), 'weak_type': False}, - {'value': np.array([0]), 'weak_type': False} - ) - def test_raise_to_shaped_weak_type(self, value, weak_type): - aval = core.raise_to_shaped(core.get_aval(value)) - self.assertEqual(aval.weak_type, weak_type) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index 151cb72be8dc..7dc0571bc172 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest, parameterized from unittest import SkipTest from jax._src import test_util as jtu +from jax._src.lib import cuda_versions import jax import jax.numpy as jnp from jax._src.cudnn import cudnn_fusion @@ -26,8 +27,9 @@ class CudnnFusionTest(jtu.JaxTestCase): def setUp(self): if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on >= sm90 GPUs") + not jtu.is_cuda_compute_capability_at_least("8.0") or + cuda_versions.cudnn_get_version() < 90110): + self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+") super().setUp() @parameterized.parameters(["", "pmap"]) diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py new file mode 100644 index 000000000000..3aed16510a4f --- /dev/null +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -0,0 +1,468 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from jax._src import test_util as jtu +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy +from jax._src.custom_partitioning_sharding_rule import ArrayMapping, BATCHING, CompoundFactor, sdy_sharding_rule_to_mlir, str_to_sdy_sharding_rule, SdyShardingRule +from jax._src.lib.mlir.dialects import hlo as stablehlo + + +class SdyShardingRuleTest(jtu.JaxTestCase): + def test_compound_factor_not_enough_factors(self): + with self.assertRaisesRegex(ValueError, "A compound factor should contain at least two factors"): + CompoundFactor("i") + + def test_compound_factor_batching_now_allowed(self): + with self.assertRaisesRegex(ValueError, "Ellipsis can't be used in a compound factor"): + CompoundFactor(BATCHING, "i") + + def test_compound_factor_element_not_a_str(self): + with self.assertRaisesRegex(ValueError, "Each element of CompoundFactor must be a str"): + CompoundFactor("i", 2) + + def test_compound_factor_str(self): + c = CompoundFactor("i", "j", "k") + self.assertEqual(str(c), "('i', 'j', 'k')") + + def test_value_mapping_element_not_a_str_or_compound_factor(self): + with self.assertRaisesRegex(ValueError, "Each element of ArrayMapping must be a str or CompoundFactor"): + ArrayMapping(CompoundFactor("i", "j"), 3) + + def test_value_mapping_factor_name_not_start_with_letter(self): + with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): + ArrayMapping("3i", "j") + + def test_value_mapping_ellipsis_not_first(self): + with self.assertRaisesRegex(ValueError, "Ellipsis can only be used at the beginning of a dimension"): + ArrayMapping("i_j", BATCHING) + + def test_value_mapping_str(self): + v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k") + self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')") + + def test_sdy_sharding_rule_factor_size_not_used(self): + with self.assertRaisesRegex(ValueError, "Factor k is not used"): + SdyShardingRule(("i",), ("j",), k=10) + + def test_sdy_sharding_rule_factor_sizes_missing(self): + with self.assertRaisesRegex( + ValueError, + "Factor k is only used in compound factors; must specify its size"): + SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")), + (ArrayMapping(CompoundFactor("j", "k")),)) + + def test_sdy_sharding_rule_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule((ArrayMapping("i"),), (ArrayMapping("j"),), i=10) + + def test_sdy_sharding_rule_compound_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule((ArrayMapping(CompoundFactor("i", "j")),), + (ArrayMapping("i"),), i=10, j=20) + + def test_sdy_sharding_rule_str(self): + r = SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")), + (ArrayMapping(CompoundFactor("j", "k")),), k=10) + self.assertEqual(str(r), "SdyShardingRule((('i',), ('j',)), ((('j', 'k'),),), {'k': 10})") + + +class StrToSdyShardingRuleTest(jtu.JaxTestCase): + + def test_rule_is_not_a_str(self): + with self.assertRaisesRegex(TypeError, "rule must be a str"): + str_to_sdy_sharding_rule(1) + + def test_factor_sizes_is_not_a_proper_dict(self): + with self.assertRaisesRegex( + TypeError, "factor_sizes must be a dict of str to int"): + str_to_sdy_sharding_rule("i->j", i="j") + + def test_sharding_rule_ellipsis_not_complete(self): + with self.assertRaisesRegex( + ValueError, "Character '.' must be used inside ellipsis '...'"): + str_to_sdy_sharding_rule(".i -> j") + + def test_sharding_rule_invalid_factor_name(self): + with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): + str_to_sdy_sharding_rule("2i -> j") + + def test_sharding_rule_missing_results(self): + with self.assertRaisesRegex(ValueError, "There is no -> in rule"): + str_to_sdy_sharding_rule("i") + + def test_sharding_rule_inbalenced_brackets(self): + with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): + str_to_sdy_sharding_rule("i j, k)->j") + + def test_sharding_rule_inbalenced_brackets2(self): + with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): + str_to_sdy_sharding_rule("i (j k->j") + + def test_sharding_rule_empty_compound_dim(self): + with self.assertRaisesRegex( + ValueError, "Brackets should contain at least two factors"): + str_to_sdy_sharding_rule("i ( ) j k->j") + + def test_sharding_rule_one_factorcompound_dim(self): + with self.assertRaisesRegex( + ValueError, "Brackets should contain at least two factors"): + str_to_sdy_sharding_rule("i (j ) k->j") + + def test_sharding_rule_nested_brackets(self): + with self.assertRaisesRegex( + ValueError, "Compound factors should be one level"): + str_to_sdy_sharding_rule("i (j (k))->j") + + def test_sharding_rule_unknown_char(self): + with self.assertRaisesRegex(ValueError, "Unknown character"): + str_to_sdy_sharding_rule("i; j->j") + + def test_sharding_rule_unknown_single_char_ellipse(self): + with self.assertRaisesRegex(ValueError, "Unknown character"): + str_to_sdy_sharding_rule("…j->…j") + + def test_sharding_rule_ellipsis_not_leading_dim(self): + with self.assertRaisesRegex( + ValueError, "Ellipsis can only be used at the beginning of a dimension"): + str_to_sdy_sharding_rule("i ... -> j") + + def test_sharding_rule_ellipsis_inside_compound_dim(self): + with self.assertRaisesRegex( + ValueError, "Ellipsis can only be used at the beginning of a dimension"): + str_to_sdy_sharding_rule("i, (..., j) -> j") + + def test_sharding_rule_scalar_operand_scalar_result(self): + rule = str_to_sdy_sharding_rule("->") + self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})") + + def test_sharding_rule_one_scalar_operand(self): + rule = str_to_sdy_sharding_rule("i j, , k->j") + self.assertEqual( + str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") + + def test_sharding_rule_factor_elementwise_add(self): + rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j") + self.assertEqual( + str(rule), + "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," + " 'j'),), {})") + + def test_sharding_rule_factor_vector_scalar_add(self): + rule = str_to_sdy_sharding_rule("...i, -> ...i") + self.assertEqual( + str(rule), + "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") + + def test_sharding_rule_factor_reshape_combining(self): + rule = str_to_sdy_sharding_rule("i j -> (i j)") + self.assertEqual( + str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})") + + def test_sharding_rule_factor_reshape_reordering(self): + rule = str_to_sdy_sharding_rule("(j i) -> (i j)", i=10, j=20) + self.assertEqual( + str(rule), + "SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':" + " 20})") + + def test_sharding_rule_factor_compound_then_individual(self): + rule = str_to_sdy_sharding_rule("(i j) (j k) i -> j k") + self.assertEqual( + str(rule), + "SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})") + + def test_sharding_rule_factor_individual_then_compound(self): + rule = str_to_sdy_sharding_rule("i j k -> (i j) (j k)") + self.assertEqual( + str(rule), + "SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})") + + def test_sharding_rule_factor_infer_k(self): + rule = str_to_sdy_sharding_rule("i_ (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) + self.assertEqual( + str(rule), + "SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" + ",), {'k': 10, 'm': 10, 'bar_24': 20})") + + +class SdyShardingRuleConversionTest(jtu.JaxTestCase): + + def run(self, result=None): + with ir.Context() as ctx, ir.Location.unknown(ctx): + sdy.register_dialect(ctx) + stablehlo.register_dialect(ctx) + module = ir.Module.create() + with ir.InsertionPoint(module.body): + super().run(result) + + def get_tensor_type(self, shape): + return ir.RankedTensorType.get(shape, ir.F32Type.get()) + + def create_tensor_value(self, shape): + return ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type(shape)], + attributes=dict(call_target_name=ir.StringAttr.get("dummy_target")) + ).result + + def test_conversion_rule_op_mismatch_in_operands_num(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("i j-> i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule has 1 operands, but the operation has 2 operands"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_operands_rank(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("i j, i j k-> i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 1th operand has rank 3, but the operation 1th " + "operand has rank 2"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_results_num(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, + opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("i j, i j -> i j, i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule has 2 results, but the operation has 1 results"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_results_dim(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("i j, i j -> i j k") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 0th result has rank 3, but the operation 0th " + "result has rank 2"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_factor_has_two_sizes(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 64))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("i j, i j -> i j") + with self.assertRaisesRegex( + ValueError, + "Factor j corresponds to two sizes: 32 and 64"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_batching_dim_has_two_sizes(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 64))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("..., ... -> ...") + with self.assertRaisesRegex( + ValueError, + "Batching dimension 1 corresponds to two sizes: 32 and 64"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,],) + + def test_conversion_invalid_batching_dim(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("... i j k, ... i j k -> ... i j k") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 0th operand has rank 3, but the operation 0th operand has rank 2"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_compound_dimension_size_mismatch(self): + opnd = self.create_tensor_value((2, 4)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((9,))], + operands=[opnd,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("i j -> (i j)") + with self.assertRaisesRegex( + ValueError, + "0th result actual size 9 doesn't match the size 8 derived from the" + " compound factors"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type], + [result.result.type,]) + + def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16,)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("..., ... -> ...") + with self.assertRaisesRegex( + ValueError, + "Ellipsis represents different number of leading dimensions 2 and 1"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_compound_then_individual(self): + opnd = self.create_tensor_value((8,)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((2,4))], + operands=[opnd,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("(i j) -> i j") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>") + + def test_conversion_elementwise_rule_scalar_instance(self): + opnd0 = self.create_tensor_value(()) + opnd1 = self.create_tensor_value(()) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type(())], + operands=[opnd0, opnd1], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("..., ... -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([], [])->([])>") + + def test_conversion_elementwise_rule_2D_instance(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("..., ... -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>") + + def test_conversion_vector_scalar_add_2D_instance(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value(()) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("..., -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>") + + def test_conversion_reshape_rule(self): + opnd0 = self.create_tensor_value((2, 4)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((8,))], + operands=[opnd0,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("i j -> (i j)") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>") + + def test_conversion_contracting_dim_matmul(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((32, 8)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 8))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 020c9f744833..4573f542c14f 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -75,6 +75,7 @@ def testSingleResultPrimitiveNaN(self): @jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION) def testCallDeoptimized(self, jit): + raise SkipTest("re-enable once we handle contexts properly") # TODO(dougalm) @jit def f(x): return jax.lax.cond( diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a8f9..6c7e9e3ab712 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,10 @@ fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] +if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index c20fc95350c2..ae0848a74a37 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -44,6 +44,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenberg_lapack_gehrd +from jax._src.internal_test_util.export_back_compat_test_data import cpu_tridiagonal_lapack_sytrd_hetrd from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf @@ -68,6 +69,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions +from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -119,8 +121,10 @@ def test_custom_call_coverage(self): cpu_eig_lapack_geev.data_2024_08_19, cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, + cpu_schur_lapack_gees.data_2024_11_29, cpu_svd_lapack_gesdd.data_2024_08_13, cpu_hessenberg_lapack_gehrd.data_2024_08_31, + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01, ] # Add here all the testdatas that should cover the targets guaranteed # stable @@ -143,6 +147,7 @@ def test_custom_call_coverage(self): cpu_svd_lapack_gesdd.data_2023_06_19, cpu_triangular_solve_blas_trsm.data_2023_07_16, cpu_hessenberg_lapack_gehrd.data_2024_08_30, + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, @@ -611,10 +616,10 @@ def compute_max_backward_error(operand, reconstructed_operand): self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4)) - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) - for dtype_name in ("f32", "f64", "c64", "c128")]) + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", + dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) @jax.default_matmul_precision("float32") def test_cpu_schur_lapack_gees(self, dtype_name="f32"): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: @@ -640,6 +645,14 @@ def check_schur_results(res_run, res_expected, *, rtol, atol): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_schur_results) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 37) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_schur_results) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -759,6 +772,40 @@ def func(): ) self.run_one_test(func, data, rtol=rtol, atol=atol) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + @jax.default_matmul_precision("float32") + def test_cpu_tridiagonal_lapack_sytrd_hetrd(self, dtype_name="f32"): + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + shape = (2, 4, 4) + input_data = jtu.rand_default(self.rng())(shape, dtype) + # del input_data # Input is in the testdata, here for readability + def func(): + return lax.linalg.tridiagonal(input_data, lower=True) + + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + + data = self.load_testdata( + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 37) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) + def test_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 0f0c20fd78e3..e8b1afc224b7 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -45,24 +45,6 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: else: return re.compile("(" + "|".join(parts) + ")") -# TODO(necula): Failures to be investigated (on GPU). -_known_failures_gpu = make_disjunction_regexp( - # Failures on GPU due to failure to export custom call targets, these - # involve GPU custom call targets withoutbackwards compatibility tests. - "custom_linear_solve_", - "lu_", - "svd_", - "tridiagonal_solve_", -) - -# Some primitive lowering rules need the GPU backend to be able to create -# CUDA lowering. -_skip_cuda_lowering_unless_have_gpus = make_disjunction_regexp( - "svd_", "lu_", "eigh_", "qr_", "custom_linear_", "tridiagonal_solve_", - # TODO(b/350111820): random should work once we enable FFI threefry2x32 - "random_", -) - class PrimitiveTest(jtu.JaxTestCase): @@ -105,8 +87,8 @@ def test_prim(self, harness: test_harnesses.Harness): "decompositions for equality.") if (jtu.device_under_test() == "gpu" - and _known_failures_gpu.search(harness.fullname)): - self.skipTest("failure to be investigated") + and "tridiagonal_solve_" in harness.fullname): + self.skipTest("tridiagonal_solve_ is not yet guaranteed stable.") if harness.params.get("enable_xla", False): self.skipTest("enable_xla=False is not relevant") @@ -118,11 +100,14 @@ def test_prim(self, harness: test_harnesses.Harness): for l in harness.jax_unimplemented: if l.filter(dtype=harness.dtype): unimplemented_platforms = unimplemented_platforms.union(l.devices) - if (_skip_cuda_lowering_unless_have_gpus.search(harness.fullname) + # Some primitive lowering rules need the GPU backend to be able to create + # CUDA lowering. + if ("tridiagonal_solve_" in harness.fullname and all(d.platform != "gpu" for d in self.devices)): unimplemented_platforms.add("gpu") - logging.info("Harness is not implemented on %s", unimplemented_platforms) + if unimplemented_platforms: + logging.info("Harness is not implemented on %s", unimplemented_platforms) # Tolerances. tol = None @@ -164,7 +149,7 @@ def export_and_compare_to_native( logging.info("Exporting harness for %s", lowering_platforms) exp = export.export(jax.jit(func_jax), - lowering_platforms=lowering_platforms)(*args) + platforms=lowering_platforms)(*args) for device in devices: if device.platform in skip_run_on_platforms: diff --git a/tests/export_test.py b/tests/export_test.py index fd6bef11ee43..2946854aa549 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -196,7 +196,7 @@ def test_pytree_export_only(self): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b) + exp = get_exported(jax.jit(f), platforms=("cpu",))((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) self.assertEqual(exp.platforms, ("cpu",)) @@ -244,22 +244,6 @@ def test_export_error_no_jit(self): "Function to be exported must be the result of `jit`"): _ = export.export(lambda x: jnp.sin(x)) - @jtu.ignore_warning(category=DeprecationWarning, - message="The jax.experimental.export module is deprecated") - def test_export_experimental_back_compat(self): - if not CAN_SERIALIZE: - self.skipTest("serialization disabled") - from jax.experimental import export - # Can export a lambda, without jit - exp = export.export(lambda x: jnp.sin(x))(.1) - self.assertAllClose(exp.call(1.), np.sin(1.)) - - blob = export.serialize(exp, vjp_order=1) - rehydrated = export.deserialize(blob) - - self.assertAllClose(export.call(exp)(1.), np.sin(1.)) - self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.)) - def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name f = jax.jit(lambda x: jnp.sin(x)) @@ -479,7 +463,7 @@ def test_default_export_platform(self): def test_error_wrong_platform(self, platform): a = np.arange(4, dtype=np.float32) - exp_f = get_exported(jnp.sin, lowering_platforms=(platform,))(a) + exp_f = get_exported(jnp.sin, platforms=(platform,))(a) if xb.canonicalize_platform(jtu.device_under_test()) == platform: raise unittest.SkipTest("Uninteresting scenario") @@ -489,7 +473,7 @@ def test_error_wrong_platform(self, platform): # Now try with the platform check disabled exp_f_no_platform_check = get_exported( - jnp.sin, lowering_platforms=(platform,), + jnp.sin, platforms=(platform,), disabled_checks=[export.DisabledSafetyCheck.platform()])(a) res = exp_f_no_platform_check.call(a) self.assertAllClose(res, jnp.sin(a)) @@ -1480,7 +1464,7 @@ def f(x): def test_multi_platform(self): x = np.arange(8, dtype=np.float32) exp = get_exported(jax.jit(_testing_multi_platform_func), - lowering_platforms=("tpu", "cpu", "cuda", "rocm"))(x) + platforms=("tpu", "cpu", "cuda", "rocm"))(x) self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm")) module_str = str(exp.mlir_module()) expected_main_re = ( @@ -1503,14 +1487,14 @@ def test_multi_platform(self): def test_multi_platform_nested(self): x = np.arange(5, dtype=np.float32) exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))), - lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + platforms=("cpu", "tpu", "cuda", "rocm"))(x) self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call to the exported using a different sequence of # lowering platforms, but included in the lowering platforms for the # nested exported. exp2 = get_exported(jax.jit(exp.call), - lowering_platforms=("cpu", "cuda", "rocm"))(x) + platforms=("cpu", "cuda", "rocm"))(x) # Ensure that we do not have multiple lowerings of the exported function exp2_module_str = str(exp2.mlir_module()) @@ -1529,7 +1513,7 @@ def test_multi_platform_nested(self): def test_multi_platform_nested_inside_single_platform_export(self): x = np.arange(5, dtype=np.float32) exp = get_exported(jax.jit(_testing_multi_platform_func), - lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + platforms=("cpu", "tpu", "cuda", "rocm"))(x) self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call for the current platform. @@ -1602,14 +1586,14 @@ def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, def f(x): return times_2_or_3_or_4.bind(x) x = np.float32(42.) - exp = export.export(f, lowering_platforms=["cpu", "cuda", "rocm", "tpu"])(x) + exp = export.export(f, platforms=["cpu", "cuda", "rocm", "tpu"])(x) expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()]) self.assertAllClose(exp.call(x), expected) def test_multi_platform_unknown_platform(self): x = np.arange(8, dtype=np.float32) exp = get_exported(jax.jit(jnp.sin), - lowering_platforms=("tpu", "cpu", "cuda", "other"))(x) + platforms=("tpu", "cpu", "cuda", "other"))(x) self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other")) @@ -1636,7 +1620,7 @@ def test_multi_platform_and_poly(self): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))), - lowering_platforms=("cpu", "tpu"))( + platforms=("cpu", "tpu"))( jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) @@ -1659,8 +1643,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] return b * 2. res_native = f_jax(a) - exp = get_exported(f_jax, - lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(a) + exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) # Call with argument placed on different plaforms for platform in self.__class__.platforms: @@ -1806,7 +1789,7 @@ def f_jax(x): # x: f32[b1, b2] effect_class_name="ForTestingOrderedEffect1") exp = get_exported( jax.jit(f_jax), - lowering_platforms=("cpu", "tpu") + platforms=("cpu", "tpu") )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) mlir_module_str = str(exp.mlir_module()) wrapped_main_expected_re = ( diff --git a/tests/extend_test.py b/tests/extend_test.py index 0fc8821f1984..3561e716f09c 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -14,6 +14,7 @@ import os import unittest +from functools import partial import numpy as np from absl.testing import absltest @@ -23,9 +24,11 @@ from jax import lax import jax.extend as jex import jax.numpy as jnp +import jax.sharding as shd from jax._src import abstract_arrays from jax._src import api +from jax._src import config from jax._src import core from jax._src import linear_util from jax._src import prng @@ -33,7 +36,10 @@ from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.layout import DeviceLocalLayout +from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo +from jax._src.lax import linalg as lax_linalg_internal +from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() @@ -67,35 +73,41 @@ def test_symbols(self): class RandomTest(jtu.JaxTestCase): - def test_key_make_with_custom_impl(self): - shape = (4, 2, 7) - + def make_custom_impl(self, shape, seed=False, split=False, fold_in=False, + random_bits=False): + assert not split and not fold_in and not random_bits # not yet implemented def seed_rule(_): return jnp.ones(shape, dtype=jnp.dtype('uint32')) def no_rule(*args, **kwargs): assert False, 'unreachable' - impl = jex.random.define_prng_impl( - key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule, - random_bits=no_rule) + return jex.random.define_prng_impl( + key_shape=shape, seed=seed_rule if seed else no_rule, split=no_rule, + fold_in=no_rule, random_bits=no_rule) + + def test_key_make_with_custom_impl(self): + impl = self.make_custom_impl(shape=(4, 2, 7), seed=True) k = jax.random.key(42, impl=impl) self.assertEqual(k.shape, ()) self.assertEqual(impl, jax.random.key_impl(k)) def test_key_wrap_with_custom_impl(self): - def no_rule(*args, **kwargs): - assert False, 'unreachable' - shape = (4, 2, 7) - impl = jex.random.define_prng_impl( - key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule, - random_bits=no_rule) + impl = self.make_custom_impl(shape=shape) data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32')) k = jax.random.wrap_key_data(data, impl=impl) self.assertEqual(k.shape, (3,)) self.assertEqual(impl, jax.random.key_impl(k)) + def test_key_impl_is_spec(self): + # this is counterpart to random_test.py: + # KeyArrayTest.test_key_impl_builtin_is_string_name + spec_ref = self.make_custom_impl(shape=(4, 2, 7), seed=True) + key = jax.random.key(42, impl=spec_ref) + spec = jax.random.key_impl(key) + self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})") + class FfiTest(jtu.JaxTestCase): @@ -122,7 +134,6 @@ def testLoweringLayouts(self, layout_spec, expected_layout): # layouts. def lowering_rule(ctx, x): aval, = ctx.avals_in - ndim = len(aval.shape) return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec], result_layouts=[layout_spec])(ctx, x) prim = core.Primitive("test_ffi") @@ -228,51 +239,42 @@ def fun(x): fun(jnp.ones(5)) self.assertNotIsInstance(manager.exception, TypeError) - @jtu.sample_product( - shape=[(1,), (4,), (5,)], - dtype=(np.int32,), - ) - @jtu.run_on_devices("gpu") - def testFfiCall(self, shape, dtype): - pivots_size = shape[-1] - permutation_size = 2 * pivots_size - pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) - pivots = jnp.broadcast_to(pivots, shape) - expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) - actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size) - self.assertArraysEqual(actual, expected) + @jtu.sample_product(shape=[(6, 5), (4, 5, 6)]) + @jtu.run_on_devices("gpu", "cpu") + def testFfiCall(self, shape): + x = self.rng().randn(*shape).astype(np.float32) + expected = lax_linalg_internal.geqrf(x) + actual = ffi_call_geqrf(x) + for a, b in zip(actual, expected): + self.assertArraysEqual(a, b) @jtu.sample_product( - shape=[(1,), (4,), (5,)], - dtype=(np.int32,), - vmap_method=("expand_dims", "broadcast_all", "sequential", - "legacy_vectorized"), + shape=[(6, 5), (4, 5, 6)], + vmap_method=["expand_dims", "broadcast_all", "sequential"], ) - @jtu.run_on_devices("gpu") - def testFfiCallBatching(self, shape, dtype, vmap_method): + @jtu.run_on_devices("gpu", "cpu") + def testFfiCallBatching(self, shape, vmap_method): shape = (10,) + shape - pivots_size = shape[-1] - permutation_size = 2 * pivots_size - pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) - pivots = jnp.broadcast_to(pivots, shape) - expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) - actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation( - x, permutation_size, vmap_method=vmap_method))(pivots) - self.assertArraysEqual(actual, expected) - - @jtu.run_on_devices("gpu") + x = self.rng().randn(*shape).astype(np.float32) + expected = lax_linalg_internal.geqrf(x) + actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x) + for a, b in zip(actual, expected): + if vmap_method == "sequential" and len(shape) == 3: + # On GPU, the batched FFI call to geqrf uses an algorithm with + # different numerics than the unbatched version (which is used when + # vmap_method="sequential"). Therefore, we need to include floating + # point tolerance for this check. + self.assertArraysAllClose(a, b) + else: + self.assertArraysEqual(a, b) + + @jtu.run_on_devices("gpu", "cpu") def testVectorizedDeprecation(self): - pivots_size = 4 - shape = (10, pivots_size) - permutation_size = 2 * pivots_size - pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, - dtype=np.int32) - pivots = jnp.broadcast_to(pivots, shape) + x = self.rng().randn(3, 5, 4).astype(np.float32) with self.assertWarns(DeprecationWarning): - ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) + ffi_call_geqrf(x, vectorized=True) with self.assertWarns(DeprecationWarning): - jax.vmap( - lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots) + jax.vmap(ffi_call_geqrf)(x) def testBackwardCompatSyntax(self): def fun(x): @@ -280,20 +282,114 @@ def fun(x): with self.assertWarns(DeprecationWarning): jax.jit(fun).lower(jnp.ones(5)) + def testInputOutputAliases(self): + def fun(x): + return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]") + + def testInvalidInputOutputAliases(self): + def fun(x): + return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x) + with self.assertRaisesRegex(ValueError, "with input index"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x) + with self.assertRaisesRegex(ValueError, "with output index"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32), + input_output_aliases={0: 0})(x) + with self.assertRaisesRegex(ValueError, + "referring to an input with abstract value"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape, + x.dtype), + input_output_aliases={0: 0})(x) + with self.assertRaisesRegex(ValueError, + "referring to an input with abstract value"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def testLegacyBackendConfig(self): + def fun(x): + return jex.ffi.ffi_call("test", x, custom_call_api_version=2, + legacy_backend_config="12345")(x) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertRegex(hlo, 'backend_config = "12345"') -# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` -# custom call target because that's the only one in jaxlib that uses the -# new FFI interface. Once more are available, consider using something that -# can be run on multiple platforms. -def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs): - return jex.ffi.ffi_call( - "cu_lu_pivots_to_permutation", - jax.ShapeDtypeStruct( - shape=pivots.shape[:-1] + (permutation_size,), - dtype=pivots.dtype, - ), - **kwargs, - )(pivots) + def testInvalidBackendConfig(self): + def fun(x): + return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x) + with self.assertRaisesRegex(ValueError, + "The use of the legacy_backend_config"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", x, + custom_call_api_version=2)(x, attribute=1) + with self.assertRaisesRegex(ValueError, + "The use of ffi_call attributes requires"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def testAllow64(self): + if config.enable_x64.value: + self.skipTest("Requires enable_x64=False") + def fun(): + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))() + self.assertIn("tensor", jax.jit(fun).lower().as_text()) + + def testInvalidResultType(self): + with self.assertRaisesRegex( + ValueError, "All elements of result_shape_dtypes.*position 0"): + jex.ffi.ffi_call("test", None)() + with self.assertRaisesRegex( + ValueError, "All elements of result_shape_dtypes.*position 1"): + jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() + + @jtu.run_on_devices("gpu", "cpu") + def testShardMap(self): + mesh = jtu.create_mesh((1,), ("i",)) + x = self.rng().randn(8, 4, 5).astype(np.float32) + + @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'), + out_specs=shd.PartitionSpec('i')) + def f(x): + return ffi_call_geqrf(x) + + f(x) # eager mode doesn't crash + jax.jit(f)(x) # neither does JIT + self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) + + +def ffi_call_geqrf(x, **kwargs): + if jtu.test_device_matches(["cpu"]): + lapack._lapack.initialize() + + assert x.dtype == np.float32 + ndim = x.ndim + x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) + output_types = [ + x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)] + + def call(platform, x): + target_name = dict( + cpu="lapack_sgeqrf_ffi", + rocm="hipsolver_geqrf_ffi", + cuda="cusolver_geqrf_ffi", + )[platform] + return jex.ffi.ffi_call( + target_name, output_types, input_output_aliases={0: 0}, + input_layouts=[x_major_to_minor], + output_layouts=[x_major_to_minor, None], + **kwargs)(x) + + return lax.platform_dependent( + x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"), + cuda=partial(call, "cuda")) class MlirRegisterLoweringTest(jtu.JaxTestCase): diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index f34b8211eb33..53f69cdb3f19 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -41,7 +41,7 @@ def main(_): print_ir(np.float32(1), np.float32(2))(lax.add) # CHECK-LABEL: TEST: acos float32[] - # CHECK: hlo.atan2 + # CHECK: chlo.acos # CHECK-SAME: tensor print_ir(np.float32(1))(lax.acos) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 438ba55203a9..9e0ebd4ff922 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -223,7 +223,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): [dict(for_impl=for_impl, impl_name=impl_name) for for_impl, impl_name in FOR_LOOP_IMPLS], ) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name, impl_name): for_ = for_impl @@ -255,7 +255,7 @@ def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name, [dict(for_impl=for_impl, impl_name=impl_name) for for_impl, impl_name in FOR_LOOP_IMPLS], ) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name, impl_name): for_ = for_impl @@ -365,7 +365,7 @@ def g(a, b): [dict(for_impl=for_impl, impl_name=impl_name) for for_impl, impl_name in FOR_LOOP_IMPLS], ) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name, impl_name): @@ -385,7 +385,7 @@ def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name, jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2, rtol=7e-3, atol=1e-2) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? @jax.legacy_prng_key('allow') def test_grad_of_triple_nested_for_loop(self): diff --git a/tests/garbage_collection_guard_test.py b/tests/garbage_collection_guard_test.py index d23d239dda1b..ce0585ba5b49 100644 --- a/tests/garbage_collection_guard_test.py +++ b/tests/garbage_collection_guard_test.py @@ -20,7 +20,6 @@ from absl.testing import absltest import jax from jax._src import config -from jax._src.lib import xla_extension_version import jax._src.test_util as jtu import jax.numpy as jnp @@ -37,8 +36,8 @@ def __init__(self, data): def _create_array_cycle(): """Creates a reference cycle of two jax.Arrays.""" - n1 = GarbageCollectionGuardTestNodeHelper(jnp.ones((2, 2))) - n2 = GarbageCollectionGuardTestNodeHelper(jnp.zeros((2, 2))) + n1 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.ones( (2, 2)))()) + n2 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.zeros((2, 2)))()) n1.next = n2 n2.next = n1 @@ -46,9 +45,6 @@ def _create_array_cycle(): class GarbageCollectionGuardTest(jtu.JaxTestCase): def test_gced_array_is_not_logged_by_default(self): - if xla_extension_version < 293: - self.skipTest("Requires xla_extension_version >= 293") - # Create a reference cycle of two jax.Arrays. _create_array_cycle() @@ -66,9 +62,6 @@ def test_gced_array_is_not_logged_by_default(self): ) def test_gced_array_is_logged(self): - if xla_extension_version < 293: - self.skipTest("Requires xla_extension_version >= 293") - # Use mock_stderr to be able to inspect stderr. mock_stderr = io.StringIO() diff --git a/tests/infeed_test.py b/tests/infeed_test.py index e378fe37a2f5..5dd52b4167d5 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -37,6 +37,7 @@ def setUp(self): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeed(self): + raise SkipTest("skipping temporarily for stackless") @jax.jit def f(x): @@ -56,6 +57,7 @@ def f(x): self.assertAllClose(f(x), x + y + z) def testInfeedPytree(self): + raise SkipTest("skipping temporarily for stackless") x = np.float32(1.5) y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) diff --git a/tests/jet_test.py b/tests/jet_test.py index 4e437c044426..7c2c71e9bbfa 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -319,6 +319,8 @@ def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0)) def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1))) @jtu.skip_on_devices("tpu") def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3))) + @jtu.skip_on_devices("tpu") + def test_copy(self): self.unary_check(jnp.array) @jtu.skip_on_devices("tpu") diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 286088eebe48..3364c9be91dd 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -18,8 +18,8 @@ import numpy as np import jax -from jax import core import jax.numpy as jnp +from jax._src import core from jax._src import prng from jax._src import random from jax._src import test_util as jtu diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 7fb118d47256..4b0420fda8f9 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -322,6 +322,19 @@ def testWhileTypeErrors(self): lax.while_loop(lambda c: True, lambda c: (True, True), (np.bool_(True), np.float32(0.))) + def testWhileLoopCustomPytreeDiffAuxData(self): + class Node: + def __init__(self, x, y): + self.x = x + self.y = y + tree_util.register_pytree_with_keys( + Node, + lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys + lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved) + lambda o: ((o.x, o.y), 'without_keys'), # flatten + ) + lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.)) + def testNestedWhileWithDynamicUpdateSlice(self): num = 5 @@ -2095,6 +2108,7 @@ def apply_carry(x, i): jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash def testIssue804(self): + # https://github.com/google/jax/issues/804 num_devices = jax.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash @@ -2423,6 +2437,7 @@ def f(c, a): scan = lambda c, xs: lax.scan(f, c, xs) scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2) + scan_fully_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=True) # jaxprs should be the same size self.assertEqual( @@ -2430,9 +2445,19 @@ def f(c, a): len(str(jax.make_jaxpr(scan_unrolled)(c, xs)))) # but HLO should grow due to unrolling - self.assertLess( - len(str(jax.jit(scan).lower(c, xs).as_text('hlo'))), - len(str(jax.jit(scan_unrolled).lower(c, xs).as_text('hlo')))) + scan_hlo = str(jax.jit(scan).lower(c, xs).as_text("hlo")) + scan_unrolled_hlo = str(jax.jit(scan_unrolled).lower(c, xs).as_text("hlo")) + scan_fully_unrolled_hlo = str( + jax.jit(scan_fully_unrolled).lower(c, xs).as_text("hlo")) + + self.assertLess(len(scan_hlo), len(scan_unrolled_hlo)) + self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo)) + + # and the lowering should contain a while loop, unless the scan is fully + # unrolled + self.assertIn("while(", scan_hlo) + self.assertIn("while(", scan_unrolled_hlo) + self.assertNotIn("while(", scan_fully_unrolled_hlo) def test_scan_xs_none(self): def f(h, _): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 392af2688c1d..ab625d10b4d8 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -399,6 +399,14 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)), out_shape=(3,)), ]), + ("EllipsisWithArrayIndices", [ + IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])), + out_shape=(2, 4)), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])), + out_shape=(2, 3)), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])), + out_shape=(3, 2)), + ]), ] diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 623c11a51998..2bef35fbdcef 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -448,6 +448,34 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS) + def testReducerWhereNonBooleanErrorInitial(self, rec): + dtype = rec.dtypes[0] + x = jnp.zeros((10,), dtype) + where = jnp.ones(10, dtype=int) + func = getattr(jnp, rec.name) + def assert_warns_or_errors(msg): + if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + func(x, where=where, initial=jnp.array(0, dtype=dtype)) + + @jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS) + def testReducerWhereNonBooleanErrorNoInitial(self, rec): + dtype = rec.dtypes[0] + x = jnp.zeros((10,), dtype) + where = jnp.ones(10, dtype=int) + func = getattr(jnp, rec.name) + def assert_warns_or_errors(msg): + if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + func(x, where=where) + @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact, @@ -820,15 +848,6 @@ def test_f16_mean(self, dtype): def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): rng = jtu.rand_some_zero(self.rng()) - def np_mock_op(x, axis=None, dtype=None, include_initial=False): - axis = axis or 0 - out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype) - if include_initial: - zeros_shape = list(x.shape) - zeros_shape[axis] = 1 - out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis) - return out - # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as # input because we rely on JAX-specific casting behavior def args_maker(): @@ -836,13 +855,24 @@ def args_maker(): if out_dtype in unsigned_dtypes: x = 10 * jnp.abs(x) return [x] - - np_op = getattr(np, "cumulative_sum", np_mock_op) kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial) + if jtu.numpy_version() >= (2, 1, 0): + np_op = np.cumulative_sum + else: + def np_op(x, axis=None, dtype=None, include_initial=False): + axis = axis or 0 + out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis) + return out + np_fun = lambda x: np_op(x, **kwargs) jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + rtol={jnp.bfloat16: 5e-2}) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( @@ -866,5 +896,53 @@ def testCumulativeSumBool(self): dtype=jnp.bool_) np.testing.assert_array_equal(np.array([[True], [True], [False]]), out) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list( + range(-len(shape), len(shape)) + ) + ([None] if len(shape) == 1 else [])], + [dict(dtype=dtype, out_dtype=out_dtype) + for dtype in (all_dtypes+[None]) + for out_dtype in ( + complex_dtypes if np.issubdtype(dtype, np.complexfloating) + else all_dtypes + ) + ], + include_initial=[False, True], + ) + @jtu.ignore_warning(category=NumpyComplexWarning) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): + if jtu.is_device_tpu(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") + rng = jtu.rand_some_zero(self.rng()) + + # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as + # input because we rely on JAX-specific casting behavior + def args_maker(): + x = jnp.array(rng(shape, dtype)) + if out_dtype in unsigned_dtypes: + x = 10 * jnp.abs(x) + return [x] + kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial) + + if jtu.numpy_version() >= (2, 1, 0): + np_op = np.cumulative_prod + else: + def np_op(x, axis=None, dtype=None, include_initial=False): + axis = axis or 0 + out = np.cumprod(x, axis=axis, dtype=dtype or x.dtype) + if include_initial: + ones_shape = list(x.shape) + ones_shape[axis] = 1 + out = jnp.concat([jnp.ones(ones_shape, dtype=out.dtype), out], axis=axis) + return out + + np_fun = lambda x: np_op(x, **kwargs) + jnp_fun = lambda x: jnp.cumulative_prod(x, **kwargs) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a6d9c613379c..24c9dd0ca863 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,8 +51,8 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements -from jax._src.util import safe_zip, NumpyComplexWarning +from jax._src.lib import version as jaxlib_version +from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -1493,8 +1493,10 @@ def testTrimZerosNotOneDArray(self): def testPoly(self, a_shape, dtype, rank): if dtype in (np.float16, jnp.bfloat16, np.int16): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and not jtu.test_device_matches(["cpu"]): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") + if rank == 2 and jaxlib_version <= (0, 4, 35) and jtu.test_device_matches(["gpu"]): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): @@ -2785,7 +2787,7 @@ def testSearchsortedDtype(self): message="NumPy will stop allowing conversion.*"): out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) else: - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype.*int64"): with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) @@ -3426,13 +3428,8 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CompileAndCheck(jnp_fun, args_maker) def testReshapeDeprecatedArgs(self): - msg = "The newshape argument of jax.numpy.reshape is deprecated." - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-reshape-newshape"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(msg): + msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." + with self.assertRaisesRegex(TypeError, msg): jnp.reshape(jnp.arange(4), newshape=(2, 2)) @jtu.sample_product( @@ -4906,7 +4903,7 @@ def testAtLeastNdLiterals(self, dtype, op): @jtu.sample_product( shape=[(0,), (5,), (10,)], - dtype=int_dtypes, + dtype=int_dtypes + bool_dtypes, weights=[True, False], minlength=[0, 20], length=[None, 8], @@ -5963,6 +5960,45 @@ def np_fun(a, i, v): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) + for a_shape in nonempty_array_shapes + for axis in list(range(-len(a_shape), len(a_shape))) + for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for v_shape in [(), (1,), i_shape] + ] + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) + for a_shape in nonempty_array_shapes + for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)] + for v_shape in [(), (1,), i_shape] + ], + dtype=jtu.dtypes.all, + mode=[None, "promise_in_bounds", "clip"], + ) + def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode): + a_rng = jtu.rand_default(self.rng()) + if axis is None: + size = math.prod(a_shape) + else: + size = a_shape[axis] + i_rng = jtu.rand_indices_unique_along_axis(self.rng()) + + def args_maker(): + a = a_rng(a_shape, dtype) + i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis) + v = a_rng(v_shape, dtype) + return a, i, v + + def np_fun(a, i, v): + a_copy = a.copy() + np.put_along_axis(a_copy, i, v, axis=axis) + return a_copy + + jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + def test_rot90_error(self): with self.assertRaisesRegex( ValueError, @@ -6186,10 +6222,115 @@ class NumpySignaturesTest(jtu.JaxTestCase): def testWrappedSignaturesMatch(self): """Test that jax.numpy function signatures match numpy.""" - jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)} - func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items() - if getattr(fun, '__np_wrapped__', None) is not None} - assert len(func_pairs) > 0 + # NumPy functions explicitly not implemented in JAX: + skip = {'array2string', + 'asanyarray', + 'asarray_chkfinite', + 'ascontiguousarray', + 'asfortranarray', + 'asmatrix', + 'base_repr', + 'binary_repr', + 'bmat', + 'broadcast', + 'busday_count', + 'busday_offset', + 'busdaycalendar', + 'common_type', + 'copyto', + 'datetime_as_string', + 'datetime_data', + 'errstate', + 'flatiter', + 'format_float_positional', + 'format_float_scientific', + 'fromregex', + 'genfromtxt', + 'get_include', + 'getbufsize', + 'geterr', + 'geterrcall', + 'in1d', + 'info', + 'is_busday', + 'isfortran', + 'isnat', + 'loadtxt', + 'matrix', + 'matvec', + 'may_share_memory', + 'memmap', + 'min_scalar_type', + 'mintypecode', + 'ndenumerate', + 'ndindex', + 'nditer', + 'nested_iters', + 'poly1d', + 'putmask', + 'real_if_close', + 'recarray', + 'record', + 'require', + 'row_stack', + 'savetxt', + 'savez_compressed', + 'setbufsize', + 'seterr', + 'seterrcall', + 'shares_memory', + 'show_config', + 'show_runtime', + 'test', + 'trapz', + 'typename', + 'vecmat'} + + # symbols removed in NumPy 2.0 + skip |= {'add_docstring', + 'add_newdoc', + 'add_newdoc_ufunc', + 'alltrue', + 'asfarray', + 'byte_bounds', + 'compare_chararrays', + 'cumproduct', + 'deprecate', + 'deprecate_with_doc', + 'disp', + 'fastCopyAndTranspose', + 'find_common_type', + 'get_array_wrap', + 'geterrobj', + 'issctype', + 'issubclass_', + 'issubsctype', + 'lookfor', + 'mat', + 'maximum_sctype', + 'msort', + 'obj2sctype', + 'product', + 'recfromcsv', + 'recfromtxt', + 'round_', + 'safe_eval', + 'sctype2char', + 'set_numeric_ops', + 'set_string_function', + 'seterrobj', + 'sometrue', + 'source', + 'who'} + + self.assertEmpty(skip.intersection(dir(jnp))) + + names = (name for name in dir(np) if not (name.startswith('_') or name in skip)) + names = (name for name in names if callable(getattr(np, name))) + names = {name for name in names if not isinstance(getattr(np, name), type)} + self.assertEmpty(names.difference(dir(jnp))) + + self.assertNotEmpty(names) # TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names. unsupported_params = { @@ -6200,6 +6341,7 @@ def testWrappedSignaturesMatch(self): 'copy': ['subok'], 'corrcoef': ['ddof', 'bias', 'dtype'], 'cov': ['dtype'], + 'cumulative_prod': ['out'], 'cumulative_sum': ['out'], 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], @@ -6211,9 +6353,7 @@ def testWrappedSignaturesMatch(self): 'full': ['order', 'like'], 'full_like': ['subok', 'order'], 'fromfunction': ['like'], - 'histogram': ['normed'], - 'histogram2d': ['normed'], - 'histogramdd': ['normed'], + 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], 'nanstd': ['correction', 'mean'], @@ -6223,29 +6363,30 @@ def testWrappedSignaturesMatch(self): 'partition': ['kind', 'order'], 'percentile': ['weights'], 'quantile': ['weights'], - 'reshape': ['shape', 'copy'], 'row_stack': ['casting'], 'stack': ['casting'], 'std': ['mean'], 'tri': ['like'], + 'trim_zeros': ['axis'], 'var': ['mean'], 'vstack': ['casting'], 'zeros_like': ['subok', 'order'] } extra_params = { - # TODO(micky774): Remove when np.clip has adopted the Array API 2023 - # standard - 'clip': ['x', 'max', 'min'], + 'compress': ['size', 'fill_value'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], + 'load': ['args', 'kwargs'], 'take_along_axis': ['mode', 'fill_value'], 'fill_diagonal': ['inplace'], } mismatches = {} - for name, (jnp_fun, np_fun) in func_pairs.items(): + for name in names: + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. @@ -6259,12 +6400,15 @@ def testWrappedSignaturesMatch(self): # TODO(dfm): After our deprecation period for the clip arguments ends # it should be possible to reintroduce the check. continue - # Note: can't use inspect.getfullargspec due to numpy issue + if name == "reshape": + # Similar issue to clip: we'd need logic specific to the NumPy version + # because of the change in argument name from `newshape` to `shape`. + continue + # Note: can't use inspect.getfullargspec for some functions due to numpy issue # https://github.com/numpy/numpy/issues/12225 try: np_params = inspect.signature(np_fun).parameters except ValueError: - # Some functions cannot be inspected continue jnp_params = inspect.signature(jnp_fun).parameters extra = set(extra_params.get(name, [])) @@ -6299,7 +6443,8 @@ def testWrappedSignaturesMatch(self): _available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all if dtype != dtypes.bfloat16] -UNIMPLEMENTED_UFUNCS = {'spacing'} +# TODO(jakevdp): implement missing ufuncs. +UNIMPLEMENTED_UFUNCS = {'spacing', 'matvec', 'vecmat'} def _all_numpy_ufuncs() -> Iterator[str]: @@ -6351,8 +6496,6 @@ def testUfuncInputTypes(self, name, arg_dtypes): class NumpyDocTests(jtu.JaxTestCase): def test_lax_numpy_docstrings(self): - # Test that docstring wrapping & transformation didn't fail. - unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', 'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift', @@ -6372,15 +6515,6 @@ def test_lax_numpy_docstrings(self): elif hasattr(np, name) and obj is getattr(np, name): # Some APIs are imported directly from NumPy; we don't check these. pass - elif hasattr(obj, '__np_wrapped__'): - # Functions decorated with @implements(...) should have __np_wrapped__ - wrapped_fun = obj.__np_wrapped__ - if wrapped_fun is not None: - # If the wrapped function has a docstring, obj should too - if wrapped_fun.__doc__ and not obj.__doc__: - raise Exception(f"jnp.{name} does not contain wrapped docstring.") - if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__: - raise Exception(f"jnp.{name} does not have a wrapped docstring.") elif name in aliases: assert "Alias of" in obj.__doc__ elif name not in skip_args_check: @@ -6392,84 +6526,6 @@ def test_lax_numpy_docstrings(self): if name not in ["frompyfunc", "isdtype", "promote_types"]: self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}") - @parameterized.named_parameters( - {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False]) - def test_wrapped_function_parameters(self, jit): - def orig(x): - """Example Docstring - - Parameters - ---------- - x : array_like - Input Data - - .. versionadded:: 1.8.0 - out : array_like, optional - Output to overwrite - other_arg : Any - not used - - Returns - ------- - x : input - """ - return x - - def wrapped(x, out=None): - return x - - if jit: - wrapped = jax.jit(wrapped) - - wrapped = implements(orig)(wrapped) - doc = wrapped.__doc__ - - self.assertStartsWith(doc, "Example Docstring") - self.assertIn("Original docstring below", doc) - self.assertIn("Parameters", doc) - self.assertIn("Returns", doc) - self.assertNotIn('other_arg', doc) - self.assertNotIn('versionadded', doc) - - - def test_parse_numpydoc(self): - # Unit test ensuring that _parse_numpydoc correctly parses docstrings for all - # functions in NumPy's top-level namespace. - section_titles = {'Attributes', 'Examples', 'Notes', - 'Parameters', 'Raises', 'References', - 'Returns', 'See also', 'See Also', 'Warnings', 'Warns'} - headings = [title + '\n' + '-'*len(title) for title in section_titles] - - for name in dir(np): - if name.startswith('_'): - continue - obj = getattr(np, name) - if isinstance(obj, type): - continue - if not callable(obj): - continue - if 'built-in function' in repr(obj): - continue - parsed = _parse_numpydoc(obj.__doc__) - - # Check that no docstring is handled gracefully. - if not obj.__doc__: - self.assertEqual(parsed, ParsedDoc(obj.__doc__)) - continue - - # Check that no unexpected section names are found. - extra_keys = parsed.sections.keys() - section_titles - if extra_keys: - raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}") - - # Check that every docstring has a summary. - if not parsed.summary: - raise ValueError(f"No summary found for np.{name}") - - # Check that no expected headings are missed. - for heading in headings: - assert heading not in parsed.front_matter - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 61c86c0a05e4..20a1a58a9dbe 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -179,13 +179,15 @@ def test_unary_ufunc_call(self, name, dtype, shape): rhs_shape=broadcast_compatible_shapes, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_bimary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + def test_binary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( @@ -218,7 +220,9 @@ def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker, tol=tol) self._CompileAndCheck(jnp_fun.outer, args_maker) @jtu.sample_product( @@ -259,7 +263,9 @@ def test_binary_ufunc_reduce(self, name, shape, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( @@ -315,7 +321,9 @@ def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): rng_where = jtu.rand_bool(self.rng()) args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] - self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( @@ -356,8 +364,10 @@ def np_fun_accumulate(x): result = np_fun.accumulate(x, axis=axis) return result if x.dtype == bool else result.astype(x.dtype) - self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) - self._CompileAndCheck(jnp_fun_accumulate, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_accumulate, args_maker, tol=tol) @jtu.sample_product( SCALAR_FUNCS, @@ -400,7 +410,9 @@ def np_fun_at(x, idx): np_fun.at(x_copy, idx) return x_copy - self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_at, args_maker) @jtu.sample_product( @@ -422,7 +434,9 @@ def np_fun_at(x, idx, y): np_fun.at(x_copy, idx, y) return x_copy - self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_at, args_maker) def test_frompyfunc_at_broadcasting(self): @@ -483,7 +497,9 @@ def np_fun_reduceat(x, i): # Numpy has different casting behavior. return np_fun.reduceat(x, i).astype(x.dtype) - self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker, tol=tol) self._CompileAndCheck(jnp_fun.reduceat, args_maker) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index bd3bca5385b7..5753628957c7 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -20,10 +20,11 @@ from absl.testing import parameterized import numpy as np +import scipy import scipy.special as osp_special import jax -from jax._src import deprecations +import jax.numpy as jnp from jax._src import test_util as jtu from jax.scipy import special as lsp_special @@ -215,7 +216,7 @@ def partial_lax_op(*vals): n=[0, 1, 2, 3, 10, 50] ) def testScipySpecialFunBernoulli(self, n): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. scipy_op = lambda: osp_special.bernoulli(n).astype(dtype) lax_op = functools.partial(lsp_special.bernoulli, n) args_maker = lambda: [] @@ -223,16 +224,33 @@ def testScipySpecialFunBernoulli(self, n): self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5) def testGammaSign(self): - # Test that the sign of `gamma` matches at integer-valued inputs. - dtype = jax.numpy.zeros(0).dtype # default float dtype. - args_maker = lambda: [np.arange(-10, 10).astype(dtype)] - rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 - self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol) - self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol) + dtype = jnp.zeros(0).dtype # default float dtype. + typ = dtype.type + testcases = [ + (np.arange(-10, 0).astype(dtype), np.array([np.nan] * 10, dtype=dtype)), + (np.nextafter(np.arange(-5, 0).astype(dtype), typ(-np.inf)), + np.array([1, -1, 1, -1, 1], dtype=dtype)), + (np.nextafter(np.arange(-5, 0).astype(dtype), typ(np.inf)), + np.array([-1, 1, -1, 1, -1], dtype=dtype)), + (np.arange(0, 10).astype(dtype), np.ones((10,), dtype)), + (np.nextafter(np.arange(0, 10).astype(dtype), typ(np.inf)), + np.ones((10,), dtype)), + (np.nextafter(np.arange(1, 10).astype(dtype), typ(-np.inf)), + np.ones((9,), dtype)), + (np.array([-np.inf, -0.0, 0.0, np.inf, np.nan]), + np.array([np.nan, -1.0, 1.0, 1.0, np.nan])) + ] + for inp, out in testcases: + self.assertArraysEqual(out, lsp_special.gammasgn(inp)) + self.assertArraysEqual(out, jnp.sign(lsp_special.gamma(inp))) + if jtu.parse_version(scipy.__version__) >= (1, 15): + self.assertArraysEqual(out, osp_special.gammasgn(inp)) + self.assertAllClose(osp_special.gammasgn(inp), + lsp_special.gammasgn(inp)) def testNdtriExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.arange(-10, 10).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) @@ -240,7 +258,7 @@ def testNdtriExtremeValues(self): def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype), np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 @@ -252,22 +270,8 @@ def testBetaParameterDeprecation(self): lsp_special.beta(1, 1) lsp_special.beta(1, b=1) lsp_special.beta(a=1, b=1) - if deprecations.is_accelerated('jax-scipy-beta-args'): - with self.assertRaises(ValueError): - lsp_special.beta(x=1, y=1) - else: - with self.assertWarns(DeprecationWarning): - lsp_special.beta(1, y=1) - with self.assertWarns(DeprecationWarning): - lsp_special.beta(a=1, y=1) - with self.assertWarns(DeprecationWarning): - lsp_special.beta(x=1, b=1) - with self.assertWarns(DeprecationWarning): - lsp_special.beta(x=1, y=1) - with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): - lsp_special.beta(1, x=1) - with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): - lsp_special.beta(b=1, y=1) + with self.assertRaises(TypeError): + lsp_special.beta(x=1, y=1) if __name__ == "__main__": diff --git a/tests/lax_test.py b/tests/lax_test.py index 1581c61d57eb..89d41d0b9312 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -41,7 +41,6 @@ from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util @@ -1077,15 +1076,15 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): if jtu.dtypes.supported([dtype]) ]) def testDotAlgorithm(self, algorithm, dtype): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, lax.DotAlgorithmPreset.F16_F16_F16, lax.DotAlgorithmPreset.F32_F32_F32, lax.DotAlgorithmPreset.F64_F64_F64, + lax.DotAlgorithmPreset.BF16_BF16_F32, + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, }: raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") @@ -1122,11 +1121,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on TPU." ) - if algorithm != lax.DotAlgorithmPreset.DEFAULT and dtype != np.float32: - raise SkipTest( - f"The dot algorithm '{algorithm}' is only supported for float32 on" - " TPU." - ) lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) @@ -1135,9 +1129,6 @@ def testDotAlgorithm(self, algorithm, dtype): self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): raise SkipTest("Not supported on CPU.") lhs_shape = (3, 4) @@ -1148,9 +1139,6 @@ def testDotAlgorithmInvalidFloat8Type(self): lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32") def testDotAlgorithmCasting(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["tpu"]): raise SkipTest("F32_F32_F32 is not supported on TPU.") def fun(lhs, rhs): @@ -1161,6 +1149,31 @@ def fun(lhs, rhs): lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) self.assertEqual(fun(lhs, rhs).dtype, np.float16) + def testDotAlgorithmAllowedOutputStorage(self): + # see https://github.com/jax-ml/jax/issues/24794 + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only supported on GPU.") + def fun(lhs, rhs): + return lax.dot(lhs, rhs, precision="F16_F16_F32", + preferred_element_type=np.float16) + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) + self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text()) + + def testDotAlgorithmConfig(self): + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + + expected = ("algorithm = :" + " tensor<1x3xi64>}}", + mlir_module, + ) + @jtu.sample_product( [ {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, @@ -1587,6 +1650,8 @@ def testPadAgainstNumpy(self, shape, dtype, pads): self._CheckAgainstNumpy(numpy_op, op, args_maker) def testPadErrors(self): + with self.assertRaisesRegex(ValueError, "padding_value must be a scalar"): + lax.pad(np.zeros(2), np.zeros(2), [(0, 0, 0)]) with self.assertRaisesRegex(ValueError, "padding_config"): lax.pad(np.zeros(2), 0., [(0, 1, 0), (0, 1, 0)]) with self.assertRaisesRegex(ValueError, "interior padding in padding_config must be nonnegative"): @@ -3536,7 +3601,7 @@ def testAsarray(self, typ): with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() - def testOptimizationBarrier(self): + def test_optimization_barrier(self): x = lax.optimization_barrier((2, 3)) self.assertEqual((2, 3), x) @@ -3820,11 +3885,11 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(xs, shardings, layouts): +def shard_foo_array_handler(xs, shardings, layouts, copy_semantics): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) + aval = core.get_aval(x.data) results.append(pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) return results @@ -4352,12 +4417,6 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'sign': regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') - elif name == 'square': - if is_cuda: - regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real') - if is_cpu: - regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real') - elif name == 'log': regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag') @@ -4400,8 +4459,8 @@ def regions_with_inaccuracies_keep(*to_keep): elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') - elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}: + elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan', 'log1p', + 'arcsin', 'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 83d4d657751b..bfe9fecd6c7e 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -691,6 +691,25 @@ def testTopK(self, shape, dtype, k, bdims): op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) + @jtu.sample_product( + [dict(shape=shape, bdims=bdims) + for shape in [(8,), (3, 4, 5)] + for bdims in lax_test_util.all_bdims(shape)], + dtype=lax_test_util.default_dtypes, + ) + def test_optimization_barrier_vmap(self, shape, dtype, bdims): + rng = jtu.rand_small(self.rng()) + self._CheckBatching(lax.optimization_barrier, 5, bdims, (shape,), (dtype,), + rng) + + def test_optimization_barrier_vmap_out_axes(self): + x = jnp.arange(8) + y = x.reshape(1, 8) + out = jax.vmap(lax.optimization_barrier, in_axes=((0, 1),), + out_axes=(0, 1))((x, y)) + self.assertArraysEqual(out[0], x) + self.assertArraysEqual(out[1], y) + @jtu.sample_product( [dict(shape=shape, bdims=bdims, dimension=dimension, arity=arity) for shape in [(2, 3)] diff --git a/tests/layout_test.py b/tests/layout_test.py index 9d26d96e2ae5..afddab916723 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,6 +25,7 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -600,6 +601,98 @@ def g(x): ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"): g(jnp.arange(8)) + def test_sparsecore_compute(self): + if not (jax.devices()[0].device_kind == 'TPU v5' or + jtu.is_device_tpu_at_least(6)): + self.skipTest('Does not have a sparsecore present') + shape = (128, 128) + inp = jnp.arange(math.prod(shape)).reshape(shape) + + dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + s = SingleDeviceSharding(jax.devices()[0]) + sparse_layout = Layout(dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + dense_layout = Layout(DLL(major_to_minor=(0, 1)), s) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_compute(x): + return x * x + + @partial(jax.jit, out_shardings=(dense_layout, sparse_layout)) + def f(x, y): + return x * 2, sparsecore_compute(y) + + f(inp, sparecore_arr) + + def test_sparsecore_compute_twice(self): + if not ( + jax.devices()[0].device_kind == 'TPU v5' + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest('Does not have a sparsecore present') + shape = (4096, 8) + inp = jnp.arange(math.prod(shape)).reshape(shape) + + dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + s = SingleDeviceSharding(jax.devices()[0]) + sparse_layout = Layout(dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_multiply(x, y): + return x * y + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_add(x, y): + return x + y + + @partial(jax.jit, donate_argnums=0, out_shardings=sparse_layout) + def f(x): + return sparsecore_multiply(sparsecore_add(x, x) + 1, x) + + f(sparecore_arr) + + def test_sparsecore_and_host_compute(self): + if not ( + jax.devices()[0].device_kind == 'TPU v5' + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest('Does not have a sparsecore present') + shape = (128, 128) + inp = jnp.arange(math.prod(shape)).reshape(shape) + s = SingleDeviceSharding(jax.devices()[0]) + + sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + sparse_layout = Layout(sparse_dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + + host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) + host_layout = Layout(host_dll, s) + host_arr = jax.device_put(inp, host_layout) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_compute(x): + return x * x + + @compute_on('device_host') + @jax.jit + def host_compute(x): + return x + x + + @partial( + jax.jit, + in_shardings=(sparse_layout, host_layout), + out_shardings=(sparse_layout, host_layout), + ) + def f(x, y): + return sparsecore_compute(x), host_compute(y) + + f(sparecore_arr, host_arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ecf18..0da09e232deb 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,8 @@ from functools import partial import itertools +from typing import Iterator +from unittest import skipIf import numpy as np import scipy @@ -53,6 +55,36 @@ def _is_required_cuda_version_satisfied(cuda_version): else: return int(version.split()[-1]) >= cuda_version + +def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]: + """ + Generate a range of valid axis arguments for a reduction over + an array with a given number of dimensions. + """ + yield from (None, ()) + if ndim > 0: + yield from (0, (-1,)) + if ndim > 1: + yield from (1, (0, 1), (-1, 0)) + if ndim > 2: + yield (-1, 0, 1) + + +def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: + """scipy.linalg.toeplitz with v1.17+ batching semantics.""" + if scipy_version >= (1, 17, 0): + return scipy.linalg.toeplitz(c, r) + elif r is None: + c = np.atleast_1d(c) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c) + else: + c = np.atleast_1d(c) + r = np.atleast_1d(r) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r) + + class NumpyLinalgTest(jtu.JaxTestCase): @jtu.sample_product( @@ -234,11 +266,11 @@ def testIssue1213(self): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] @@ -277,12 +309,12 @@ def check_left_eigenvectors(a, w, vl): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, @@ -293,15 +325,15 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, - ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + ) + @jtu.run_on_devices("cpu", "gpu") def testEigvalsGrad(self, shape, dtype): # This test sometimes fails for large matrices. I (@j-towns) suspect, but # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -313,10 +345,10 @@ def testEigvalsGrad(self, shape, dtype): shape=[(4, 4), (5, 5), (50, 50)], dtype=float_types + complex_types, ) - # TODO: enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -324,9 +356,11 @@ def testEigvals(self, shape, dtype): w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -334,8 +368,10 @@ def testEigvalsInf(self): shape=[(1, 1), (4, 4), (5, 5)], dtype=float_types + complex_types, ) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape args = rng(shape, dtype) @@ -687,29 +723,25 @@ def testMatrixNorm(self, shape, dtype, keepdims, ord): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) self._CompileAndCheck(jnp_fn, args_maker) + @skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0") @jtu.sample_product( - shape=[(3,), (3, 4), (2, 3, 4, 5)], + [ + dict(shape=shape, axis=axis) + for shape in [(3,), (3, 4), (2, 3, 4, 5)] + for axis in _axis_for_ndim(len(shape)) + ], dtype=float_types + complex_types, keepdims=[True, False], - axis=[0, None], ord=[1, -1, 2, -2, np.inf, -np.inf], ) def testVectorNorm(self, shape, dtype, keepdims, axis, ord): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - def np_fn(x, *, ord, keepdims, axis): - x = np.asarray(x) - if axis is None: - result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0) - return np.reshape(result, (1,) * x.ndim) if keepdims else result - return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis) - else: - np_fn = np.linalg.vector_norm - np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis) + np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) - self._CompileAndCheck(jnp_fn, args_maker) + tol = 1E-3 if jtu.test_device_matches(['tpu']) else None + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) # jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here. @jtu.sample_product( @@ -1990,11 +2022,11 @@ def testSqrtmEdgeCase(self, diag, expected, dtype): self.assertAllClose(root, expected, check_dtypes=False) @jtu.sample_product( - cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)], + cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)], cdtype=float_types + complex_types, - rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)], + rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)], rdtype=float_types + complex_types + int_types) - def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): + def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype): if ((rdtype in [np.float64, np.complex128] or cdtype in [np.float64, np.complex128]) and not config.enable_x64.value): @@ -2007,10 +2039,11 @@ def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)] - with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) - self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) + with jax.numpy_rank_promotion("allow"): + with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): + self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz), + jsp.linalg.toeplitz, args_maker) + self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) @jtu.sample_product( shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)], @@ -2028,8 +2061,7 @@ def testToeplitzSymmetricConstruction(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) + self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker) self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) def testToeplitzConstructionWithKnownCases(self): diff --git a/tests/logging_test.py b/tests/logging_test.py index a1d6695a1e37..a83058095ce6 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,8 +15,9 @@ import contextlib import io import logging -import os import platform +import re +import shlex import subprocess import sys import tempfile @@ -26,6 +27,7 @@ import jax import jax._src.test_util as jtu from jax._src import xla_bridge +from jax._src.logging_config import _default_TF_CPP_MIN_LOG_LEVEL # Note: importing absltest causes an extra absl root log handler to be # registered, which causes extra debug log messages. We don't expect users to @@ -49,10 +51,23 @@ def jax_debug_log_modules(value): finally: jax.config.update("jax_debug_log_modules", original_value) +@contextlib.contextmanager +def jax_logging_level(value): + # jax_logging_level doesn't have a context manager, because it's + # not thread-safe. But since tests are always single-threaded, we + # can define one here. + original_value = jax.config.jax_logging_level + jax.config.update("jax_logging_level", value) + try: + yield + finally: + jax.config.update("jax_logging_level", original_value) + @contextlib.contextmanager def capture_jax_logs(): log_output = io.StringIO() + handler = logging.StreamHandler(log_output) logger = logging.getLogger("jax") @@ -91,21 +106,8 @@ def test_no_log_spam(self): """)) python = sys.executable assert "python" in python - env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} - if os.getenv("ASAN_OPTIONS"): - env_variables["ASAN_OPTIONS"] = os.getenv("ASAN_OPTIONS") - if os.getenv("PYTHONPATH"): - env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") - if os.getenv("LD_LIBRARY_PATH"): - env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") - if os.getenv("LD_PRELOAD"): - env_variables["LD_PRELOAD"] = os.getenv("LD_PRELOAD") # Make sure C++ logging is at default level for the test process. - proc = subprocess.run( - [python, f.name], - capture_output=True, - env=env_variables, - ) + proc = subprocess.run([python, f.name], capture_output=True) lines = proc.stdout.split(b"\n") lines.extend(proc.stderr.split(b"\n")) @@ -155,6 +157,159 @@ def test_debug_logging(self): jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) + @jtu.skip_on_devices("tpu") + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_stderr_info_logging(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + # test INFO + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + info_lines = log_output.split("\n") + self.assertGreater(len(info_lines), 0) + self.assertIn("INFO", log_output) + self.assertNotIn("DEBUG", log_output) + + @jtu.skip_on_devices("tpu") + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_stderr_debug_logging(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + # test DEBUG + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + self.assertIn("INFO", log_output) + self.assertIn("DEBUG", log_output) + + # test JAX_DEBUG_MODULES + cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + self.assertIn("DEBUG", log_output) + + @jtu.skip_on_devices("tpu") + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_toggling_logging_level(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + _separator = "---------------------------" + program = f""" + import sys + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + jax.config.update("jax_logging_level", None) + sys.stderr.write("{_separator}") + jax.jit(lambda x: x)(1) # should not log anything now + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + m = re.search(_separator, log_output) + self.assertTrue(m is not None) + log_output_verbose = log_output[:m.start()] + log_output_silent = log_output[m.end():] + + self.assertIn("Finished tracing + transforming for pjit", + log_output_verbose) + self.assertEqual(log_output_silent, "") + + @jtu.skip_on_devices("tpu") + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_double_logging_absent(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import jax # this prints INFO logging from backend imports + jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch") + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + self.assertNotEmpty(log_output) + log_lines = log_output.strip().split("\n") + # only one tracing line should be printed, if there's more than one + # then logs are printing duplicated + self.assertLen([line for line in log_lines + if "Finished tracing + transforming" in line], 1) + + @jtu.skip_on_devices("tpu") + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_cpp_logging_level(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import sys + import jax # this prints INFO logging from backend imports + jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + # verbose logging: DEBUG, VERBOSE + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertIn("Initializing CoordinationService", p.stderr) + + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertIn("Initializing CoordinationService", p.stderr) + + # verbose logging: WARNING, None + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertNotIn("Initializing CoordinationService", p.stderr) + + cmd = shlex.split(f"{sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: + self.assertNotIn("Initializing CoordinationService", p.stderr) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py new file mode 100644 index 000000000000..bf9c0fb6b51d --- /dev/null +++ b/tests/magma_linalg_test.py @@ -0,0 +1,124 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np + +from absl.testing import absltest + +import jax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import linalg as lax_linalg +from jax._src.lib import gpu_solver + +config.parse_flags_with_absl() + +float_types = jtu.dtypes.floating +complex_types = jtu.dtypes.complex + + +class MagmaLinalgTest(jtu.JaxTestCase): + + @jtu.sample_product( + shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEig(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + rng = jtu.rand_default(self.rng()) + n = shape[-1] + args_maker = lambda: [rng(shape, dtype)] + + # Norm, adjusted for dimension and type. + def norm(x): + norm = np.linalg.norm(x, axis=(-2, -1)) + return norm / ((n + 1) * jnp.finfo(dtype).eps) + + def check_right_eigenvectors(a, w, vr): + self.assertTrue( + np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) + + def check_left_eigenvectors(a, w, vl): + rank = len(a.shape) + aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) + wC = jnp.conj(w) + check_right_eigenvectors(aH, wC, vl) + + a, = args_maker() + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + w = results[0] + + if compute_left_eigenvectors: + check_left_eigenvectors(a, w, results[1]) + if compute_right_eigenvectors: + check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) + + self._CompileAndCheck(jnp.linalg.eig, args_maker, rtol=1e-3) + + @jtu.sample_product( + shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + a = jnp.full(shape, jnp.nan, dtype) + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + for result in results: + self.assertTrue(np.all(np.isnan(result))) + + def testEigMagmaConfig(self): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + rng = jtu.rand_default(self.rng()) + a = rng((5, 5), np.float32) + with config.gpu_use_magma("on"): + hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text() + self.assertIn('magma = "on"', hlo) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/memories_test.py b/tests/memories_test.py index 7f05ac424127..fcb1d6bdc2b1 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -26,7 +26,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.layout import DeviceLocalLayout as DLL, Layout -from jax._src.lib import xla_extension_version from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp @@ -655,8 +654,6 @@ def f(): @jtu.run_on_devices('tpu') def test_ragged_copy_on_host(self): - if xla_extension_version < 290: - self.skipTest('Requires xla_extension_version >= 290') mesh = jtu.create_mesh((2,), ('x')) sharding = jax.sharding.NamedSharding(mesh, P(('x'))) cpu_sharding = sharding.with_memory_kind('pinned_host') @@ -698,6 +695,34 @@ def foo(x): if compiled_text is not None: self.assertIn('custom_call_target="AllocateBuffer"', compiled_text) + def test_disallow_alias_copies_arrays(self): + mesh = jtu.create_mesh((2,), ("x",)) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P("x"), memory_kind="pinned_host") + inp_host = jax.device_put(np_inp, s) + + inp_host_copy = jax.device_put(inp_host, may_alias=False) + + for a in jax.tree.leaves(inp_host): + a.delete() + + jax.block_until_ready(inp_host_copy) + + def test_disallow_alias_copies_arrays_with_donated_input(self): + mesh = jtu.create_mesh((2,), ("x",)) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P("x"), memory_kind="pinned_host") + inp_host = jax.device_put(np_inp, s) + + inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host) + + inp_host_donate_copy = jax.device_put(inp_host_donate, may_alias=False) + + for a in jax.tree.leaves(inp_host_donate): + a.delete() + + jax.block_until_ready(inp_host_donate_copy) + class ComputeOffload(jtu.BufferDonationTestCase): @@ -778,6 +803,46 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_host_shared_sharding(self): + mesh = jtu.create_mesh((2,), ("x")) + device_sharding = NamedSharding(mesh, P("x")) + host_sharding = device_sharding.with_memory_kind("pinned_host") + + @compute_on("device_host") + @functools.partial( + jax.jit, + in_shardings=(host_sharding, device_sharding), + out_shardings=(host_sharding, device_sharding), + donate_argnums=(0, 1), + ) + def host_func(x, y): + return (x * y), ((x**2) * (y**2)) + + @functools.partial( + jax.jit, + in_shardings=(host_sharding, device_sharding), + out_shardings=(host_sharding, device_sharding), + donate_argnums=(0), + ) + def device_func(host_data, device_data): + host_data, device_data = host_func(host_data, device_data) + device_data = device_data * 2 + host_data, device_data = host_func(host_data, device_data) + return (host_data, device_data) + + input_x = jnp.ones(8) + input_host = jax.device_put(input_x, host_sharding) + + input_device = jnp.arange(8) + input_device = jnp.where(input_device < 4, 0, 1) + input_device = jax.device_put(input_device, device_sharding) + + output_host, output_device = device_func(input_host, input_device) + self.assertEqual(output_host.sharding.memory_kind, 'pinned_host') + self.assertEqual(output_device.sharding.memory_kind, 'device') + self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.]) + self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.]) + def test_compute_on_basic_inline(self): @compute_on('device_host') @jax.jit @@ -1508,7 +1573,7 @@ def test_fn(x_in, y_in): test_fn, out_shardings=( Layout(custom_dll, sharding), - Layout(custom_dll, p_sharding), + Layout(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) @@ -1516,10 +1581,6 @@ def test_fn(x_in, y_in): self.assertArraysEqual(y_out, y1 + y1) def test_compute_offload_mesh_with_linear_layout(self): - if config.use_shardy_partitioner.value: - self.skipTest( - "Shardy inlines the host compute. Remove when that's fixed." - ) mesh = jtu.create_mesh((2, 2), ("x", "y")) sharding = NamedSharding(mesh, P("x", "y")) p_sharding = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") @@ -1551,7 +1612,7 @@ def test_fn(x_in, y_in): test_fn, out_shardings=( Layout(custom_dll, sharding), - Layout(custom_dll, p_sharding), + Layout(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 42522d7f4b1b..4f1b1fb037d6 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -205,31 +205,6 @@ def mock_2x2x2_v5e_devices(one_device_per_chip=True): class MeshUtilsTest(test_util.JaxTestCase): - @parameterized.named_parameters( - ('1x1', mock_1x1_devices, (1, 1, 1, 2)), - ('2x2', mock_2x2_devices, (2, 2, 1, 2)), - ('4x4', mock_4x4_devices, (4, 4, 1, 2)), - ('8x8', mock_8x8_devices, (8, 8, 1, 2)), - ) - def test_bounds_from_last_device_2d(self, devices, expected_bounds): - self.assertEqual( - mesh_utils._bounds_from_last_device(devices()[-1]), - expected_bounds) - - @parameterized.named_parameters( - ('1x2x1_t', mock_1x2x1_devices, True, (1, 2, 1, 1)), - ('1x2x1_f', mock_1x2x1_devices, False, (1, 2, 1, 2)), - ('2x2x1_t', mock_2x2x1_devices, True, (2, 2, 1, 1)), - ('2x2x1_f', mock_2x2x1_devices, False, (2, 2, 1, 2)), - ('8x8x16_t', mock_8x8x16_devices, True, (8, 8, 16, 1)), - ('8x8x16_f', mock_8x8x16_devices, False, (8, 8, 16, 2)), - ) - def test_bounds_from_last_device_3d(self, devices, one_device_per_chip, - expected_bounds): - self.assertEqual( - mesh_utils._bounds_from_last_device(devices(one_device_per_chip)[-1]), - expected_bounds) - @parameterized.named_parameters( ('1x2x1_t', (1, 2, 1), True), ('4x4x4_t', (4, 4, 4), True), @@ -378,6 +353,12 @@ def test_create_device_mesh_for_nd_torus( ) self.assertArraysEqual(assignment, expected_assignment_matrix) + def test_create_device_mesh_non_int_error(self): + with self.assertRaisesRegex( + ValueError, + "`mesh_shape` passed to `create_device_mesh` should be a sequence of ints"): + mesh_utils.create_device_mesh(((4,), 4)) + @parameterized.named_parameters( ('2x2x1', mock_2x2x1_devices,), ('2x2x4', mock_2x2x4_devices, ), diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index 1a4de7456167..b84903618fab 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding @@ -58,10 +59,16 @@ def f(x, y): hlo = f_lowered.compiler_ir() mocked_count = NUM_SHARDS * jax.local_device_count() - self.assertIn( - f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"', - str(hlo) - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}', + str(hlo) + ) + else: + self.assertIn( + f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"', + str(hlo) + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py new file mode 100644 index 000000000000..44ec4e2f9529 --- /dev/null +++ b/tests/mock_gpu_topology_test.py @@ -0,0 +1,60 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +jax.config.parse_flags_with_absl() + +NUM_SLICES = 2 +NUM_HOSTS_PER_SLICE = 4 + + +@jtu.with_config( + jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1", + jax_cuda_visible_devices="0") +class MockGPUTopologyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Mocking devices only works on the GPU backend.") + super().setUp() + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockDeviceCount(self): + self.assertEqual(jax.device_count(), NUM_SLICES * NUM_HOSTS_PER_SLICE) + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockWithSharding(self): + mesh = jax.sharding.Mesh(jax.devices(), ('x',)) + f = jax.jit(jnp.sum, + in_shardings=NamedSharding(mesh, P('x')), + out_shardings=NamedSharding(mesh, P())) + + f_lowered = f.lower(jnp.arange(16)) + hlo = f_lowered.compiler_ir() + + mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE + self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) + self.assertIn( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', + str(hlo) + ) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 3d1348371f07..6ea9c02b9639 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -16,6 +16,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", "jax_multiplatform_test", + "jax_py_test", "py_deps", ) @@ -36,13 +37,23 @@ jax_multiplatform_test( "gpu_h100", "gpu_h100_2gpu", ], - shard_count = 4, + shard_count = 8, tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_py_test( + name = "gpu_dialect_test", + srcs = ["gpu_dialect_test.py"], + deps = [ + "//jax", + "//jax:mosaic_gpu", + "//jax:test_util", + ] + py_deps("absl/testing"), +) + jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], diff --git a/tests/mosaic/flash_attention_test.py b/tests/mosaic/flash_attention_test.py index 1d15159ca44e..46a2199e19cc 100644 --- a/tests/mosaic/flash_attention_test.py +++ b/tests/mosaic/flash_attention_test.py @@ -43,8 +43,8 @@ def setUp(self): if flash_attention is None: self.skipTest("Mosaic GPU not available.") if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on GPU with capability >= sm90") + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") @parameterized.product( batch_size=(1,), diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py new file mode 100644 index 000000000000..3edddaad9d12 --- /dev/null +++ b/tests/mosaic/gpu_dialect_test.py @@ -0,0 +1,544 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""(Deviceless) tests for the Mosaic GPU MLIR dialect.""" + +from typing import Callable + +from absl.testing import parameterized +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir as mlir_interpreter +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import gpu +from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import memref +from jax._src.lib.mlir.dialects import nvvm +from jax._src.lib.mlir.dialects import scf +from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member +from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import +from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import + +_cext = mgpu._cext if mgpu is not None else None + + +config.parse_flags_with_absl() + + +def _make_ir_context(): + context = ir.Context() + context.append_dialect_registry(mlir_interpreter.upstream_dialects) + context.load_all_available_dialects() + mgpu.register_dialect(context) + return context + + +def walk_operations(op: ir.OpView, callback): + for region in op.operation.regions: + for block in region: + for block_op in block: + walk_operations(block_op, callback) + callback(op) + + +def find_if( + module: ir.Module, predicate: Callable[[ir.OpView], bool] +) -> list[ir.OpView]: + result = [] + + def callback(op: ir.OpView): + if predicate(op): + result.append(op) + + for op in module.body.operations: + walk_operations(op, callback) + return result + + +def is_mosaic_gpu_op(op: ir.OpView) -> bool: + return op.name.startswith("mosaic_gpu.") + + +def workgroup_ptr_ty() -> ir.Type: + workgroup_nvptx_address_space = gpu_address_space_to_nvptx( + gpu.AddressSpace.Workgroup) + return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") + + +class DialectTest(parameterized.TestCase): + + def setUp(self): + if mgpu is None: + raise self.skipTest("Test requires Mosaic GPU dialect") + super().setUp() + self.enter_context(_make_ir_context()) + self.enter_context(ir.Location.unknown()) + self.module = ir.Module.create() + + def test_dialect_module_is_loaded(self): + self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu")) + + def test_initialize_barrier_op_result_memref_must_wrap_barriers(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.F32Type.get()), + llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1) + with self.assertRaisesRegex( + ir.MLIRError, "must be memref of barrier values" + ): + self.module.operation.verify() + + def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=0) + with self.assertRaisesRegex(ir.MLIRError, "value is positive"): + self.module.operation.verify() + + def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")), + arrival_count=1) + with self.assertRaisesRegex(ir.MLIRError, "pointer in address space 3"): + self.module.operation.verify() + + def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=1) + self.assertTrue(self.module.operation.verify()) + self.assertIsInstance(self.module.body.operations[1], + mgpu.InitializeBarrierOp) + + def test_async_load_op_dest_must_be_contiguous(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get( + [4, 8], + ir.F32Type.get(), + layout=ir.Attribute.parse("strided<[16, 1]>"), + ), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `destination` memref must be contiguous", + ): + self.module.operation.verify() + + def test_async_load_op_source_and_dest_must_have_same_element_type(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F64Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`source` and `destination` memrefs must have the same element", + ): + self.module.operation.verify() + + def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[-2, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `slice_lengths` attribute must not contain values less than -1", + ): + self.module.operation.verify() + + def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([1, 4, 8], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[-1, 4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`destination` plus the number of collapsed dimensions as indicated", + ): + self.module.operation.verify() + + def test_async_load_op_indices_size_must_match_source_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `indices` must be equal to the rank of `source`", + ): + self.module.operation.verify() + + def test_async_load_op_slice_lengths_size_must_match_source_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `slice_lengths` must be equal to the rank of `source`", + ): + self.module.operation.verify() + + def test_async_load_op_slice_collective_must_be_unique(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([ + ir.Attribute.parse( + f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>" + ), + ir.Attribute.parse( + f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>" + ), + ]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `collective` attribute must not contain duplicate dimensions", + ): + self.module.operation.verify() + + def test_async_store_op_source_must_be_contiguous(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get( + [4, 8], + ir.F32Type.get(), + layout=ir.Attribute.parse("strided<[16, 1]>"), + ), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `source` memref must be contiguous", + ): + self.module.operation.verify() + + def test_async_store_op_source_and_dest_must_have_same_element_type(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F64Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`source` and `destination` memrefs must have the same element", + ): + self.module.operation.verify() + + def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[-2, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `slice_lengths` attribute must not contain values less than -1", + ): + self.module.operation.verify() + + def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([1, 4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[-1, 4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`source` plus the number of collapsed dimensions as indicated", + ): + self.module.operation.verify() + + def test_async_store_op_indices_size_must_match_destination_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `indices` must be equal to the rank of `destination`", + ): + self.module.operation.verify() + + def test_async_store_op_slice_lengths_size_must_match_source_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `slice_lengths` must be equal to the rank of" + " `destination`", + ): + self.module.operation.verify() + + +class DialectLoweringTest(DialectTest): + + def test_lowering_removes_mosaic_gpu_ops(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=1) + lower_mgpu_dialect(self.module) + + self.assertEmpty( + list(filter(is_mosaic_gpu_op, self.module.body.operations)) + ) + + def test_lowering_traverses_regions_correctly(self): + with ir.InsertionPoint(self.module.body): + bool_type = ir.IntegerType.get_signless(1) + cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1)) + if_op = scf.IfOp(cst_true) + with ir.InsertionPoint(if_op.then_block): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=1) + scf.yield_([]) + lower_mgpu_dialect(self.module) + + self.assertEmpty( + list(filter(is_mosaic_gpu_op, if_op.then_block.operations)) + ) + + def test_initialize_barrier_op_lowering_rule(self): + shape = (3, 4) + num_shape_elements = shape[0] * shape[1] + arrival_count = 1337 + + with ir.InsertionPoint(self.module.body): + barriers_ref = mgpu.initialize_barrier( + ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=arrival_count) + # Add a user for barriers_ref to make sure that the lowering keeps types + # consistent. + memref.copy(barriers_ref, barriers_ref) + + self.assertTrue(self.module.operation.verify()) + lower_mgpu_dialect(self.module) + self.assertTrue(self.module.operation.verify()) + + all_mbarrier_init_shared_ops = find_if( + self.module, + lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME, + ) + + # One nvvm.mbarrier_init_shared is issued per barrier. + self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements) + + # Each barrier has its count equal to the arrival count. + for op in all_mbarrier_init_shared_ops: + count = op.count.owner.opview + self.assertIsInstance(count, arith.ConstantOp) + self.assertEqual(count.literal_value, arrival_count) + + +if __name__ == "__main__": + parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1ece3f62e3a3..26d6bfafd84d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -18,6 +18,8 @@ import itertools import math import operator +import os +import re import unittest from absl.testing import absltest, parameterized @@ -29,6 +31,7 @@ from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax.experimental.mosaic.gpu import fragmented_array as fa import jax.numpy as jnp import numpy as np try: @@ -362,6 +365,26 @@ def kernel(ctx, inp, out, _): else: do_test() + @parameterized.parameters(jnp.uint64, jnp.uint32, jnp.uint16, jnp.uint8) + def test_scalar_argument(self, dtype): + scalar = 42 + expected = np.full((128, 128), scalar, dtype=dtype) + + def kernel(ctx, inp, out, _): + del ctx + inp = memref.load(inp, []) + mgpu.FragmentedArray.splat(inp, expected.shape, is_signed=True).store_untiled(out) + + res = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + jax.ShapeDtypeStruct(shape=(), dtype=expected.dtype), + expected, + (), + )(scalar) + np.testing.assert_array_equal(res, expected) + def get_packed_shape(strides, shape): perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) @@ -1060,6 +1083,30 @@ def kernel(ctx, src, dst, scratch): y = f(x) np.testing.assert_array_equal(y, x) + def test_tma_load_indexed_tiled(self): + shape = (128, 2, 128) + tiling = mgpu.TileTransform((32, 32)) + def kernel(ctx, src, dst, scratch): + tmp, barrier = scratch + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + barrier=barrier, + gmem_transform=tiling, + gmem_slice=(slice(None), 1, slice(None)), + ) + barrier.wait() + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_transform=tiling) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) + smem = ( + jax.ShapeDtypeStruct((4, 4, 32, 32), jnp.float32), + mgpu.TMABarrier(), + ) + out_shape = jax.ShapeDtypeStruct((128, 128), jnp.float32) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem) + np.testing.assert_array_equal(f(x), x[:, 1, :]) + @parameterized.product( swizzle=(None, 128), dtype=(jnp.float16, jnp.float32), @@ -1193,7 +1240,7 @@ def run_kernel(shape): x = np.arange(np.prod(shape)).reshape(shape) _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) - with self.assertRaisesRegex(ValueError, "only support striding up to 5"): + with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"): run_kernel([1] * 6) with self.assertRaisesRegex( @@ -1209,27 +1256,22 @@ class FragmentedArrayTest(TestCase): operator.add, operator.mul, operator.sub, - operator.truediv, - operator.mod, + (lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum), (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), ) - @jtu.ignore_warning(message="(invalid value|divide by zero)", - category=RuntimeWarning) + @jtu.ignore_warning( + message="(invalid value|divide by zero)", category=RuntimeWarning + ) def test_binary(self, op, dtype, m=64, n=32): if isinstance(op, tuple): op, np_op = op else: np_op = op - if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv: - self.skipTest("Unsupported for integer types") - if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod: - self.skipTest("Unsupported for floating types") - for scalar_rhs in [None, 2]: def kernel(ctx, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) @@ -1242,10 +1284,56 @@ def kernel(ctx, dst, _): )() ref_x = np.arange(m * n, dtype=dtype).reshape(m, n) ref_rhs = scalar_rhs or ref_x - if op is operator.truediv: - np.testing.assert_allclose(result, np_op(ref_x, ref_rhs), atol=2e-7) - else: - np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + + def test_minimum_np_compatibility(self): + one = np.ones((128, 128)).astype(np.float32) + negz = one * -0. + posz = one * 0. + nan = one * np.nan + expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one) + assert np.all(expectation), expectation + + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128)) + negz = splat(-0.) + posz = splat(0.) + nan = splat(np.nan) + one = splat(1.) + res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one) + i8 = ir.IntegerType.get_signless(8) + res.astype(i8, is_signed=False).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((128, 128), np.int8) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + # astype() uses extsi so i1=True becomes -1 + np.testing.assert_array_equal(result == -1, expectation) + + @parameterized.product( + op=[operator.truediv, operator.floordiv, operator.mod], + dtype=[jnp.float32, jnp.int32, jnp.uint32], + ) + def test_division(self, op, dtype, m=64, n=32): + if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv: + self.skipTest("Unsupported for integer types") + if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod: + self.skipTest("Unsupported for floating types") + + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, dtype) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_allclose( + result, op(dtype(4.2).item() * iota, iota + 1), atol=2e-7 + ) @parameterized.product( op=[ @@ -1257,28 +1345,83 @@ def kernel(ctx, dst, _): operator.ne, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], + rhs_is_literal=[False, True] ) - def test_comparison(self, op, dtype, m=64, n=32): + def test_comparison(self, op, dtype, rhs_is_literal, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + rhs = 0 if rhs_is_literal else iota + 1 + op(iota, rhs).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) + rhs = rhs = 0 if rhs_is_literal else iota + 1 + np.testing.assert_array_equal(result, op(iota, rhs)) + + def test_foreach(self): + dtype = jnp.int32 + swizzle = 128 + tile = 64, swizzle // jnp.dtype(dtype).itemsize + shape = 128, 192 + tiled_shape = mgpu.tile_shape(shape, tile) + mlir_dtype = utils.dtype_to_ir_type(dtype) + cst = 9999 + def causal(val, idx): + row, col = idx + mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) + return arith.select(mask, val, c(cst, mlir_dtype)) + + tiling = mgpu.TileTransform(tile) + def kernel(ctx, dst, smem): + x = iota_tensor(shape[0], shape[1], dtype) + x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.await_async_copy(0) + + iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) + result = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + (), + jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + )() + expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst + np.testing.assert_array_equal(result, expected) + + @parameterized.product( + op=[operator.and_, operator.or_, operator.xor], + dtype=[jnp.uint32], + ) + def test_bitwise(self, op, dtype, m=64, n=8): + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, dtype) + op(iota, iota + 1).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(m * n, dtype=dtype).reshape(m, n) np.testing.assert_array_equal(result, op(iota, iota + 1)) @parameterized.product( ops=( (lambda x: -x, jax.lax.neg), (lambda x: x + 42, lambda x: x + 42), + (lambda x: x.tanh(), jax.lax.tanh), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], ) def test_unary(self, ops, dtype, m=64, n=32): op, np_op = ops + if np_op is jax.lax.tanh and jnp.issubdtype(dtype, jnp.integer): + raise self.skipTest("Tanh not supported for integer types") def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) @@ -1338,11 +1481,7 @@ def kernel(ctx, src, dst, scratch): src = mgpu.FragmentedArray.load_strided( src, is_signed=utils.is_signed(dtype) ) - acc = mgpu.FragmentedArray.splat( - src.reduce_sum(scratch), - (m,), - is_signed=src.is_signed - ) + acc = src.reduce_sum(scratch).broadcast((m,)) acc.store_untiled(dst) in_shape = jax.ShapeDtypeStruct((m, n), dtype) @@ -1409,6 +1548,29 @@ def kernel(ctx, dst, _): )() np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) + + def test_splat_binary_ops(self): + def kernel(ctx, src, dst, _): + f32 = ir.F32Type.get() + pi_arr = mgpu.FragmentedArray.load_strided(src) + assert isinstance(pi_arr.layout, mgpu.WGStridedFragLayout) + pi_scalar = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) + pi_splat = mgpu.FragmentedArray.splat(pi_scalar, ()) + assert isinstance(pi_splat.layout, mgpu.WGSplatFragLayout) + pi_arr_sq = pi_arr * pi_splat.broadcast(pi_arr.shape) + assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) + pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq + assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + inp = jnp.ones_like(out_shape) * 3.14 + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, () + )(inp) + np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32)) + + @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) def test_strided_load_store(self, in_shape): def kernel(ctx, *args): @@ -1442,19 +1604,19 @@ def kernel(ctx, out, *_): np.testing.assert_array_equal(result, x) - @parameterized.named_parameters( - ("_bf16", jnp.bfloat16) - ) - def test_fast_i8_convert(self, jax_dtype_to): - jax_dtype_to = jnp.dtype(jax_dtype_to) + @parameterized.parameters(2, 4) + def test_fast_i8_convert(self, reg_length): + jax_dtype_to = jnp.dtype(jnp.bfloat16) jax_dtype_from = jnp.dtype(jnp.int8) mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) + assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] arr.astype(mlir_dtype_to).store_untiled(out) x = jnp.arange(-128, 128, dtype=jax_dtype_from) + x = jnp.tile(x, reg_length // 2) reference = x.astype(jax_dtype_to) result = mgpu.as_gpu_kernel( @@ -1474,6 +1636,54 @@ def kernel(ctx, _): _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)() + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast(self, in_dtype, out_dtype): + out_ir_type = utils.dtype_to_ir_type(out_dtype) + in_is_signed = utils.is_signed(in_dtype) + out_is_signed = utils.is_signed(out_dtype) + + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed) + arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed) + arr.store_untiled(out) + + x = jnp.arange(256, dtype=in_dtype) + reference = jax.lax.bitcast_convert_type(x, out_dtype) + + result = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + x, + reference, + None, + )(x) + np.testing.assert_array_equal(result, reference) + + @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) + def test_optimization_barrier(self, dtype): + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp) + arr2 = arr * 2 + arr, arr2 = mgpu.optimization_barrier(arr, arr2) + (arr + arr2).store_untiled(out) + + x = jnp.arange(256, dtype=dtype) + + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None) + np.testing.assert_array_equal(f(x), x * 3) + class ProfilerTest(TestCase): @@ -1481,6 +1691,17 @@ def test_measure(self): x = jnp.arange(1024 * 1024) profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test + def test_profile(self): + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + spec = profiler.ProfilerSpec(1024) + # This is just a smoke test. + f = jax.jit(mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), prof_spec=spec + )) + jax.block_until_ready(f(x)) + def test_multigpu(self): if len(jax.devices()) < 2: self.skipTest("Need at least 2 devices") @@ -1518,5 +1739,96 @@ def kernel(ctx, i_gmem, o_gmem, _): del y # Make sure the destructor runs successfully. +class LayoutTest(TestCase): + + @parameterized.product( + shape=((128, 128), (64, 8), (64, 256)), + dtype=(jnp.int32, jnp.int16, jnp.int8), + ) + def test_wgmma_tiled_layout(self, shape, dtype): + def kernel(ctx, dst, _): + iota = iota_tensor(*shape, dtype) + tiled = iota.to_layout(fa._tiled_wgmma_layout(shape)) + # Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1) + self.assertEqual( + tiled.registers.shape, + (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1), + ) + self.assertEqual(tiled.shape, shape) + self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) + tiled.store_untiled(dst) + ty = jax.ShapeDtypeStruct(shape, dtype) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) + expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) + np.testing.assert_array_equal(f(), expected) + + @parameterized.product( + load_tiled=[False, True], + store_tiled=[False, True], + dtype=[jnp.int8, jnp.int16, jnp.int32], + swizzle=[32, 64, 128], + num_col_tiles=[1, 2, 3], + ) + def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles): + mlir_dtype = utils.dtype_to_ir_type(dtype) + bw = bytewidth(mlir_dtype) + col_tiling = swizzle // bw + m, n = 128, col_tiling * num_col_tiles + tiling = (64, col_tiling) + tiled_layout = fa._tiled_wgmma_layout((m, n)) + load_layout = tiled_layout if load_tiled else mgpu.WGMMA_LAYOUT + store_layout = tiled_layout if store_tiled else mgpu.WGMMA_LAYOUT + if (not load_tiled or not store_tiled) and bw == 4 and swizzle == 32: + self.skipTest("Old code path does not support this") + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_tiled( + smem_in, swizzle=swizzle, is_signed=True, layout=load_layout + ) + t.to_layout(store_layout).store_tiled(smem_out, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) + ctx.await_async_copy(0) + expected = ( + np.arange(m * n, dtype=dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + + prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) + os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" + try: + with jtu.capture_stdout() as get_sass: + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, + [expected, expected, mgpu.TMABarrier()], + )(expected) + finally: + if prev_dump is not None: + os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump + np.testing.assert_array_equal(iota, expected) + + # Verify that we don't use too many registers for the transfers. + # We verify LDS and STS separately, because they might use two different + # methods of computing offsets and we don't rely on CSE between them. + expected_regs = swizzle // bytewidth(mlir_dtype) // 8 + # When the bytewidth is smaller than 2 the swizzle pattern changes every 2 + # column tiles, so we only need half the registers. + if load_tiled and store_tiled: # The old code doesn't optimize properly. + if bytewidth(mlir_dtype) < 2: + expected_regs //= 2 + for instr in ("STS", "LDS"): + with self.subTest(instr + " count"): + addrs = re.findall(instr + r".* \[(.*)\]", get_sass()) + def get_reg(addr): + if (pos := addr.find("+")) != -1: + return addr[:pos] + return addr + used_regs = {get_reg(addr) for addr in addrs} + self.assertLessEqual(len(used_regs), expected_regs) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index 27ce4e3f02d7..d598d7d0c0ec 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -55,8 +55,8 @@ def setUp(self): if matmul is None: self.skipTest("Mosaic GPU not available.") if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on GPU with capability >= sm90") + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") @parameterized.named_parameters( (f"_shard{i}", i) for i in range(5) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index d3e32873c597..f1b80f32446a 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -232,5 +232,18 @@ def f(): x = f() self.assertArraysEqual(x, jnp.zeros(8)) + def test_grad_mutable_array(self): + @jax.jit + def f(x): + x_ = core.mutable_array(x) + x_[()] = x_[()] + x_[()] + y = core.freeze(x_) + return y + + ans = jax.grad(f)(1.) + expected = 2.0 + self.assertAllClose(ans, expected, check_dtypes=False) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/nn_test.py b/tests/nn_test.py index df719256a921..0856b259c190 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -99,8 +99,8 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) - self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) - self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) + self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01) + self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), @@ -164,10 +164,10 @@ def testDotProductAttentionMask(self, mask_mode): self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) - self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) - self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) - self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03) + self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) + self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02) @parameterized.product( batch_size=[1, 16], @@ -224,7 +224,7 @@ def bwd_ans(x, bias, mask): else: _, dbias_ref, _ = bwd_ref(x, bias, mask) _, dbias_ans, _ = bwd_ans(x, bias, mask) - self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03) + self.assertAllClose(dbias_ans, dbias_ref, rtol=.02, atol=.02) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 71d48c2b121c..d80c750ae859 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -32,10 +32,55 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. _mod("jax.errors", exclude=["JaxRuntimeError"]), + _mod( + "jax.numpy", + exclude=[ + "array_repr", + "array_str", + "can_cast", + "character", + "complexfloating", + "dtype", + "iinfo", + "index_exp", + "inexact", + "integer", + "iterable", + "finfo", + "flexible", + "floating", + "generic", + "get_printoptions", + "ndarray", + "ndim", + "number", + "object_", + "printoptions", + "save", + "savez", + "set_printoptions", + "shape", + "signedinteger", + "size", + "s_", + "unsignedinteger", + "ComplexWarning", + ], + ), + _mod("jax.numpy.linalg"), _mod("jax.nn.initializers"), _mod( "jax.tree_util", - exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"], + exclude=[ + "PyTreeDef", + "default_registry", + "KeyEntry", + "KeyPath", + "DictKey", + "GetAttrKey", + "SequenceKey", + "FlattenedIndexKey", + ], ), ]) def test_exported_names_match_module(self, module_name, include, exclude): @@ -46,7 +91,8 @@ def test_exported_names_match_module(self, module_name, include, exclude): if name not in include and (name.startswith('_') or name in exclude): continue obj = getattr(module, name) - if isinstance(obj, types.ModuleType): + if obj is None or isinstance(obj, (bool, int, float, complex, types.ModuleType)): + # No __module__ attribute expected. continue self.assertEqual(obj.__module__, module_name, f"{obj} has {obj.__module__=}, expected {module_name}") diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index b5af90272510..fd1166d66df6 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -56,6 +56,20 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "pallas_cost_estimate_test", + srcs = [ + "pallas_cost_estimate_test.py", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", + "//jax:pallas_gpu_ops", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "pallas_jumble_test", srcs = [ @@ -162,6 +176,7 @@ jax_multiplatform_test( ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "JAX_PALLAS_VERBOSE_ERRORS": "0", }, deps = [ "//jax:pallas", @@ -221,7 +236,6 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep - "//jax/experimental/export", ], ) @@ -380,9 +394,13 @@ jax_multiplatform_test( "tpu_pallas_random_test.py", ], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5p_2x2", + ], deps = [ "//jax:pallas", "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -390,6 +408,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], + disable_configs = [ + "tpu_v5p_1x1", + ], enable_backends = ["tpu"], shard_count = 5, tags = [ @@ -467,3 +488,32 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", ] + py_deps("absl/testing") + py_deps("numpy"), ) + +jax_multiplatform_test( + name = "mgpu_attention_run", + srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_h100_x32"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "mgpu_attention_test", + srcs = ["mgpu_attention_test.py"], + enable_backends = [], + enable_configs = ["gpu_h100_x32"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 9e9935884b3a..462597e567f2 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -18,6 +18,7 @@ """ import math +import unittest from absl.testing import absltest import jax @@ -47,6 +48,10 @@ def setUp(self): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() + @unittest.skip("This test is checking backwards compatibility " + "of Triton IR, but Triton doesn't promise backwards " + "compatibility for its IR, and we have since removed " + "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index ee7fe4ffff47..8b18f706a1d0 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -49,7 +49,11 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: a = np.arange(8 * 16, dtype=np.int32).reshape((8, 16)) exp = export.export( add_vectors, - lowering_platforms=["tpu", "cuda"], + platforms=["tpu", "cuda"], + # The Pallas GPU custom call is not enabled for export by default. + disabled_checks=[ + export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton") + ] )(a, a) if (jtu.device_under_test() == "tpu" or diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index ed059c235329..3b4aa1551591 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -21,6 +21,7 @@ from jax import random from jax._src import config from jax._src import test_util as jtu + if sys.platform != "win32": from jax.experimental.pallas.ops.gpu import decode_attention else: @@ -48,8 +49,9 @@ def setUp(self): self.skipTest("On CPU, the test works only in interpret mode") if jax.config.x64_enabled: self.skipTest("The test works only in 32-bit") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): + if jtu.test_device_matches( + ["cuda"] + ) and not jtu.is_cuda_compute_capability_at_least("8.0"): self.skipTest("Only works on GPU with capability >= sm80") if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") @@ -62,12 +64,18 @@ class DecodeAttentionTest(PallasBaseTest): @parameterized.named_parameters(*[ ( - f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}", + ( + f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_" + f"{start_idx=}_{kv_seq_len=}_{return_residuals=}" + ), batch_size, seq_len, num_heads, head_dim, kwargs, + start_idx, + kv_seq_len, + return_residuals, ) for ( batch_size, @@ -80,6 +88,9 @@ class DecodeAttentionTest(PallasBaseTest): (2, 1024, 2, 64, {}), (1, 1024, 8, 64, {}), ] + for start_idx in [None, 123] + for kv_seq_len in [None, 250] + for return_residuals in [False, True] ]) @jax.numpy_dtype_promotion("standard") def test_mqa( @@ -89,6 +100,9 @@ def test_mqa( num_heads, head_dim, kwargs, + start_idx, + kv_seq_len, + return_residuals, ): del kwargs @@ -97,19 +111,45 @@ def test_mqa( k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16) v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16) - o = decode_attention.mqa(q, k, v, interpret=self.INTERPRET) - o_ref = decode_attention.mqa_reference(q, k, v) + o, *res = decode_attention.mqa( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + interpret=self.INTERPRET, + ) + o_ref, *res_ref = decode_attention.mqa_reference( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + ) np.testing.assert_allclose(o, o_ref, atol=0.05) + if return_residuals: + l, m = res[0] + l_ref, m_ref = res_ref[0] + np.testing.assert_allclose(l, l_ref, atol=0.05) + np.testing.assert_allclose(m, m_ref, atol=0.05) @parameterized.named_parameters(*[ ( - f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}_{kwargs=}", + ( + f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}" + f"_{kwargs=}_{start_idx=}_{kv_seq_len=}_{return_residuals=}" + ), batch_size, seq_len, num_q_heads, num_kv_heads, head_dim, kwargs, + start_idx, + kv_seq_len, + return_residuals, ) for ( batch_size, @@ -123,6 +163,9 @@ def test_mqa( (1, 1024, 16, 16, 64, {}), (1, 1024, 32, 32, 64, {}), ] + for start_idx in [None, 123] + for kv_seq_len in [None, 250] + for return_residuals in [False, True] ]) @jax.numpy_dtype_promotion("standard") def test_gqa( @@ -133,6 +176,9 @@ def test_gqa( num_kv_heads, head_dim, kwargs, + start_idx, + kv_seq_len, + return_residuals, ): del kwargs @@ -146,10 +192,30 @@ def test_gqa( v = random.normal( k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 ) - - o = decode_attention.gqa(q, k, v, interpret=self.INTERPRET) - o_ref = decode_attention.gqa_reference(q, k, v) + o, *res = decode_attention.gqa( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + interpret=self.INTERPRET, + ) + o_ref, *res_ref = decode_attention.gqa_reference( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + ) np.testing.assert_allclose(o, o_ref, atol=0.05) + if return_residuals: + l, m = res[0] + l_ref, m_ref = res_ref[0] + np.testing.assert_allclose(l, l_ref, atol=0.05) + np.testing.assert_allclose(m, m_ref, atol=0.05) + class DecodeAttentionInterpretTest(DecodeAttentionTest): INTERPRET = True diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py new file mode 100644 index 000000000000..43727f47338b --- /dev/null +++ b/tests/pallas/mgpu_attention_test.py @@ -0,0 +1,76 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of FlashAttention.""" + +import os + +import numpy as np +from absl.testing import absltest, parameterized +from jax._src import config +from jax._src import test_util as jtu +import jax.numpy as jnp + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + attention_mgpu = None +else: + from jax.experimental.pallas.ops.gpu import attention_mgpu + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class FlashAttentionTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if attention_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + + @parameterized.product( + batch_size=(1, 4), + q_seq_len=(4096,), + kv_seq_len=(4096,), + num_q_and_kv_heads=((4, 1), # MQA + (6, 3), # GQA + (4, 4),), # MHA + head_dim=(64, 128, 256), + ) + def test_flash_attention( + self, batch_size, q_seq_len, kv_seq_len, num_q_and_kv_heads, head_dim + ): + num_q_heads, num_kv_heads = num_q_and_kv_heads + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + out = attention_mgpu.attention( + q, k, v, attention_mgpu.TuningConfig(block_q=64, block_kv=64, max_concurrent_steps=2) + ) + out_ref = attention_mgpu.attention_reference(q, k, v) + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 22ae3e699b38..283e3e1a83c6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import math +import os import re +import tempfile import traceback from absl.testing import absltest @@ -26,6 +29,10 @@ from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib +except ImportError: + mosaic_gpu_lib = None jax.config.parse_flags_with_absl() @@ -41,6 +48,19 @@ def setUp(self): super().setUp() + @contextlib.contextmanager + def capture_stdout(self): + if mosaic_gpu_lib is None: + raise ValueError("Running tests but missing Mosaic GPU extension") + with jtu.capture_stdout() as stdout: + yield stdout + # We need to cudaDeviceSynchronize to make sure printfs are flushed. + mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + + def skip_unless_sm90a(self): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + class PallasCallTest(PallasTest): @@ -50,8 +70,9 @@ class PallasCallTest(PallasTest): ("exp", jax.lax.exp), ("square", lambda x: x ** 2), ("rsqrt", jax.lax.rsqrt), + ("tanh", jax.lax.tanh, 1e-6), ) - def test_unary_ops(self, unary): + def test_unary_ops(self, unary, rtol=1e-7): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), @@ -60,7 +81,26 @@ def kernel(x_ref, o_ref): o_ref[...] = unary(x_ref[...]) x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), unary(x)) + np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol) + + @parameterized.named_parameters( + ("add", lambda x, y: x + y), + ("mul", lambda x, y: x * y), + ("div", lambda x, y: x / y), + ("min", lambda x, y: jnp.minimum(x, y)), + ("max", lambda x, y: jnp.maximum(x, y)), + ) + def test_binary_op(self, bop): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = bop(x_ref[...], y_ref[...]) + + x = jnp.arange(256).astype(jnp.float32) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), bop(x, y)) def test_add_first(self): @functools.partial( @@ -74,17 +114,21 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) - def test_add_xy(self): + def test_reshape(self): + shape1, shape2 = (128,), (2, 16, 4) + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), ) - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = x_ref[...] + y_ref[...] + def kernel(x_ref, out_ref): + x_ref_reshaped = x_ref.reshape(shape2) + self.assertEqual(x_ref.shape, shape1) + self.assertEqual(x_ref_reshaped.shape, shape2) + out_ref[...] = x_ref_reshaped[...] - x = jnp.arange(256).astype(jnp.float32) - y = x + 1 - np.testing.assert_array_equal(kernel(x, y), x + y) + x = jnp.arange(math.prod(shape1)).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) def test_add_xy_indexed(self): @functools.partial( @@ -197,6 +241,18 @@ def kernel(x_ref, o_ref): # are never written to. np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16]) + @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) + def test_iota(self, dtype): + dimension = 1 + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + ) + def kernel(o_ref): + o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) + + np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_smem_to_gmem(self, indexer): @functools.partial( @@ -207,6 +263,7 @@ def test_copy_smem_to_gmem(self, indexer): ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 + plgpu.commit_smem() plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer]) plgpu.wait_smem_to_gmem(0) @@ -226,7 +283,7 @@ def test_copy_gmem_to_smem(self, indexer): ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.copy_gmem_to_smem( - x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref + x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref ) plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 @@ -247,7 +304,7 @@ def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.copy_gmem_to_smem( - x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer] + x_ref_gmem, scratch_ref, barrier_ref.at[indexer] ) plgpu.barrier_wait(barrier_ref.at[indexer]) o_ref[...] = scratch_ref[...] + 1 @@ -259,9 +316,10 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_copy_with_transforms(self, to_smem): def kernel(x_ref, o_ref, barrier_ref): if to_smem: - plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) else: + plgpu.commit_smem() plgpu.copy_smem_to_gmem(x_ref, o_ref) plgpu.wait_smem_to_gmem(0) @@ -291,7 +349,7 @@ def test_scoped_copy_with_transforms(self): ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): - plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = tmp_ref[...] * 2 pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) @@ -313,7 +371,7 @@ def body(tmp_ref): def test_copy_with_transforms_and_indexing(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): - plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) @@ -341,7 +399,7 @@ def test_indexing_before_transpose(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( - x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier=barrier_ref + x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref ) plgpu.barrier_wait(barrier_ref) @@ -369,7 +427,7 @@ def test_copy_gmem_to_smem_in_run_scoped(self): def kernel(x_ref_gmem, o_ref): def body(barrier_ref): def inner_body(scratch_ref): - plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) @@ -434,9 +492,7 @@ def layer_norm_np(x): jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) * input_factor ) - # TODO(cperivol): find out why in this particular case we have a small-ish error. - rtol = 1e-07 if input_factor > 10 else 5e-5 - np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol) + np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): @functools.partial( @@ -448,9 +504,8 @@ def kernel(x_ref, o_ref): pl.debug_print("It works!") x = jnp.arange(256).astype(jnp.float32) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) - self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): @@ -462,7 +517,7 @@ def kernel(x_ref, o_ref): x = jnp.arange(size, dtype=jnp.float32).reshape(shape) f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) - with jtu.capture_stdout() as get_output: + with self.capture_stdout() as get_output: jax.block_until_ready(f(x)) output = get_output() @@ -482,7 +537,7 @@ def kernel(x_ref, o_ref): pl.debug_print("x.sum() = {}", x_ref[...].sum()) x = jnp.arange(256) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn(f"x.sum() = {x.sum()}", output()) @@ -497,7 +552,7 @@ def kernel(x_ref, o_ref): pl.debug_print("x.sum() = {}", x_ref[...].sum() + 1) x = jnp.arange(256) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn(f"x.sum() = {x.sum() + 1}", output()) @@ -514,11 +569,23 @@ def kernel(x_ref, o_ref): pl.debug_print("x: {}", x_ref[...]) x = jnp.arange(math.prod(in_shape)).reshape(in_shape) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) + def test_load_scalar(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.broadcast_to(x_ref[10], (128,)) + + np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), + jnp.full((128,), 10, dtype=jnp.int32)) + def test_run_scoped(self): def kernel(x_ref, o_ref): def body(tmp_ref): @@ -554,6 +621,30 @@ def kernel(o_ref): jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32), ) + def test_program_id_in_squashed_grid(self): + # Tests whether a grid with >3 logical dimensions is correctly squashed to + # 3 CUDA grid dimensions. + grid = (2, 3, 4, 5) + @functools.partial( + pl.pallas_call, + in_specs=(), + out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), + out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), + grid=grid, + ) + def kernel(o_ref): + mult = 1 + idx = 0 + for axis in range(len(grid)-1, -1, -1): + idx += pl.program_id(axis) * mult + mult *= pl.num_programs(axis) + o_ref[...] = jnp.full(o_ref.shape, idx) + + np.testing.assert_array_equal( + kernel()[:, :, :, :, 0], + jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(*grid) + ) + def test_program_id_in_block_spec(self): @functools.partial( pl.pallas_call, @@ -616,28 +707,64 @@ def kernel(x_ref, o_ref): def test_fori_loop_array(self): @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...]) - x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0) + x = jnp.arange(256).astype(jnp.int32) + np.testing.assert_array_equal(kernel(x), x + 2 + 3) def test_fori_loop_scalar(self): + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(o_ref): + # Equivalent to 2 + 3. + o_ref[...] = jax.lax.broadcast( + jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0), o_ref.shape + ) + + np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + + def test_fori_loop_dynamic_bounds(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + grid=(1,) ) def kernel(o_ref): + zero = pl.program_id(0) # Equivalent to 2 + 3. o_ref[...] = jax.lax.broadcast( - jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0.0), o_ref.shape + jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape + ) + + np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + + def test_fori_loop_tuple(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(o_ref): + def body(step, xs): + return tuple( + jax.lax.cond(step % 2 == 0, lambda x: x + 1, lambda x: x, x) + for x in xs + ) + + # Equivalent to 3 * (0 + 1). + o_ref[...] = jax.lax.broadcast( + sum(jax.lax.fori_loop(2, 4, body, (0, 0, 0))), o_ref.shape ) np.testing.assert_array_equal( - kernel(), jnp.full([256], 5.0, dtype=jnp.float32) + kernel(), jnp.full([256], 3 * (0 + 1), dtype=jnp.int32) ) def test_fori_loop_indexed_store(self): @@ -657,7 +784,6 @@ def body(idx, _): np.testing.assert_array_equal(kernel(x, y), x + y) def test_cond(self): - @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), @@ -672,13 +798,31 @@ def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(acc, o_ref.shape) x = jnp.arange(256) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn("acc * 2:", output()) + def test_cond_returning_array(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + acc = x_ref[...].sum() + acc2, acc = jax.lax.cond( + acc % 2 == 0, + lambda: (acc * 2, acc), + lambda: (acc, acc * 2), + ) + o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + + x = jnp.arange(256) + np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): + self.skip_unless_sm90a() # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -729,6 +873,7 @@ def scope(acc_ref): ) def test_wgmma_registers(self): + self.skip_unless_sm90a() def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -751,7 +896,33 @@ def scope(acc_ref): )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + def test_wgmma_registers_init(self): + self.skip_unless_sm90a() + def kernel(a_ref, b_ref, i_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...])) + + key1, key2, key3 = jax.random.split(jax.random.key(42), 3) + a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) + i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 + + transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + ], + out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), + )(a, b, i) + np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) + def test_wgmma_sliced_ref(self): + self.skip_unless_sm90a() def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -787,6 +958,7 @@ def scope(acc_ref): np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): + self.skip_unless_sm90a() swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -842,6 +1014,7 @@ def kernel(a_ref, b_ref): np.testing.assert_array_equal(b, np.ones_like(a)) def test_realistic_matmul(self): + self.skip_unless_sm90a() dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -939,10 +1112,63 @@ def kernel(o_ref): x = jnp.full(shape, 42.0) np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): + def kernel(x_ref, o_ref): + with jax.named_scope("add"): + with jax.named_scope("load"): + x = x_ref[...] + o = x + x + with jax.named_scope("store"): + o_ref[...] = o + with tempfile.TemporaryDirectory() as tmpdir: + x = jnp.arange(256).astype(jnp.float32) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + profile_space=16, profile_dir=tmpdir + ), + )(x) + jax.block_until_ready(y) + jax.effects_barrier() + [name] = os.listdir(tmpdir) + with open(os.path.join(tmpdir, name), "r") as f: + data = f.read() + self.assertEqual(data.count('"name": "add"'), 2) + self.assertEqual(data.count('"name": "load"'), 2) + self.assertEqual(data.count('"name": "store"'), 2) + np.testing.assert_array_equal(y, x + x) + + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + m, n = 16, 8 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + class PipelineTest(PallasTest): - def test_manual(self, max_concurrent_steps=2, num_steps=4): + def test_manual(self): + max_concurrent_steps = 2 + num_steps = 4 def kernel(x_gmem, o_gmem): return pl.run_scoped( @@ -963,8 +1189,9 @@ def body(step, _): # Wait for the previous output SMEM->GMEM copy to complete. plgpu.wait_smem_to_gmem(max_concurrent_steps - 1) - o_smem[...] = x_smem[...] + 1.0 + o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0 + plgpu.commit_smem() plgpu.copy_smem_to_gmem( o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] ) @@ -976,7 +1203,7 @@ def body(step, _): lambda: plgpu.copy_gmem_to_smem( x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)], x_smem.at[fetch_slot], - barrier=barrier.at[fetch_slot], + barrier.at[fetch_slot], ), lambda: None, ) @@ -987,7 +1214,7 @@ def body(step, _): plgpu.copy_gmem_to_smem( x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)], x_smem.at[slot], - barrier=barrier.at[slot], + barrier.at[slot], ) jax.lax.fori_loop(0, num_steps, body, ()) @@ -1005,6 +1232,164 @@ def body(step, _): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + @parameterized.parameters( + ((),), + ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), + ) + def test_emit(self, transforms): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[ + plgpu.GPUBlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + out_specs=[ + plgpu.GPUBlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + # +1 for the indexing done by ``emit_pipeline`. + self.assertLen(x_smem.transforms, len(transforms) + 1) + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(64 * num_steps * 64) + x = x.reshape(-1, num_steps * 64).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + def test_nested_emit(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + grid=(), + )(x_gmem, o_gmem) + + def nested_kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def nested_kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + def test_emit_with_grid_invariant_output(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + y = jnp.empty_like(x) + for i in range(num_steps): + i_slice = slice(16 * i, 16 * (i + 1)) + y = y.at[:, :16].set(x[:, i_slice] + 1) + # We only compare the elements in the first 16 columns, because the rest + # are never written to. + np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) + + def test_emit_with_parallel_grid(self): + num_steps1 = 4 + num_steps2 = 5 + + def kernel(x_gmem, o_gmem): + pid = pl.program_id(0) + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + grid=(num_steps2,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) + x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(num_steps1,), + ) + y = x + 1.0 + np.testing.assert_array_equal(kernel_fn(x), y) + + def test_emit_with_2d_grid(self): + num_steps1 = 4 + num_steps2 = 5 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], + out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], + grid=(num_steps1, num_steps2), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) + x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + class CoreMapTest(PallasTest): @@ -1047,6 +1432,61 @@ def kernel(): f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) + def test_multiple_wg_with_squashed_grid(self): + # Tests whether a grid with >3 logical dimensions is correctly squashed to + # 3 CUDA grid dimensions. + b = 4 + x_dim = 3 + y_dim = 5 + z_dim = 7 + num_threads = 2 + mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), + num_threads=num_threads, + axis_names=("b", "x", "y", "z", "wg")) + + @jax.jit + def f(): + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def _(): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) + y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) + return inner(y_init) + result = f()[:, :, :, :, :, 0] + ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( + result.shape) + np.testing.assert_array_equal(result, ref) + + + def test_cross_wg_barrier(self): + mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + + @jax.jit + def f(): + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def kernel(): + def scoped(barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + # Each warpgroup is a single logical thread! + pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) + y_init = jnp.zeros((2, 128), np.int32) + return inner(y_init) + np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7b54ef5f9f88..38e359aef3a1 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -17,24 +17,25 @@ from collections.abc import Sequence import functools import itertools +import math import sys from typing import Any import unittest -import numpy as np from absl.testing import absltest from absl.testing import parameterized - import jax -import jax.numpy as jnp from jax import lax from jax import random +from jax._src import config from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu -from jax.interpreters import partial_eval as pe from jax.experimental import pallas as pl +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np if sys.platform != "win32": from jax.experimental.pallas import triton as plgpu @@ -62,6 +63,10 @@ floatx = dtypes.canonicalize_dtype(jnp.float64) +def is_power_of_two(n: int) -> bool: + return (n > 0) and (n & (n - 1) == 0) + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -524,9 +529,6 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): tol = 1e-6 elif name == "exp2": tol = 1e-6 - elif jtu.test_device_matches(["tpu"]): - if not jtu.is_device_tpu_at_least(version=5) and False: - self.skipTest("TODO: not implemented on TPU v{3,4}") def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) @@ -559,10 +561,6 @@ def test_cast(self, from_dtype, to_dtype, data): self.skipTest("Not supported: bad canonicalization") if from_dtype == "bool" and to_dtype in {"int16", "int8"}: self.skipTest("Not supported: cannot extend to sub-32 bit types") - if jtu.test_device_matches(["gpu"]): - if (from_dtype in {"bfloat16", "float32"} and - to_dtype in {"int8", "int16", "int32"}): - self.skipTest("TODO: wrong result on GPU") if from_dtype == "bfloat16": from_dtype = jnp.bfloat16 @@ -721,6 +719,28 @@ def kernel(x_ref, o_ref): expected.astype(jnp.float32), ) + # TODO(twsung): Add more types once lowering is implemented. + @parameterized.parameters( + jnp.float32, + jnp.bfloat16, + jnp.int32, + ) + def test_add_constant(self, dtype): + + shape = (256, 256) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + np.testing.assert_array_equal( + kernel(jnp.zeros(shape, dtype=dtype)), + jnp.ones(shape, dtype=dtype), + ) + @parameterized.parameters( -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, ) @@ -756,6 +776,7 @@ def kernel(x_ref, o_ref): ["float32", "float64"], ), ([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]), + ([jnp.logical_not], ["bool"]) ] @parameterized.named_parameters( @@ -793,12 +814,15 @@ def test_elementwise(self, fn, dtype): self.skipTest(f"{fn.__name__} not implemented on TPU") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1 + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + grid=1, ) def kernel(x_ref, o_ref): o_ref[:] = fn(x_ref[...]) - x = jnp.array([0.42, 2.4]).astype(dtype) + # create an array with shape (8, 128) + x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype) self.assertAllClose(kernel(x), fn(x), rtol=1e-6) @parameterized.named_parameters( @@ -831,9 +855,9 @@ def test_elementwise_scalar(self, fn, dtype): # TODO(b/370578663): implement these lowerings on TPU if jtu.test_device_matches(["tpu"]) and fn in ( - jnp.abs, jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, + jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh, jnp.cbrt, jnp.cosh, jnp.expm1, - jnp.sinh, lax.rsqrt, + jnp.sinh, ): self.skipTest(f"{fn.__name__} not implemented on TPU") @@ -868,9 +892,6 @@ def kernel(x_ref, o_ref): ("float64", "float64"), ) def test_pow(self, x_dtype, y_dtype): - if jtu.test_device_matches(["tpu"]): - self.skipTest("TODO: Error on TPU") - if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -898,24 +919,32 @@ def kernel(x_ref, o_ref): x = jnp.array([1, 2, 3, 4]).astype(jnp.float32) / 10 np.testing.assert_allclose(kernel(x), lax.integer_pow(x, y)) - @parameterized.parameters("float32", "float64") - def test_nextafter(self, dtype): + _NEXTAFTER_VALUES = (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf) + + @parameterized.named_parameters( + (f"{dtype.__name__} ({x=}, {y=})", dtype, x, y) + for dtype, x, y in itertools.product( + (jnp.float32, jnp.float64), _NEXTAFTER_VALUES, _NEXTAFTER_VALUES, + ) + ) + def test_nextafter(self, dtype, x, y): if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") - # TODO: implement this on TPU - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented: nextafter") - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), dtype), ) def kernel(x_ref, y_ref, o_ref): - o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...]) + o_ref[...] = jnp.nextafter(x_ref[...], y_ref[...]) + + x = jnp.full((4,), x, dtype=dtype) + y = jnp.full((4,), y, dtype=dtype) + out = kernel(x, y) + expected = jnp.nextafter(x, y) - x = jnp.array([1, 2, 3, 4]).astype(dtype) - y = jnp.array([1, 2, 3, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) + # `nextafter` requires exact equality + self.assertArraysEqual(out, expected) COMPARISON_OPS = [ jnp.equal, @@ -927,16 +956,17 @@ def kernel(x_ref, y_ref, o_ref): ] @parameterized.named_parameters( - (f"{fn.__name__}_{dtype}", fn, dtype) + (f"{fn.__name__}_{dtype.__name__}", fn, dtype) for fn, dtype in itertools.product( - COMPARISON_OPS, ["int32", "uint32", "float16", "float32", "bool"] + COMPARISON_OPS, + (jnp.int32, jnp.uint32, jnp.float16, jnp.float32, jnp.bool_), ) ) def test_comparison(self, fn, dtype): - if jtu.test_device_matches(["gpu"]) and dtype == "bool": + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") - if jtu.test_device_matches(["tpu"]) and dtype == "float16": + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @functools.partial( @@ -949,16 +979,19 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), fn(x, y)) + out = kernel(x, y) + expected = fn(x, y) + self.assertArraysEqual(out, expected) @parameterized.named_parameters( - (f"{fn.__name__}_{dtype}", fn, dtype) + (f"{fn.__name__}_{dtype.__name__}", fn, dtype) for fn, dtype in itertools.product( - COMPARISON_OPS, ["int32", "uint32", "float16", "float32", "bool"] + COMPARISON_OPS, + (jnp.int32, jnp.uint32, jnp.float16, jnp.float32, jnp.bool_), ) ) def test_comparison_scalar(self, fn, dtype): - if jtu.test_device_matches(["tpu"]) and dtype == "float16": + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") if ( @@ -983,7 +1016,9 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), fn(x, y)) + out = kernel(x, y) + expected = fn(x, y) + self.assertArraysEqual(out, expected) def test_isnan(self): @functools.partial( @@ -996,6 +1031,22 @@ def isnan(x_ref, o_ref): x = x.at[3].set(jnp.nan) np.testing.assert_allclose(isnan(x), jnp.isnan(x)) + def test_jnp_einsum_grad_y_pallas(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test ooms on gpu") + + x = jnp.arange(128 * 256, dtype=jnp.float32).reshape((128, 256)) + y = jnp.arange(256 * 128, dtype=jnp.float32).reshape((128, 256)) + + def kernel(x_ref, y_ref, out_ref): + # grad_y side of grouped matmul + out_ref[...] = jnp.einsum('mk,mn->kn', x_ref[...], y_ref[...]) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32) + )(x, y) + np.testing.assert_array_equal(out, jnp.einsum('mk,mn->kn', x, y)) + @parameterized.parameters( ("int32", "float32"), ("float32", "float32"), @@ -1056,14 +1107,6 @@ def test_binary(self, f, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO(ayx): Fix these operations on TPU - if ( - jtu.test_device_matches(["tpu"]) - and f in (jnp.floor_divide, jnp.subtract) - and dtype == "uint32" - ): - self.skipTest("Not supported on TPU") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 ) @@ -1089,14 +1132,6 @@ def test_binary_scalar(self, f, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO(ayx): Fix these operations on TPU - if ( - jtu.test_device_matches(["tpu"]) - and f in (jnp.floor_divide, jnp.subtract) - and dtype == "uint32" - ): - self.skipTest("Not supported on TPU") - @functools.partial( self.pallas_call, in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), @@ -1210,6 +1245,8 @@ def kernel(x_ref, o_ref): "plgpu.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1378,21 +1415,69 @@ def f(x_ref, o_ref): np.testing.assert_allclose(f(x), expected) @parameterized.product( - size=[16, 32, 64], - dtype=["float32", "float16"], + lhs_and_rhs_shape=[ + ((16, 16), (16, 16)), + ((32, 32), (32, 32)), + ((64, 64), (64, 64)), + ((128, 128), (128, 128)), + ((256, 256), (256, 256)), + ((8, 128), (128, 256)), + ((8, 128), (256, 128)), + ((8, 256), (256, 128)), + ((16, 128), (128, 256)), + ((16, 128), (256, 128)), + ((16, 256), (256, 128)), + ((24, 128), (128, 256)), + ((24, 128), (256, 128)), + ((24, 256), (256, 128)), + ((128, 8), (128, 256)), + ((128, 8), (256, 128)), + ((256, 8), (256, 128)), + ((128, 16), (128, 256)), + ((128, 16), (256, 128)), + ((256, 16), (256, 128)), + ((128, 24), (128, 256)), + ((128, 24), (256, 128)), + ((256, 24), (256, 128)), + ], + dtype=[jnp.float32, jnp.float16, jnp.bfloat16], trans_x=[False, True], trans_y=[False, True], ) - def test_dot(self, size, dtype, trans_x, trans_y): - if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: - self.skipTest("16-bit types are not supported on TPU") + def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): + lhs_shape, rhs_shape = lhs_and_rhs_shape + + final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape + final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape + if final_lhs_shape[1] != final_rhs_shape[0]: + self.skipTest("Contraction dimensions do not match") + + out_shape = (final_lhs_shape[0], final_rhs_shape[1]) if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented: Transposed LHS") + if dtype == jnp.float16: + self.skipTest("float16 type is not supported on TPU") + if dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4): + self.skipTest("bfloat16 matmul is supported on TPUv4+") + if trans_x: + self.skipTest("Not implemented: Transposed LHS") + + if jtu.test_device_matches(["gpu"]): + if dtype == jnp.bfloat16: + self.skipTest("bfloat16 type are not supported on GPU") + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): + self.skipTest("Shared memory size limit exceeded") + if min(*lhs_shape, *rhs_shape) < 16: + self.skipTest("All dimensions of lhs and rhs must be >= 16") + if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape): + self.skipTest("All dimensions of lhs and rhs must be power of two") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((size, size), dtype), + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), grid=1, ) def dot(x_ref, y_ref, o_ref): @@ -1401,11 +1486,16 @@ def dot(x_ref, y_ref, o_ref): o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype) k1, k2 = random.split(random.key(0)) - x = random.normal(k1, (size, size), dtype=dtype) - y = random.normal(k2, (size, size), dtype=dtype) + x = random.normal(k1, lhs_shape, dtype=dtype) + y = random.normal(k2, rhs_shape, dtype=dtype) out = dot(x, y) expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected.astype(jnp.float32), + atol=0.05, + rtol=0.05, + ) @parameterized.product( size=[1, 2, 64, 129, 1021], @@ -1865,11 +1955,131 @@ def reduce(x_ref, y_ref): y_ref = jnp.cumsum(x, axis=axis) np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) + @parameterized.parameters( + (0, jnp.float32), + (0, jnp.bfloat16), + (1, jnp.float32), + (1, jnp.bfloat16), + (-1, jnp.float32), + (-1, jnp.bfloat16), + ) + def test_triu(self, k, dtype): + if dtype == jnp.bfloat16 and jtu.test_device_matches(["tpu"]): + # TODO(mvoz): b/376330700 + raise unittest.SkipTest('NYI - bf16 select') + + x = jnp.arange(128 * 256, dtype=dtype).reshape((128, 256)) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.triu(x_ref[...], k=k) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((128, 256), dtype) + )(x) + np.testing.assert_array_equal(out, np.triu(x, k=k)) + + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + m, n = 4, 4 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + def test_bitcast_convert_type_scalar(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + x = jnp.int32(42) + out_dtype = jnp.float32 + out_shape = jax.ShapeDtypeStruct(x.shape, out_dtype) + grid = () + + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_dtype) + + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + @parameterized.product( + array_shapes=[(4, 128), (10, 100), (8, 128), (17, 257)], + padding=[ + ((5, 8), (0, 0)), + ((0, 0), (5, 100)), + ((1, 1), (1, 1)), + ((0, 0), (0, 0)), + ], + pad_type=["constant", "wrap"], + dtype=( + jnp.float32, + jnp.bfloat16, + ), + ) + def test_arbitrary_padding_jnp_pad( + self, array_shapes, padding, pad_type, dtype + ): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not implemented on GPU") + + x = jnp.arange(np.prod(array_shapes), dtype=dtype).reshape(array_shapes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.pad(x_ref[...], padding, mode=pad_type) + + ref = jnp.pad(x, padding, mode=pad_type) + + out_shape = jax.ShapeDtypeStruct(ref.shape, x.dtype) + try: + out = self.pallas_call( + kernel, + out_shape=out_shape, + )(x) + np.testing.assert_array_equal(out, jnp.pad(x, padding, mode=pad_type)) + except Exception as e: + self.assertEqual( + dtype, + jnp.bfloat16, + "some bfloat16 combinations can fail with not implemented", + ) + # The first two options are expected to fail due to current limitations + # in the Pallas TPU lowering. However, the last one is unexpected, and + # should be fixed, it is a pjrt bug. + # b/379787665 + acceptable_errors = ( + "Only 32-bit types supported" in str(e) + or "Not implemented" in str(e) + or "Expected mask vector type" in str(e) + ) + self.assertTrue(acceptable_errors, "Failed with error: " + str(e)) + class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py new file mode 100644 index 000000000000..d9eb18e6f540 --- /dev/null +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -0,0 +1,114 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import lax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import cost_estimate +from jax._src.state import discharge + + +config.parse_flags_with_absl() + + +class PallasCostEstimateTest(jtu.JaxTestCase): + + def test_exp_add(self): + def exp_add(x, y): + return jnp.exp(x + y) + cost = cost_estimate.estimate_cost(exp_add, + jnp.ones(10, dtype=jnp.float32), + jnp.ones(10, dtype=jnp.float32)) + self.assertEqual(cost.flops, 10) + self.assertEqual(cost.transcendentals, 10) + self.assertEqual(cost.bytes_accessed, 4 * 30) + + def test_very_large_matmul(self): + def matmul(a, b): + return a @ b + m, k, n = 400_000, 800_000, 900_000 + cost = cost_estimate.estimate_cost( + matmul, + jax.ShapeDtypeStruct((m, k), jnp.bfloat16), + jax.ShapeDtypeStruct((k, n), jnp.bfloat16)) + self.assertEqual(cost.flops, 2*m*k*n) + self.assertEqual(cost.transcendentals, 0) + self.assertEqual(cost.bytes_accessed, 2*(m*k + n*k + m*n)) + + def test_batched_matmul(self): + def matmul(a, b): + return jnp.matmul(a, b) + b, m, k, n = 7, 37, 91, 23 + cost = cost_estimate.estimate_cost( + matmul, + jax.ShapeDtypeStruct((b, m, k), jnp.float32), + jax.ShapeDtypeStruct((b, k, n), jnp.float32)) + self.assertEqual(cost.flops, 2*b*m*k*n) + self.assertEqual(cost.transcendentals, 0) + self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n)) + + def test_attention(self): + qk_dim = 16 + v_dim = 4 + kv_len = 128 + q_len = 64 + def attention(q, k, v): + return jax.nn.softmax(q @ k.T, axis=-1) @ v + cost = cost_estimate.estimate_cost( + attention, + jnp.zeros((q_len, qk_dim), dtype=jnp.float32), + jnp.zeros((kv_len, qk_dim), dtype=jnp.float32), + jnp.zeros((kv_len, v_dim), dtype=jnp.float32)) + qk_cost = 2 * q_len * kv_len * qk_dim + v_cost = 2 * q_len * kv_len * v_dim + softmax_flops = kv_len * q_len + self.assertEqual(cost.flops, qk_cost + v_cost + 2 * softmax_flops + q_len) + self.assertEqual(cost.transcendentals, softmax_flops) + input_bytes = q_len * qk_dim + kv_len * qk_dim + kv_len * v_dim + output_bytes = q_len * v_dim + self.assertEqual(cost.bytes_accessed, 4 * (input_bytes + output_bytes)) + + @parameterized.parameters( + (1, 0), (7, 5), (8, 4), (9, 5) + ) + def test_integer_pow(self, power, expected_flops_per_element): + cost = cost_estimate.estimate_cost(lambda x: lax.integer_pow(x, power), + jnp.ones(10, dtype=jnp.float32)) + self.assertEqual(cost.flops, 10 * expected_flops_per_element) + self.assertEqual(cost.transcendentals, 0) + self.assertEqual(cost.bytes_accessed, 80) + + def test_run_state(self): + def add_refs(refs): + x_ref, y_ref, z_ref = refs + x = x_ref[:] + y = y_ref[:] + z = x + y + z_ref[:] = z + input_shape = jax.ShapeDtypeStruct((100,), jnp.float32) + cost = cost_estimate.estimate_cost( + discharge.run_state(add_refs), + (input_shape, input_shape, input_shape)) + self.assertEqual(cost.flops, 100) + self.assertEqual(cost.transcendentals, 0) + # TODO(justinfu): This is off by a factor of 2 because run_state + # has all inputs/outputs as both arguments and return values. + self.assertEqual(cost.bytes_accessed / 2, 3 * 4 * 100) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index f26352da0f38..509ef08a987f 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -41,6 +41,21 @@ floatx = dtypes.canonicalize_dtype(jnp.float64) +def _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res, ref +): + total_columns = col_grid_size * 128 + mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool) + + for i, r in enumerate(ragged_shape): + mask = mask.at[i, :, : r * 128].set(True) + + res_valid = jnp.where(mask, res, -1) + ref_valid = jnp.where(mask, ref, -1) + + np.testing.assert_allclose(res_valid, ref_valid) + + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -104,24 +119,16 @@ def invoke_kernel(x): axis_size=3, )(x) - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - - def correct(v): - return np.count_nonzero(v == jnp.sin(1.0)) - - for b, batch in enumerate(res): - ragged_val = ragged_shape[b] - for r, row in enumerate(batch): - row_total = ragged_val * 128 - self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}") + ref = jax.vmap( + jnp.sin, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) - self.assertEqual(correct(res), ragged_total) + _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res.data, ref.data + ) def test_vmap_jumble_over_add_kernel(self): if not jtu.test_device_matches(["tpu"]): @@ -156,36 +163,34 @@ def invoke_kernel(x, y): (8, col_grid_size * 128), dtype=jnp.float32 ), grid=(1, col_grid_size), - interpret=False, + interpret=self.INTERPRET, )(x, y) - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) + # We've had this test fail with data corruption due to multiple + # invocations, so we run it k times to make sure it's not setting up + # memory incorrectly for subsequent invocations. + for _ in range(4): + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - - def correct(v): - return np.count_nonzero(v == 2.0) - - for r, row in enumerate(res): - ragged_val = ragged_shape[r] - row_total = ragged_val * 128 * row_count - self.assertEqual(correct(row), row_total) - for col in row: - col_total = ragged_val * 128 - self.assertEqual(correct(col), col_total) - - self.assertEqual(np.count_nonzero(res == 2.0), ragged_total) + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + + ref = jax.vmap( + lambda x, y: x + y, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) + _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res, ref.data + ) def test_vmap_jumble_over_sin_kernel_grid_remapping(self): if not jtu.test_device_matches(["tpu"]): @@ -212,7 +217,7 @@ def invoke_kernel(x): out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), grid=(1, 5), - interpret=False, + interpret=self.INTERPRET, )(x) with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): @@ -227,6 +232,9 @@ def test_vmap_jumble_over_matmul_kernel(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Only tested on TPU") + if jtu.is_device_tpu(version=4): + self.skipTest("Flaky 15% of the time on tpuv4?") + m = 128 k = 640 n = 640 @@ -277,7 +285,7 @@ def matmul( ), grid=grid, input_output_aliases={2: 0}, - interpret=False, + interpret=self.INTERPRET, )(x, y, x_sentinel) # TODO(mvoz): parameterize this shape? diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bc2c237ffa94..39bd279e8bce 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -687,6 +687,46 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) + @parameterized.parameters( + ("float32", None), + ("float32", jax.lax.Precision.DEFAULT), + ("float32", jax.lax.Precision.HIGH), + ("float32", jax.lax.Precision.HIGHEST), + ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), + ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), + ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), + ("bfloat16", None), + ("bfloat16", jax.lax.Precision.DEFAULT), + ("bfloat16", jax.lax.Precision.HIGHEST), + ("bfloat16", jax.lax.DotAlgorithmPreset.DEFAULT), + ("bfloat16", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ) + def test_dot_precision(self, dtype, precision): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("`DotAlgorithmPreset` only supported on GPU.") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((32, 64), jnp.float32), + grid=1, + ) + def dot_kernel(x_ref, y_ref, o_ref): + o_ref[()] = pl.dot(x_ref[()], y_ref[()], precision=precision) + + key0, key1 = random.split(random.key(0)) + x = random.normal(key0, (32, 16), dtype=dtype) + y = random.normal(key1, (16, 64), dtype=dtype) + expected = jnp.dot( + x, + y, + precision=jax.lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True @@ -1300,11 +1340,7 @@ def if_true(z): np.testing.assert_allclose(f(jnp.bool_(False), arg), -arg) - # We actually expect the assertion failure in linearize, but this also - # covers another case where an effect was causing an earlier assertion - # failure. - with self.assertRaises(AssertionError): - # Notably, we should not have a ValueError for mismatched Read effect. + with self.assertRaisesRegex(ValueError, "Linearization failed"): _ = jax.grad(lambda x: jnp.sum(f(jnp.bool_(True), x)**2))(arg) # np.testing.assert_allclose( # dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14])) @@ -1357,7 +1393,7 @@ def body_fn(i, args): 16 * x * params[4, 2]) np.testing.assert_allclose(f(program, params, x), expected) - with self.assertRaises(AssertionError): + with self.assertRaisesRegex(ValueError, "Linearization failed"): jax.value_and_grad(lambda params, x: f(program, params, x).sum())( params, x) @@ -1411,7 +1447,7 @@ def body_fn(i, args): 16 * x * params[4, 2]) np.testing.assert_allclose(f(program, params, x), expected) - with self.assertRaises(AssertionError): + with self.assertRaisesRegex(ValueError, "Linearization failed"): jax.value_and_grad(lambda params, x: f(program, params, x).sum())( params, x) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index fefccfe7eb4f..ffa6195625dd 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -22,6 +22,7 @@ import jax from jax import random from jax._src import config +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl @@ -35,6 +36,10 @@ config.parse_flags_with_absl() +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -42,8 +47,6 @@ class PallasBaseTest(jtu.JaxTestCase): def setUp(self): if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled: - self.skipTest("On GPU the test works only in 32-bit") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") @@ -67,7 +70,7 @@ def setUp(self): def test_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -77,7 +80,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_simple_kernel_with_in_axes_None(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add(x_ref, y_ref, o_ref): o_ref[()] = x_ref[()] + y_ref[()] @@ -87,7 +90,7 @@ def add(x_ref, y_ref, o_ref): def test_double_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -97,7 +100,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -108,7 +111,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_batched_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), intx), grid=(7,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -120,7 +123,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_slicing_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -151,7 +154,7 @@ def kernel(src, dst): def test_vmap_of_kernel_with_input_output_aliases(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), input_output_aliases={1:0}, grid=()) def add(x_ref, _, o_ref): @@ -163,7 +166,7 @@ def add(x_ref, _, o_ref): def test_vmap_of_kernel_with_input_output_aliases_different_axes(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((4,), intx), input_output_aliases={0: 0}, grid=(), ) @@ -176,7 +179,7 @@ def add(x_ref, o_ref): def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -194,7 +197,7 @@ def add_one(x_ref, o_ref): def test_double_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx), grid=(4,)) def sin(x_ref, o_ref): i = pl.program_id(0) @@ -211,7 +214,7 @@ def sin(x_ref, o_ref): def test_small_large_vmap(self): # Catches https://github.com/jax-ml/jax/issues/18361 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -230,7 +233,7 @@ def add_one(x_ref, o_ref): def test_small_small_large_vmap(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -249,12 +252,6 @@ def add_one(x_ref, o_ref): class PallasCallVmapInterpretTest(PallasCallVmapTest): INTERPRET = True - def setUp(self): - super().setUp() - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") - if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index ca5361a70051..8843c6a58064 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -233,6 +233,24 @@ def run(cond, lhs, rhs): assert (run(cond, lhs, rhs) == lhs).all() + def test_logical_and_relayouted_mask(self): + def get_mask(x_ref): + x = x_ref[...] == 1 + iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1) + iota = iota > 7 + return jnp.logical_and(x, iota) + + def body(x_ref, y_ref): + y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0) + + shape = (2, 512) + out = jax.ShapeDtypeStruct(shape, jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape) + result = self.pallas_call(body, out_shape=out)(x) + expected = jnp.ones(x.shape, dtype=jnp.float32) + expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) + np.testing.assert_array_equal(result, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7e3eaaf0736..7b3bd70efafe 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -15,6 +15,8 @@ """Tests for distributed pallas TPU operations.""" import functools +import os +import tempfile from absl.testing import absltest from absl.testing import parameterized import jax @@ -513,5 +515,62 @@ def _(): atol=1e-5, rtol=1e-3) + +class VerificationTest(jtu.JaxTestCase): + + def test_verification(self): + if (num_devices := jax.local_device_count()) <= 1: + self.skipTest('Test requires multiple devices.') + if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1: + self.skipTest('Test requires a new single-core TPU.') + def kernel_body(in_ref, out_ref, scratch_ref, send_sem, recv_sem, capacity_sem): + my_id = lax.axis_index('x') + dst_id = jnp.where(my_id == num_devices - 1, 0, my_id + 1) + src_id = jnp.where(my_id == 0, num_devices - 1, my_id - 1) + pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id) + out_ref[...] = jnp.zeros_like(out_ref) + scratch_ref[0] = in_ref[0] + + @functools.partial(lax.fori_loop, 0, num_devices - 1, init_val=None) + def _(i, _): + slot = i % 2 + next_slot = 1 - slot + pltpu.semaphore_wait(capacity_sem, 1) + copy = pltpu.async_remote_copy( + scratch_ref.at[slot], + scratch_ref.at[next_slot], + send_sem, + recv_sem, + device_id=dst_id, + ) + out_ref[...] += scratch_ref[slot] + copy.wait() + pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id) + out_ref[...] += scratch_ref[(num_devices - 1) % 2] + pltpu.semaphore_wait(capacity_sem, 1) + + kernel = pl.pallas_call( + kernel_body, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + scratch_shapes=[ + pltpu.VMEM((2, 128, 128), jnp.float32), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ], + ) + devices = mesh_utils.create_device_mesh((num_devices,)) + mesh = jax.sharding.Mesh(devices, ['x']) + # This is just a smoke test to ensure that the verification does not crash. + with tempfile.TemporaryDirectory() as tmpdir: + previous_config = jax.config.read('jax_pallas_dump_promela_to') + jax.config.update('jax_pallas_dump_promela_to', tmpdir) + shard_map.shard_map( + kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False + )(jnp.ones((8, 128, 128), jnp.float32)) + jax.config.update('jax_pallas_dump_promela_to', previous_config) + self.assertNotEmpty(os.listdir(tmpdir)) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index ca64275d3f09..2af00cf6b8c6 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -486,12 +486,11 @@ def _wait_on_prev_dma(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs ), - compiler_params=dict( - mosaic=dict(collective_id=0, - # must set scoped vmem flag *larger* than below! e.g.: - # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB ), ) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index 2b5c315263c9..88c33a020ce9 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -20,10 +20,14 @@ from jax._src import test_util as jtu from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl +from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401 import jax.numpy as jnp import numpy as np +P = jax.sharding.PartitionSpec + jax.config.parse_flags_with_absl() @@ -253,6 +257,53 @@ def body(key_ref, o_ref): ) np.testing.assert_array_equal(result, jax_result) + @parameterized.parameters( + ((512, 512),), + ((137, 275),), # Non block-aligned shape + ((4, 512, 512),), # Greater than 2D shape + ((34,),), # 1D + (tuple(),), # 0D + ) + def test_threefry_kernel_matches_jax_threefry(self, shape): + with jax.threefry_partitionable(True): + key_jax = jax_random.key(0, impl="threefry2x32") + jax_gen = jax_random.bits(key_jax, shape=shape) + key_pl = jax_random.key(0, impl="pallas_threefry2x32") + pl_gen = jax_random.bits(key_pl, shape=shape) + + np.testing.assert_array_equal(jax_gen, pl_gen) + + @parameterized.parameters( + ((256, 256),), + ((35, 113),), # Non block-aligned shape + ((331,),), # 1D + ) + def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): + if jax.device_count() < 2: + self.skipTest("Need at least 2 devices") + num_devices = jax.device_count() + partition = P("x") + mesh = jax.make_mesh((num_devices,), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + with jax.threefry_partitionable(True): + key_jax = jax_random.split( + jax_random.key(0, impl="threefry2x32"), num_devices) + key_pallas = jax_random.split( + jax_random.key(0, impl="pallas_threefry2x32"), num_devices) + key_jax = jax.device_put(key_jax, sharding) + key_pallas = jax.device_put(key_pallas, sharding) + generate = shard_map.shard_map( + lambda x: jax_random.bits(x[0], shape=shape), + mesh=mesh, + in_specs=partition, + out_specs=partition, + ) + jax_gen = generate(key_jax) + pl_gen = generate(key_pallas) + + np.testing.assert_array_equal(jax_gen, pl_gen) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 49dd127b76fe..9c4788d7447f 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -870,8 +870,9 @@ def scope(): pl.run_scoped(scope) return [] - aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) - in_avals = [aref, aref] + aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + in_avals = [aref1, aref2] stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( @@ -1471,6 +1472,40 @@ def kernel(index, x, y, sem): np.testing.assert_array_equal(y, i) del y + def test_dynamic_dma_on_2nd_minor(self): + def kernel(array, data, index, size, _, sem): + pltpu.async_copy( + data.at[pl.ds(0, size[0])], array.at[pl.ds(index[0], size[0])], sem + ).wait() + + def run(array, data, index, size): + return pl.pallas_call( + kernel, + out_shape=array, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + scratch_shapes=[ + pltpu.SemaphoreType.DMA, + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(array, data, index, size) + + array = jnp.zeros((1024, 128), jnp.int32) + data = jnp.ones((8, 128), jnp.int32) + index = jnp.array([3], jnp.int32) + size = jnp.array([5], jnp.int32) + + expected = array.at[index[0] : index[0] + size[0]].set( + data[index[0] : index[0] + size[0]] + ) + result = run(array, data, index, size) + np.testing.assert_array_equal(result, expected) + class PallasCallDMAInterpretTest(PallasCallDMATest): INTERPRET = True @@ -1587,7 +1622,6 @@ def kernel(x, y): self.assertEqual(analysis_result['transcendentals'], 21) self.assertEqual(analysis_result['bytes accessed'], 12345) - def test_cost_analysis_vmap(self): def kernel(x, y): y[:] = x[:] @@ -1606,7 +1640,6 @@ def kernel(x, y): self.assertEqual(analysis_result['transcendentals'], batch_size * 21) self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345) - def test_vmem_limit(self): shape = (128, 128) @@ -1673,6 +1706,59 @@ def kernel(x_ref, y_ref): ), )(x) + @parameterized.product(dtype=[jnp.bfloat16, jnp.float32]) + def test_pltpu_repeat(self, dtype): + def test_kernel(x_ref, o_ref): + x = x_ref[...] + o_ref[...] = pltpu.repeat(x, 2, axis=1) + + @jax.jit + def test(x: jax.Array) -> jax.Array: + return pl.pallas_call( + test_kernel, + out_shape=jax.ShapeDtypeStruct([x.shape[0], x.shape[1] * 2], x.dtype), + )(x) + + x = jnp.arange(2048, dtype=dtype).reshape((8, 256)) + y = test(x) + np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) + + def test_masked_store(self): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("Test requires masked store support") + shape = (16, 256) + mask_shape = (10, 130) + mask_start = (4, 5) + dtype = jnp.float32 + def body(scalar_ref, x_ref, o_ref): + o_ref[...] = jnp.full(shape, -1, dtype=dtype) + b0, b1 = scalar_ref[0], scalar_ref[1] + e0, e1 = b0 + mask_shape[0], b1 + mask_shape[1] + iota0 = lax.broadcasted_iota(jnp.int32, shape, 0) + iota1 = lax.broadcasted_iota(jnp.int32, shape, 1) + mask0 = jnp.logical_and(b0 <= iota0, iota0 < e0) + mask1 = jnp.logical_and(b1 <= iota1, iota1 < e1) + pl.store( + o_ref, + (slice(None), slice(None)), + x_ref[...], + mask=jnp.logical_and(mask0, mask1), + ) + + s = jnp.array(mask_start, jnp.int32) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + out = pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + ), + )(s, x) + slices = tuple(slice(b, b + l) for b, l in zip(mask_start, mask_shape)) + expected = jnp.full(shape, -1, dtype=dtype) + expected = expected.at[slices].set(x[slices]) + np.testing.assert_array_equal(out, expected) + class PallasUXTest(PallasBaseTest): diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 8f3f7b2d3c0f..fb144cacbc98 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -17,34 +17,38 @@ import logging import math import os +import shutil import tempfile -import unittest from absl.testing import absltest import jax +from jax._src import api +from jax._src import compilation_cache as cc from jax._src import config -from jax._src import profiler -from jax._src import pjit from jax._src import monitoring +from jax._src import pjit +from jax._src import profiler from jax._src import test_util as jtu -from jax._src import api from jax.experimental import profiler as exp_profiler -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec -from jax._src import compilation_cache as cc -import numpy as np - from jax.experimental.serialize_executable import ( deserialize_and_load, serialize, ) +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec +import numpy as np jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["gpu"]): + self.skipTest('Profile-guideded latency estimation only supported on GPU') + cc.set_cache_dir(None) cc.reset_cache() @@ -52,7 +56,6 @@ def tearDown(self): cc.set_cache_dir(None) super().tearDown() - @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) @@ -60,6 +63,7 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y @@ -81,7 +85,6 @@ def f(x, y): self.assertIsNotNone(fdo_profile) self.assertIn(b'custom', fdo_profile) - @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfileLarge(self): mesh = jtu.create_mesh((2,), ('x',)) its = 500 @@ -90,6 +93,11 @@ def testPGLEProfilerGetFDOProfileLarge(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + }, ) def f(x): agg = x @@ -100,53 +108,80 @@ def f(x): shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - with config.pgle_profiling_runs(0): - f_lowered = f.lower(x) - f_compiled = f_lowered.compile() - pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): - f_compiled(x) + f(x) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertEqual(fdo_profile.count(b'custom'), its) + def get_fdo_profiles(self, dump_dir): + jit_f_fdo_profiles = [ + x + for x in os.listdir(dump_dir) + if 'jit_f' in x and x.endswith('.fdo_profile') + ] + return jit_f_fdo_profiles + def testAutoPgle(self): mesh = jtu.create_mesh((2,), ('x',)) - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - ) - def f(x): - return x * 2 - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - expected = x * 2 - - with config.pgle_profiling_runs(2), config.enable_pgle(True): - # Run 1: Module should be compiled without FDO. Two modules are expected - # One is the funtion f, the other one is multi slice module - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) - - # Run 2: Second PGLE run should not recompile the module - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 0) - - # Run 3: The module should be recompiled with FDO profiles - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) - - # Run 4: Fast-path should be used after PGLE is done - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 0) + with tempfile.TemporaryDirectory() as dump_dir: + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True' + }, + ) + def f(x): + return x * 2 + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + expected = x * 2 + + with config.pgle_profiling_runs(2), config.enable_pgle(True): + # Run 1: Module should be compiled without FDO. Two modules are expected + # One is the funtion f, the other one is multi slice module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 2: Second PGLE run. Profile should be empty. + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) + # One for before and one for after optimization. + self.assertLen(fdo_profiles_before_pgle, 2) + # The FDO profile file should be empty. + self.assertEqual( + os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) + + # Run 3: The module should be recompiled with FDO profiles + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) + # One for before and one for after optimization. + self.assertLen(fdo_profiles_after_pgle, 4) + + for fdo_profile in fdo_profiles_after_pgle: + if fdo_profile not in fdo_profiles_before_pgle: + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 + ) + + # Run 4: Fast-path should be used after PGLE is done + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertLess(cache_miss_count[0], 2) def testAutoPgleWithAot(self): @jax.jit @@ -171,101 +206,110 @@ def f(x): self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count[0], 0) - @unittest.skip("Test failing in CI") def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - ) - def f(x): - agg = x - for _ in range(its): - agg = agg @ x - return agg - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - - profilers_dict = ( - pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict) - with (config.enable_compilation_cache(True), - config.enable_pgle(True), - config.raise_persistent_cache_errors(True), - config.raise_persistent_cache_errors(True), - config.persistent_cache_min_entry_size_bytes(0), - config.persistent_cache_min_compile_time_secs(0), - config.pgle_profiling_runs(2), - tempfile.TemporaryDirectory() as tmpdir): - cc.set_cache_dir(tmpdir) - # Run 1: Module should be compiled without FDO - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - f(x) - self.assertEqual(cache_miss_count[0], 1) - - # Non-pgle profiled version of module should be saved - non_pgle_profiled_files = os.listdir(tmpdir) - if len(non_pgle_profiled_files) > 1: - non_pgle_profiled_files = [ - f for f in non_pgle_profiled_files if 'cache' in f + with tempfile.TemporaryDirectory() as dump_dir: + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True' + }, + ) + def f(x): + agg = x + for _ in range(its): + agg = agg @ x + return agg + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + + with (config.enable_compilation_cache(True), + config.enable_pgle(True), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + config.pgle_profiling_runs(2), + tempfile.TemporaryDirectory() as cache_dir): + cc.reset_cache() + cc.set_cache_dir(cache_dir) + # Run 1: Module should be compiled without FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + # Non-pgle profiled version of module should be saved + non_pgle_profiled_files = os.listdir(cache_dir) + self.assertNotEmpty(non_pgle_profiled_files) + + # Run 2: Compilation should not be called + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) + # Run 3: Module should be compiled with FDO and stored to persistent cache + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + # Check if FDO profile file of the biggest module is not empty + fdo_profiles_after_pgle = [ + x + for x in self.get_fdo_profiles(dump_dir) + if x not in fdo_profiles_before_pgle ] - - self.assertLen(non_pgle_profiled_files, 1) - - # Run 2: Compilation should not be called - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertNotEmpty(fdo_profiles_after_pgle) + + # Check if FDO profile file in dump directory is not empty + for fdo_profile in fdo_profiles_after_pgle: + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 + ) + + for pgle_profiler in pjit._pgle_profiler_dict.values(): + self.assertTrue(pgle_profiler.is_enabled()) + self.assertTrue(pgle_profiler.is_fdo_consumed()) + + files_after_pgle_profile = os.listdir(cache_dir) + self.assertGreater( + len(files_after_pgle_profile), len(non_pgle_profiled_files) + ) + + # Removing non-pgle profiled module from cache to check that later pgle + # profiled version will be used. + for non_pgle_file in non_pgle_profiled_files: + path = os.path.join(cache_dir, non_pgle_file) + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + # Run 4: Persistent compilation cache should be hit PGLE profiler should + # be disabled + cache_hit = 0 + def check_if_cache_hit(event): + nonlocal cache_hit + if event == '/jax/compilation_cache/cache_hits': + cache_hit += 1 + + monitoring.register_event_listener(check_if_cache_hit) f(x) - self.assertEqual(cache_miss_count[0], 0) - - # Run 3: Module should be compiled with FDO and stored to persistent cache - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - f(x) - self.assertEqual(cache_miss_count[0], 1) - - for pgle_profiler in profilers_dict.values(): - self.assertTrue(pgle_profiler.is_enabled()) - self.assertTrue(pgle_profiler.is_fdo_consumed()) - # One module is PGLEd version another one is not PGLEd - files_after_pgle_profile = os.listdir(tmpdir) - if len(files_after_pgle_profile) > 2: - files_after_pgle_profile = [ - f for f in files_after_pgle_profile if 'cache' in f - ] - self.assertLen(os.listdir(tmpdir), 2) - - self.assertLen(files_after_pgle_profile, 2) - non_pgled_file_size = os.path.getsize( - os.path.join(tmpdir, files_after_pgle_profile[0]) - ) - pgled_file_size = os.path.getsize( - os.path.join(tmpdir, files_after_pgle_profile[1]) - ) - # Make sure that FDO profile were applied to the module - self.assertNotEqual(pgled_file_size, non_pgled_file_size) - - # Removing non-pgle profiled module from cache to check that later pgle - # profiled version will be used. - os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0])) - - api.clear_caches() - profilers_dict.clear() - - # Run 4: Persistent compilation cache should be hit PGLE profiler should - # be disabled - cache_hit = 0 - def check_if_cache_hit(event): - nonlocal cache_hit - if event == '/jax/compilation_cache/cache_hits': - cache_hit += 1 - - monitoring.register_event_listener(check_if_cache_hit) - f(x) - monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - self.assertEqual(cache_hit, 1) + self.assertGreater(cache_hit, 0) def testPassingFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) @@ -274,6 +318,7 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y @@ -286,11 +331,11 @@ def f(x, y): f_lowered = f.lower(x, y) compiled = f_lowered.compile() - with tempfile.TemporaryDirectory() as tmpdir: - jax.profiler.start_trace(tmpdir) + with tempfile.TemporaryDirectory() as cache_dir: + jax.profiler.start_trace(cache_dir) compiled(x, y) jax.profiler.stop_trace() - directories = glob.glob(os.path.join(tmpdir, 'plugins/profile/**/')) + directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) directories = [d for d in directories if os.path.isdir(d)] rundir = directories[-1] logging.info('rundir: %s', rundir) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d3b96676afdc..5bb09043568c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -17,6 +17,7 @@ import re from functools import partial import logging +import json import math import textwrap import threading @@ -38,7 +39,6 @@ from jax import stages from jax import lax from jax._src.lax import lax as lax_internal -from jax._src.lib import xla_extension_version from jax.lax import with_sharding_constraint from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh @@ -52,7 +52,6 @@ from jax._src.sharding_impls import ( AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) -import jax._src.pjit as pjit_lib from jax._src.pjit import pjit from jax._src import mesh as mesh_lib from jax._src.interpreters import pxla @@ -1293,6 +1292,34 @@ def f(x): with self.assertRaisesRegex(ValueError, "spmd_axis_name"): jax.vmap(f, spmd_axis_name='x')(xs) + def test_cache_bug(self): + devices = list(jax.devices()) + if len(devices) < 2: + raise unittest.SkipTest("Test requires 2 devices") + + def under_jvp(f): + return jax.jvp(f, (), ()) + + x0 = jnp.zeros(1, device=devices[0]) + x1 = jnp.zeros(1, device=devices[1]) + + # comments describe how caches worked under the old `_most_recent_pjit_call_executable` system + under_jvp(lambda: jnp.sin(x0)) # cpp_pjit miss, pjit_call_impl miss + jnp.sin(x1) # cpp_pjit miss, pjit_call_impl miss + ans1 = jnp.sin(x0) # cpp_pjit miss, pjit_call_impl hit. Bad cpp_pjit entry created + ans2 = jnp.sin(x0) # cpp_pjit hit with bad cache entry + assert(ans1.devices() == ans2.devices()) + + def test_zero_literal_equality(self): + # This test verifies that we don't accidentally conflate positive and + # negative zeros when deduplicating literals in the IR. + f = jax.jit(lambda x: (x / np.float32(-0.0), x / np.float32(0.0))) + a, b = f(np.float32(1.0)) + self.assertEqual(a, -np.inf) + self.assertEqual(b, np.inf) + ir = f.lower(np.float32(1.0)).as_text() + self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) + self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): @@ -1871,11 +1898,6 @@ def _checks(out, input_data): ) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): - if config.use_shardy_partitioner.value: - self.skipTest( - 'TODO(b/355263220) Shardy conflict resolution is not complete. Issue ' - 'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while ' - 'Shardy gives it fully replicated.') global_mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) @@ -2133,13 +2155,13 @@ def add(x, y): return x + y out = add(a, b) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, a + b) self.assertFalse(out._committed) out2 = add(out, out) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out2, array.ArrayImpl) self.assertArraysEqual(out2, 2 * (a + b)) self.assertFalse(out2._committed) @@ -2149,7 +2171,7 @@ def add(x, y): c = jax.device_put(a, jax.devices()[0]) out3 = add(c, c) - cache_info3 = pjit_lib._pjit_lower_cached.cache_info() + cache_info3 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(out3, 2 * c) self.assertTrue(out3._committed) @@ -2192,14 +2214,11 @@ def test_pjit_different_device_recompilation(self): f = pjit(lambda x: x) - out1 = f(a) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() - - out2 = f(b) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + with jtu.count_jit_compilation_cache_miss() as count: + out1 = f(a) + out2 = f(b) + self.assertEqual(count[0], 2) - self.assertEqual(cache_info2.hits, cache_info1.hits) - self.assertEqual(cache_info2.misses, cache_info1.misses + 1) self.assertArraysEqual(out1, val1) self.assertArraysEqual(out2, val2) @@ -2856,13 +2875,13 @@ def f(x, y, z): return x, y, z o1, o2, o3 = f(a, y=b, z=c) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(o1, a) self.assertArraysEqual(o2, b) self.assertArraysEqual(o3, c) o4, o5, o6 = f(x=a, y=b, z=c) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(o4, a) self.assertArraysEqual(o5, b) self.assertArraysEqual(o6, c) @@ -2871,7 +2890,7 @@ def f(x, y, z): self.assertEqual(cache_info2.misses, cache_info1.misses + 1) o7, o8, o9 = f(a, b, c) - cache_info3 = pjit_lib._pjit_lower_cached.cache_info() + cache_info3 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(o7, a) self.assertArraysEqual(o8, b) self.assertArraysEqual(o9, c) @@ -2958,26 +2977,19 @@ def _check(out, expected_device, expected_out): x = jnp.arange(8).reshape(4, 2) f_out = f(x) f_out2 = f(f_out) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() _check(f_out, jax.devices()[1], x) _check(f_out2, jax.devices()[1], f_out) y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y'))) out2 = f(y) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() _check(out2, jax.devices()[1], y) - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): h = pjit(mul, device=jax.devices()[-1]) h_out = h(y) - cache_info3 = pjit_lib._pjit_lower_cached.cache_info() _check(h_out, jax.devices()[-1], y) - self.assertEqual(cache_info3.hits, cache_info2.hits) - # AOT test compiled = f.lower(core.ShapedArray(y.shape, y.dtype)).compile() out3 = compiled(y) @@ -3399,6 +3411,9 @@ def test_device_assignment_mismatch_apply_primitive(self): def test_device_put_grad(self): if jax.device_count() < 8: self.skipTest("Requires >=8 devices.") + if jtu.is_device_tpu(5, 'e'): + self.skipTest('TPU v5e does not support computations that run on a ' + 'non-singleton subset of cores.') def _test(fun, inp, np_inp, in_s): out = fun(inp) @@ -3507,11 +3522,11 @@ def mul(x): with jtu.count_pjit_cpp_cache_miss() as count: out = f(arr) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out.sharding, NamedSharding) out2 = f(np_arr) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out2.sharding, NamedSharding) # Drops out of C++ cache i.e. cache miss @@ -3545,6 +3560,13 @@ def identity(x): out2 = pjit(identity)(arr2) self.assertIsInstance(out2.sharding, PositionalSharding) + def test_wsc_error_on_none(self): + with self.assertRaisesRegex( + ValueError, + 'One of with_sharding_constraint arguments got sharding None which is' + ' not allowed'): + with_sharding_constraint(jnp.arange(8), None) + def test_sharding_preserved_aot(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3593,13 +3615,11 @@ def test_jit_mul_sum_sharding_preserved(self): f = jax.jit(lambda x: x * 2) out = f(arr) cache_info1 = pxla._cached_compilation.cache_info() - pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info() self.assertIsInstance(out.sharding, NamedSharding) with jtu.count_pjit_cpp_cache_miss() as count: out2 = f(arr2) cache_info2 = pxla._cached_compilation.cache_info() - pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info() self.assertIsInstance(out2.sharding, PositionalSharding) # This will hit the cpp cache. @@ -3610,9 +3630,6 @@ def test_jit_mul_sum_sharding_preserved(self): self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) - self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits) - self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1) - out4 = jnp.sum(arr) self.assertIsInstance(out4.sharding, NamedSharding) @@ -3802,6 +3819,15 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') + def test_jit_static_argnames_non_interned(self): + def do_nothing(foobar: int): + return foobar + + argname = "foobar" + # Has the side effect of ensuring argname is not interned. + argname = str(json.loads(json.dumps(argname))) + jax.jit(do_nothing, static_argnames=[argname])(foobar=2) # doesn't crash + def test_most_recent_executable_outer_inner_cache(self): x = np.zeros((20, 20), dtype=jnp.float64) @@ -3811,7 +3837,7 @@ def trace_to_jaxpr(x): constant_values= ((0.0, 0.0), (0.0, 0.0))) jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x) - jax.core.jaxpr_as_fun(jaxpr)(x) + jax._src.core.jaxpr_as_fun(jaxpr)(x) jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash @@ -4433,6 +4459,14 @@ def g(x): self.assertEqual(out2.sharding, s) self.assertEqual(out2.dtype, np.float32) + def test_make_mesh_non_int_error(self): + with self.assertRaisesRegex( + ValueError, + "`axis_shapes` passed to `make_mesh` should be a sequence of ints"): + jax.make_mesh(((4,), 4), ('x', 'y')) + + jax.make_mesh((1, np.int32(1), np.int64(1)), ('x', 'y', 'z')) # doesn't crash + def test_jnp_array_reshard_error(self): if jax.device_count() < 2: self.skipTest('Requires >=2 devices') @@ -4589,17 +4623,25 @@ def f(x): jax.jit(f, out_shardings=s)(np.arange(8)) self.assertEqual(count[0], 1) + def test_input_shardings_single_device(self): + @jax.jit + def f(x): + return x * 2 + + ins, _ = f.lower(np.arange(8)).compile().input_shardings + self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") -@jtu.with_config(jax_sharding_in_types=True, jax_use_shardy_partitioner=False) +@jtu.with_config(jax_use_shardy_partitioner=False) class ShardingInTypesTest(jtu.JaxTestCase): - def test_basic_mul(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(8, 2) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_basic_mul(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4620,10 +4662,21 @@ def f(x): if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) else: - self.assertEqual(lowered_text.count('@Sharding'), 2) + self.assertEqual(lowered_text.count('@Sharding'), 3) - def test_fully_replicated_array_mul(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jax.jit + def g(x): + x = f(x) + return jnp.sum(x) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_fully_replicated_array_mul(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp1, s) @@ -4671,9 +4724,9 @@ def g(x, y): ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce') ) - def test_dot_general_basic(self, spec1, spec2, out_spec, collective_name): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp1 = np.arange(16).reshape(8, 2) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): + np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4694,6 +4747,55 @@ def f(x, y): if collective_name is not None and compiled_text is not None: self.assertIn(collective_name, compiled_text) + @jax.jit + def g(x, y): + out = f(x, y) + return jnp.sum(out) + + out = jax.grad(g, argnums=(0, 1))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + out = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + @jtu.with_user_mesh((4,), ('x',)) + def test_dot_general_out_type(self, mesh): + np_inp1 = np.arange(16.).reshape(8, 2) + arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) + arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return jnp.sum(out) + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + out = jax.grad(f, argnums=(0, 1))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1))) + out = jitted_grad(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + jaxpr = jitted_grad.trace(arr1, arr2).jaxpr + bwd_jaxpr = jaxpr.eqns[1] + expected_spec = [('broadcast_in_dim', P('x', None)), + ('dot_general', P('x', None)), + ('transpose', P(None, 'x')), + ('dot_general', P('x', None))] + for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec): + self.assertEqual(eqn.primitive.name, spec[0]) + self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1]) + @parameterized.named_parameters( ('fail1', P('x', 'y'), P('y', 'x'), "PartitionSpec.*x.*x.*has duplicate entries", ValueError), @@ -4701,8 +4803,8 @@ def f(x, y): "dot_general requires contracting dimensions to have consistent sharding", TypeError), ) - def test_dot_general_error(self, spec1, spec2, error_msg, error_type): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4714,8 +4816,8 @@ def f(x, y): with self.assertRaisesRegex(error_type, error_msg): f(arr1, arr2) - def test_dot_general_batch_error(self): - mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_dot_general_batch_error(self, mesh): arr1 = jax.device_put(np.ones((8, 4, 2)), NamedSharding(mesh, P('x', 'y', 'z'))) arr2 = jax.device_put(np.ones((8, 2, 4)), @@ -4733,9 +4835,8 @@ def test_dot_general_batch_error(self): ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) - def test_aval_repr(self): - mesh = jtu.create_mesh((2, 2), ('model', 'data')) - + @jtu.with_user_mesh((2, 2), ('model', 'data')) + def test_aval_repr(self, mesh): aval = core.ShapedArray((128, 64), np.float32, sharding=NamedSharding(mesh, P('model', 'data'))) self.assertEqual(aval.str_short(), 'float32[128@model,64@data]') @@ -4753,14 +4854,14 @@ def test_aval_repr(self): self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') @parameterized.named_parameters( - ('all', None, P('x', 'y'), P()), - ('first', 0, P('x', 'y'), P('y')), - ('second', 1, P('x', 'y'), P('x')), - ('first2', 0, P(('x', 'y'), None), P(None)), + ('all', None, P('x', 'y'), P(), True), + ('first', 0, P('x', 'y'), P('y'), True), + ('second', 1, P('x', 'y'), P('x'), True), + ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4784,15 +4885,15 @@ def f(x): self.assertIn('all-reduce', compiled_text) @parameterized.named_parameters( - ('all', None, P('x', 'y'), P()), - ('first', 0, P('x', 'y'), P('y')), - ('second', 1, P('x', 'y'), P('x')), - ('first2', 0, P(('x', 'y'), None), P(None)), + ('all', None, P('x', 'y'), P(), True), + ('first', 0, P('x', 'y'), P('y'), True), + ('second', 1, P('x', 'y'), P('x'), True), + ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - def test_reduce_max(self, axis, in_spec, out_spec, reduce=True): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(8, 2) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): + np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4814,14 +4915,25 @@ def f(x): if reduce and compiled_text is not None: self.assertIn('all-reduce', compiled_text) + @jax.jit + def g(x): + out = f(x) + return jnp.mean(out) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + @parameterized.named_parameters( ('0', 0, P(None, 'x', 'y')), ('1', 1, P('x', None, 'y')), ('2', 2, P('x', 'y', None)), ('-1', -1, P('x', 'y', None)), ) - def test_broadcast_in_dim(self, axis, out_spec): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_broadcast_in_dim(self, axis, out_spec, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4846,8 +4958,8 @@ def f(x): ('3', 3), ('4', 4), ) - def test_integer_pow(self, pow): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_integer_pow(self, pow, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4865,8 +4977,24 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_sin_unop(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + def test_broadcasting_nary_error(self): + mesh1 = Mesh([jax.devices()[0]], 'x') + mesh2 = Mesh([jax.devices()[0]], 'y') + + arr1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) + arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) + + @jax.jit + def f(x, y): + return x + y + + with config.sharding_in_types(True): + with self.assertRaisesRegex( + ValueError, "Mesh for all inputs should be equal"): + f(arr1, arr2) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4883,8 +5011,8 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_jnp_array(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_jnp_array(self, mesh): np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4899,8 +5027,8 @@ def f(x): f(arr) - def test_lax_transpose_rule(self): - mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_lax_transpose_rule(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', 'y', 'z')) arr = jax.device_put(np_inp, s) @@ -4918,8 +5046,8 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_broadcasted_iota_with_sharding(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_broadcasted_iota_with_sharding(self, mesh): np_inp = np.arange(4) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) @@ -4945,6 +5073,415 @@ def g(x): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_einsum_with_out_type(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertArraysEqual(out, np_inp @ np_inp.T) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + lowered_text = f.lower(arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + + @jax.jit + def g(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr4 = jax.device_put(np_inp.T, NamedSharding(mesh, P('x', 'y'))) + out2 = g(arr3, arr4) + self.assertArraysEqual(out2, np_inp @ np_inp.T) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + + @jax.jit + def h2(x, y): + out = g(x, y) + return jnp.sum(out) + + out = jax.grad(h2, argnums=(0, 1))(arr3, arr4) + self.assertEqual(out[0].sharding, arr3.sharding) + self.assertEqual(out[1].sharding, arr4.sharding) + + out = jax.jit(jax.grad(h2, argnums=(0, 1)))(arr3, arr4) + self.assertEqual(out[0].sharding, arr3.sharding) + self.assertEqual(out[1].sharding, arr4.sharding) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_einsum_inverse(self, mesh): + np_inp = np.arange(64.) + + @jax.jit + def h(x, y): + s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None)) + out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s) + self.assertEqual(out.sharding.spec, s.spec) + return out + + arr1 = jax.device_put(np_inp.reshape(8, 4, 2), + NamedSharding(mesh, P('x', 'y', None))) + arr2 = jax.device_put(np_inp.reshape(2, 4, 8), + NamedSharding(mesh, P(None, 'x', 'y'))) + out = h(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None))) + + lowered_text = h.lower(arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + + @jax.jit + def h2(x, y): + out = h(x, y) + return jnp.sum(out) + + out = jax.grad(h2, argnums=(0, 1))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + out = jax.jit(jax.grad(h2, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + @parameterized.named_parameters( + ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), + ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), + ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True) + ) + @jtu.with_user_mesh((2,), ('x',)) + def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, + use_sharding_arg, mesh): + np_inp = np.arange(math.prod(src_shape), + dtype=np.float32).reshape(src_shape) + arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) + + @partial(jax.jit, static_argnums=1) + def f(x, new_sharding): + y = lax.reshape(x, dst_shape, sharding=new_sharding) + y = y * 2 + self.assertEqual(y.sharding.spec, dst_spec) + return y + + new_s = (NamedSharding(mesh.abstract_mesh, dst_spec) + if use_sharding_arg else None) + out = f(arr, new_s) + self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec)) + self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2) + + lowered_text = f.lower(arr, new_s).as_text() + self.assertIn('@Sharding', lowered_text) + + def g(x): + out = f(x, new_s) + return jnp.square(jnp.sum(out)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_select(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np_inp, s) + + @jax.jit + def f(pred, on_true, on_false): + y = lax.select(pred, on_true, on_false) + self.assertEqual(y.sharding.spec, s.spec) + return y + + out = f(arr1 == arr2, arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr1) + + lowered_text = f.lower(arr1 == arr2, arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + + arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) + with self.assertRaisesRegex( + TypeError, "select cases must have the same shardings"): + f(arr1 == arr2, arr1, arr3) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_device_put_reshard(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = jax.device_put(x, NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_shard_map_full_manual(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective) + return x * y + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', 'y'))(x, y) + self.assertEqual(z.sharding.spec, P('x', 'y')) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', 'y')) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp * np_inp) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_shard_map_dot(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective) + allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) + z = x @ allgatherd_y + return jax.lax.psum(z, axis_name='y') + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', None))(x, y) + self.assertEqual(z.sharding.spec, P('x', None)) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_slice(self, mesh): + np_inp = np.arange(16.).reshape(4, 4) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) + + @jax.jit + def f(x): + y = lax.slice(x, (0, 0), (4, 3)) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertIn('@Sharding', f.lower(arr).as_text()) + + def g(x): + out = f(x) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) + + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_squeeze(self, mesh): + np_inp = np.arange(16.).reshape(4, 4, 1) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) + + @jax.jit + def f(x): + y = lax.squeeze(x, (2,)) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertIn('@Sharding', f.lower(arr).as_text()) + self.assertArraysEqual(out, np.squeeze(np_inp, axis=2)) + + def g(x): + out = f(x) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_pad(self, mesh): + np_inp = np.arange(8.) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @partial(jax.jit, static_argnums=(1, 2)) + def f(x, padding_config, spec): + y = lax.pad(x, 0., padding_config) + self.assertEqual(y.sharding.spec, spec) + return y + + out = f(arr, ((2, 2, 0),), P('x')) + self.assertArraysEqual(out, np.pad(np_inp, 2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertIn('@Sharding', f.lower(arr, ((2, 2, 0),), P('x')).as_text()) + + out = f(arr, ((0, 0, 0),), P('x')) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + f(arr, ((0, 3, 1), ), P('x')) # doesn't crash + + def g(x): + out = f(x, ((2, 2, 0),), P('x')) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + f(arr, ((2, 3, 0), ), None) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + f(arr, ((0, 3, 0), ), None) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) + f(arr, ((4, 4, 1),), None) + + @jtu.with_user_mesh((2, 1), ('x', 'y')) + def test_concatenate(self, mesh): + np_inp = np.arange(16.).reshape(4, 4) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np.arange(4.).reshape(4, 1), s) + + @partial(jax.jit, static_argnums=2) + def f(x, y, method='jnp'): + if method == 'jnp': + y = jnp.concatenate([x, y], axis=1) + else: + assert method == 'lax' + y = lax.concatenate([x, y], dimension=1) + self.assertEqual(y.sharding.spec, P('x', 'y')) + return y + + out = f(arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) + self.assertIn('@Sharding', f.lower(arr1, arr2).as_text()) + + out = f(arr1, arr2, method='lax') + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) + + with self.assertRaisesRegex( + TypeError, "All operands should have the same sharding"): + arr3 = jax.device_put(np.arange(4.).reshape(4, 1), + NamedSharding(mesh, P('x'))) + f(arr1, arr3) + + def g(x, y): + out = f(x, y) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr1, arr2) + self.assertEqual(out.sharding, s) + + out = jax.jit(jax.grad(g))(arr1, arr2) + self.assertEqual(out.sharding, s) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_scan(self, mesh): + carry = jax.device_put(np.arange(16.).reshape(2, 8), + NamedSharding(mesh, P(None, 'x'))) + arr = jax.device_put(np.arange(128.).reshape(8, 8, 2), + NamedSharding(mesh, P(None, 'x', 'y'))) + + @jax.jit + def f(carry, xs): + def g(carry, x): + self.assertEqual(carry.sharding.spec, P(None, 'x')) + self.assertEqual(x.sharding.spec, P('x', 'y')) + y = carry @ x + self.assertEqual(y.sharding.spec, P(None, 'y')) + z = jax.nn.relu(y) + self.assertEqual(z.sharding.spec, P(None, 'y')) + a = z @ x.T + self.assertEqual(a.sharding.spec, P(None, 'x')) + return a, y + return jax.lax.scan(g, carry, xs) + + activation, mean = f(carry, arr) + self.assertEqual(activation.sharding, NamedSharding(mesh, P(None, 'x'))) + self.assertEqual(mean.sharding, NamedSharding(mesh, P(None, None, 'y'))) + + f.lower(carry, arr).compile()(carry, arr) # doesn't crash + + def g(carry, arr): + out = f(carry, arr) + return jnp.sum(out[0]) + out = jax.jit(jax.grad(g, argnums=(0, 1)))(carry, arr) + self.assertEqual(out[0].sharding, carry.sharding) + self.assertEqual(out[1].sharding, arr.sharding) + + with self.assertRaisesRegex( + ValueError, "0th dimension of all xs should be replicated"): + f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_argminmax(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + z = jnp.argmax(x, axis=0) + self.assertEqual(z.sharding.spec, P('y')) + a = jnp.argmin(x, axis=1) + self.assertEqual(a.sharding.spec, P('x')) + return z, a + + out1, out2 = f(arr) + self.assertArraysEqual(out1, np.argmax(np_inp, axis=0)) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('y'))) + self.assertArraysEqual(out2, np.argmin(np_inp, axis=1)) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) + + self.assertIn('@Sharding', f.lower(arr).as_text()) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @@ -5550,9 +6087,6 @@ def test_hlo_sharding_manual_replicated(self): self.assertTrue(hs4.is_tiled()) def test_hlo_sharding_with_device_ordering(self): - if xla_extension_version < 291: - self.skipTest('Requires xla_extension_version >= 291') - hs1 = xc.HloSharding.subgroup_with_device_ordering( np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.int64), subgroup_types=[xc.OpSharding.Type.REPLICATED], @@ -5664,7 +6198,6 @@ def f(x): self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text()) - @unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292") def test_lowering_with_sharding_constraint(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) arr = np.arange(16).reshape(4, 2, 2) @@ -5690,7 +6223,6 @@ def f(x): self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str) # TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline. - @unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292") @jtu.skip_on_devices('cpu') def test_compile_with_inferred_out_sharding(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 9a8d0b91272b..f611ee981335 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2057,7 +2057,7 @@ def testSizeOverflow(self): def test_axis_env_length(self): f = lambda x: jax.pmap(g)(jnp.array([x]))[0] def g(x): - assert len(core.thread_local_state.trace_state.axis_env) == 1 + assert len(core.get_axis_env().axis_names()) == 1 return x jax.grad(f)(3.) # doesn't fail @@ -2215,8 +2215,6 @@ def test_cache_uses_jax_key(self): pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) - config.update_thread_local_jit_state() - pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) @@ -3015,7 +3013,7 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args([sharding], [None], [arg]) + results = pxla.shard_args([sharding], [None], [None], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays diff --git a/tests/random_test.py b/tests/random_test.py index da182dbccae9..d8c5a70b995a 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -936,6 +936,11 @@ def f(x): x = jnp.array([True, False, False]) f(x) # doesn't crash + def test_device_get(self): + keys = self.make_keys(4) + keys_on_host = jax.device_get(keys) + self.assertKeysEqual(keys, keys_on_host) + def test_device_put(self): device = jax.devices()[0] keys = self.make_keys(4) @@ -1120,10 +1125,10 @@ class A: pass jax.random.key(42, impl=A()) @jtu.sample_product(name=[name for name, _ in PRNG_IMPLS]) - def test_key_spec_repr(self, name): + def test_key_impl_builtin_is_string_name(self, name): key = jax.random.key(42, impl=name) spec = jax.random.key_impl(key) - self.assertEqual(repr(spec), f"PRNGSpec({name!r})") + self.assertEqual(spec, name) def test_keyarray_custom_vjp(self): # Regression test for https://github.com/jax-ml/jax/issues/18442 @@ -1155,6 +1160,12 @@ def _f_bwd(_, state_bar): result = jax.grad(lambda theta: f(theta, state)[0])(3.0) self.assertEqual(result, 1.0) + def test_keyarray_array_conversion_fails(self): + key = jax.random.key(0) + msg = "JAX array with PRNGKey dtype cannot be converted to a NumPy array." + with self.assertRaisesRegex(TypeError, msg): + np.asarray(key) + # TODO(frostig,mattjj): more polymorphic primitives tests diff --git a/tests/roofline_test.py b/tests/roofline_test.py new file mode 100644 index 000000000000..e5003947181b --- /dev/null +++ b/tests/roofline_test.py @@ -0,0 +1,426 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import partial +import contextlib + +from absl.testing import absltest +from jax.sharding import PartitionSpec as P +import jax +import jax.lax as lax +import jax.numpy as jnp + +from jax._src import test_util as jtu + +from jax.experimental import roofline + + +jax.config.parse_flags_with_absl() + + +def create_inputs( + *shardings: P, + dtype: jnp.dtype = jnp.float32, + mesh_shape: tuple[int, ...] = (2, 2, 2), +) -> tuple[jax.sharding.Mesh, tuple[jax.ShapeDtypeStruct, ...]]: + mesh = jtu.create_mesh(mesh_shape, ("x", "y", "z")) + arrays = [] + for sharding in shardings: + array = jax.ShapeDtypeStruct( + (8, 8), dtype, sharding=jax.sharding.NamedSharding(mesh, sharding) + ) + arrays.append(array) + return mesh, tuple(arrays) + + +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + + +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + +def tearDownModule(): + _exit_stack.close() + + +class RooflineTest(jtu.JaxTestCase): + def test_scalar_collectives(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(P("z", None), P(("x", "y"), None)), + ) + def scalar_collectives(a, b): + a = lax.pmin(a, ("x", "y")) + b = lax.pmax(b, "z") + return a, b + + _, results = scalar_collectives(a, b) + + itemsize = 4 + + axis_size = 2 + axis_size_m1 = axis_size - 1 + + xy_num_axes = 2 + xy_ici_bytes = int( + itemsize + # 2 phases. + * ( + (1 / xy_num_axes * axis_size_m1) + (1 * axis_size / xy_num_axes * axis_size_m1) + ) + ) + # 2 phases times 2 hops. + xy_ici_latency = 2 * 2 + + z_ici_bytes = int(itemsize * 1 * axis_size_m1) + # 2 hops. + z_ici_latency = 2 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_collective_matmul(self): + a_spec = P(None, "x") + b_spec = P(None, "x") + c_spec = P("x", None) + mesh, (a, b, c) = create_inputs(a_spec, b_spec, c_spec, dtype=jnp.int8) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec, c_spec), + out_specs=a_spec, + ) + def collective_matmul(a, b, c): + a = lax.all_gather(a, "x", axis=1, tiled=True) + # Test broadcasting and slicing works. + a = a[None, :, :] + b = b[:, None, :] + ab = jnp.einsum("bij,jbk->ikb", a, b).astype(jnp.int8)[..., 0] + abc = jnp.einsum("ik,kc->ic", ab, c).astype(jnp.int8) + abc = lax.psum_scatter(abc, "x", scatter_dimension=1, tiled=True) + return abc + + _, results = collective_matmul(a, b, c) + + itemsize = 1 + m, k, n = 8, 4, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_mk = mk + + # Times 2 for ag + rs. + ici_bytes = 2 * int(itemsize * sharded_mk * axis_size_m1) + ici_latency = 2 * 2 + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=2 * itemsize * (mk + kn + mn), + # Right after all_gather. + peak_hbm_bytes=itemsize * (mk * axis_size + mk + kn), + ) + self.assertDataclassEqual(results, expected) + + def test_matmul_psum(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), None) + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=P("z", None), + ) + def matmul_psum(a, b): + c = a @ b + c = lax.psum(c, ("x", "y")) + return c + + _, results = matmul_psum(a, b) + + itemsize = 4 + m, k, n = 4, 2, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + num_axes = 2 + sharded_mn = mn / axis_size / num_axes + + # Times 2 for ag + rs. + ici_bytes = 2 * int( + itemsize + # 2 phases. + * ( + (sharded_mn / num_axes * axis_size_m1) + + (sharded_mn * axis_size / num_axes * axis_size_m1) + ) + ) + ici_latency = 2 * 2 * 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={axis: ici_bytes for axis in ("x", "y")}, + ici_latency={axis: ici_latency for axis in ("x", "y")}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mn), + ) + self.assertDataclassEqual(results, expected) + + def test_all_to_all(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(P(("z", "x", "y"), None), P(("x", "y", "z"), None)), + ) + def all_to_all(a, b): + a = lax.all_to_all(a, ("x", "y"), split_axis=0, concat_axis=1, tiled=True) + b = lax.all_to_all(b, "z", split_axis=0, concat_axis=1, tiled=True) + return a, b + + _, results = all_to_all(a, b) + + itemsize = 4 + + xy_size = itemsize * 8 * 8 / 2 + # Half the data over 2 links. + xy_ici_bytes = int(xy_size / 2 / 2) + # 2 hops. + xy_ici_latency = 2 + + z_size = itemsize * 8 * 8 / 2 / 2 + # Half the data over 1 link. + z_ici_bytes = int(z_size / 2) + # 1 hop. + z_ici_latency = 1 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_ppermute(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(a_spec, b_spec), + ) + def ppermute(a, b): + a = lax.ppermute(a, ("x", "y"), perm=((0, 3), (3, 0), (1, 2), (2, 1))) + b = lax.ppermute(b, "z", perm=((1, 0), (0, 1))) + return a, b + + _, results = ppermute(a, b) + + itemsize = 4 + shard_size = itemsize * 4 * 2 + + # At most 2 shards contend for 1 link. + xy_ici_bytes = int(shard_size * 2) + # 2 hops. + xy_ici_latency = 2 + + # No contention but there is a single link. + z_ici_bytes = int(shard_size * 2) + # 1 hop. + z_ici_latency = 1 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_grad_matmuls(self): + a_spec = P(None, "x") + b_spec = P(None, None) + mesh, (a, b) = create_inputs(a_spec, b_spec, dtype=jnp.int8) + + @partial( + roofline.roofline_and_grad, + mesh=mesh, + in_specs=(a_spec, b_spec), + # Numerically incorrect AD, but tests that we handle it properly. + out_specs=P("x", None), + ) + def collective_matmul(a, b): + a = lax.all_gather(a, "x", axis=1, tiled=True) + return a @ b + + c, fwd_results, bwd_results = collective_matmul(a, b) + + itemsize = 1 + m, k, n = 8, 8, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_mk = mk // axis_size + + ici_bytes = int(itemsize * sharded_mk * axis_size_m1) + ici_latency = 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mk + kn), + ) + self.assertDataclassEqual(fwd_results, expected) + + bwd_itemsize = 2 + # 2 for psum + 1 for rs. + bwd_ici_bytes = 3 * int(bwd_itemsize * sharded_mk * axis_size_m1) + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": bwd_ici_bytes}, + ici_latency={"x": 3 * ici_latency}, + hbm_bytes=2 * bwd_itemsize * (mk + kn + mn), + # Residuals + cotangents. + peak_hbm_bytes=bwd_itemsize * (mk + kn + mn), + ) + self.assertDataclassEqual(bwd_results, expected) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=c.sharding.spec, + out_specs=c.sharding.spec, + ) + def mul_2(c): + return c * 2 + + results = mul_2(c) + self.assertLen(results, 2) + + def test_one_sized_axis_collectives(self): + a_spec = P("x") + mesh, (a,) = create_inputs(a_spec, mesh_shape=(1, 2, 4)) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=a_spec, + out_specs=a_spec, + ) + def one_sized_axis_collectives(a): + a = lax.pmin(a, "x") + a = lax.all_gather(a, "x", axis=1, tiled=True) + a = lax.psum_scatter(a, "x", scatter_dimension=1, tiled=True) + a = lax.psum(a, "x") + a = lax.all_to_all(a, "x", split_axis=0, concat_axis=1, tiled=True) + a = lax.ppermute(a, "x", perm=((1, 0), (0, 1))) + return a + + _, results = one_sized_axis_collectives(a) + expected = roofline.RooflineResult( + ici_bytes={"x": 0}, + ici_latency={"x": 0}, + peak_hbm_bytes=4 * 8 * 8, + ) + self.assertDataclassEqual(results, expected) + + def test_remat(self): + a_spec = P("x", None) + b_spec = P("x", None) + mesh, (a, b) = create_inputs(a_spec, b_spec) + + def fsdp_checkpoint_policy(prim, *args, **kwargs): + if prim == lax.all_gather_p and kwargs["axis_name"] == "x": + return True + return False + + @partial( + roofline.roofline_and_grad, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=P("x", None), + ) + @partial(jax.checkpoint, policy=fsdp_checkpoint_policy) + def collective_matmul(a, b): + b = lax.all_gather(b, "x", axis=0, tiled=True) + return a @ b + + _, fwd_results, bwd_results = collective_matmul(a, b) + + itemsize = 4 + m, k, n = 4, 8, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_kn = kn // axis_size + + ici_bytes = int(itemsize * sharded_kn * axis_size_m1) + ici_latency = 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mk + kn), + ) + self.assertDataclassEqual(fwd_results, expected) + + bwd_itemsize = 2 + # Remat ag + rs. + bwd_ici_bytes = 2 * int(bwd_itemsize * sharded_kn * axis_size_m1) + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": bwd_ici_bytes}, + ici_latency={"x": 2 * ici_latency}, + hbm_bytes=2 * bwd_itemsize * (mk + kn + mn), + # Residuals + cotangents. + # We gather kn while computing the kn cotangents. + peak_hbm_bytes=bwd_itemsize * (kn + kn + mn), + ) + self.assertDataclassEqual(bwd_results, expected) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 540136b33870..fe2232d7ffe6 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -164,7 +164,7 @@ def testRotationConcatenate(self, shape, other_shape, dtype): @jtu.sample_product( dtype=float_dtypes, shape=[(10, 4)], - indexer=[slice(1, 5), slice(0), slice(-5, -3)], + indexer=[slice(1, 5), slice(0, 1), slice(-5, -3)], ) def testRotationGetItem(self, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index f02ed0fc04bb..88a126c284a7 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -543,6 +543,13 @@ def testGammaLogPdfZero(self): self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) + def testGammaDebugNans(self): + # Regression test for https://github.com/jax-ml/jax/issues/24939 + with jax.debug_nans(True): + self.assertAllClose( + osp_stats.gamma.pdf(0.0, 1.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0, 1.0) + ) + @genNamedParametersNArgs(4) def testGammaLogCdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 1b213a8b5bb4..0d1da6ceaeef 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -48,6 +48,9 @@ from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.state import discharge +from jax._src.state import primitives as ref_primitives + import numpy as np config.parse_flags_with_absl() @@ -125,11 +128,11 @@ def sampled_assertion(self, ): """Checks `assertion(e, fun(*operands))` symbolically and concretely. - For the concrete check, it will same the space of dimension variable + For the concrete check, it will sample the space of dimension variable assignments for the dimension variables in `e`. - This is useful when `fun` can operate both with polynomials and with - concrete values, and we want to double-check that the behavior is sound. + This is useful when `fun` can operate both with symbolic and with + concrete values, and we want to check that the behavior is sound. """ computed_sym = fun(*operands_sym) assertion_fun = { @@ -1426,6 +1429,29 @@ def test_non_trivial_dim_expr(self, expr=lambda d: d % -2): arg_descriptors=[RandArg((3,), np.int64)], polymorphic_shapes=["b"]) + @jtu.parameterized_filterable( + # The function `f` will be called with x: f32[b] + kwargs=[ + dict(testcase_name="cube", f=lambda x: x.shape[0] ** 3), + dict(testcase_name="zero", f=lambda x: x.shape[0] ** 0), + dict(testcase_name="rpow", f=lambda x: 2 ** x.shape[0]), + dict(testcase_name="negative", + f=lambda x: x.shape[0] ** -2, + expect_error=(ValueError, "cannot be raised to negative powers")), + dict(testcase_name="non_integer", + f=lambda x: x.shape[0] ** 1.5, + expect_error=(ValueError, "cannot be raised to non-integer powers")), + dict(testcase_name="sym_pow", + f=lambda x: x.shape[0] ** x.shape[0]), + ] + ) + def test_pow(self, f, expect_error: tuple[Exception, str] | None = None): + check_shape_poly(self, + f, + arg_descriptors=[RandArg((3,), np.float32)], + polymorphic_shapes=["b"], + expect_error=expect_error) + def test_static_shape_result(self): """The result has static shape.""" @@ -2062,6 +2088,31 @@ def test_vmap_error(self): polymorphic_shapes=["b, ...", "c, ...", None]) + @jtu.parameterized_filterable( + kwargs=[ + dict(slc=slc) + for slc in [ + slice(None, None, None), + slice(2, 5), + ] + ]) + def test_stateful(self, slc: slice): + w, = export.symbolic_shape("w", constraints=["w >= 3"]) + def f(x_ref): + ones = jnp.ones_like(x_ref)[slc] + ref_primitives.ref_addupdate(x_ref, slc, ones) + x1 = ref_primitives.ref_get(x_ref, slc) + x2 = x1 + ones + ref_primitives.ref_set(x_ref, slc, x2) + + exp = export.export(jax.jit(discharge.run_state(f)))( + jax.ShapeDtypeStruct((w,), dtype=_f32)) + x = np.ones((32,), dtype=_f32) + expected = np.copy(x) + expected[slc] = 3. + self.assertAllClose(exp.call(x), expected) + + # List containing either harnesses, or lists of harnesses _POLY_SHAPE_TEST_HARNESSES = [ PolyHarness("add", "", @@ -2941,6 +2992,40 @@ def test_vmap_error(self): RandArg((3, 5, 0), _f32)], polymorphic_shapes=[None, "b0, b1, ..."], override_jax_config_flags=override_jax_config_flags), # type: ignore + [ + PolyHarness("random_choice", f"{flags_name}_arr_poly={arr_poly}_shape_poly={shape_poly}_replace={replace}_use_p={use_p}", + lambda key, a, res_shape, use_p: jax.random.choice( + jax.random.wrap_key_data(key), + a, + shape=res_shape.shape, + p=jnp.full((a.shape[1],), 0.1, dtype=_f32) if use_p else None, + axis=1, + replace=replace), + arg_descriptors=[RandArg((key_size,), np.uint32), + RandArg((64, 12, 4), _f32), # sample on axis=1 + RandArg((3, 4), _f32), + StaticArg(use_p)], + # TODO(necula): threefry requires even-sized samples. + polymorphic_shapes=[None, + "_, 2*b1, _" if arr_poly else None, + "b3, b4" if shape_poly else None], + # The array sampled dimension must be larger than res_shape.size + symbolic_constraints=[ + "2*b1 >= 12" if arr_poly else "1 >= 0", + "2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0", + "12 >= b3*b4" if shape_poly else "1 >= 0" + ], + override_jax_config_flags=override_jax_config_flags, + expect_error=( + (NotImplementedError, "permutation") + if arr_poly and not use_p else None)) # type: ignore + # np.insert used in random.choice tries to coerce shape_poly to + # integer arrays, but only when the arr_poly is False. + for arr_poly in [True, False] + for shape_poly in [True, False] + for replace in [True, False] + for use_p in [True, False] + ], PolyHarness("random_split", f"{flags_name}", lambda key, a: jax.random.key_data( jax.random.split(jax.random.wrap_key_data(key), @@ -2971,7 +3056,7 @@ def test_vmap_error(self): polymorphic_shapes=[None, "b0, ..."], expect_error=( (core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else None), + "array size .* must be even") if flags_name == "threefry_non_partitionable" else None), override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [ @@ -3268,6 +3353,14 @@ def test_vmap_error(self): lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1 + x.shape[0] // 4, axis=0), arg_descriptors=[RandArg((13, 4), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("sort", "", + lambda a: lax.sort(a), + arg_descriptors=[RandArg((16,), _f32)], + polymorphic_shapes=["b"]), + PolyHarness("jvp_sort", "", + lambda a: jax.jvp(lax.sort, (a,), (a,)), + arg_descriptors=[RandArg((16,), _f32)], + polymorphic_shapes=["b"]), PolyHarness("jnp_split", "idx_tuple_ct", # The indices are a tuple with constants lambda a: jnp.split(a, (2,)), @@ -3561,7 +3654,7 @@ def test_harness(self, harness: PolyHarness): not harness.polymorphic_shapes[0].endswith("...") and jtu.test_device_matches(["tpu"])): raise unittest.SkipTest( - "Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.") + "Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.") config_flags = harness.override_jax_config_flags # Update this here rather than in harness object because vmap_random_gamma is derived diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 10267ff5eb98..c5f80a6d97f3 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -221,18 +221,16 @@ def test_shard_alike_inputs(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) - rep_s = NamedSharding(mesh, P()) arr = jax.device_put(np_inp, s) - arr2 = jax.device_put(np_inp, rep_s) def f(x, y): return shard_alike(x, y) - eager_out1, eager_out2 = f(arr, arr2) + eager_out1, eager_out2 = f(arr, np_inp) self.assertEqual(eager_out1.sharding, s) self.assertEqual(eager_out2.sharding, s) - out1, out2 = jax.jit(f)(arr, arr2) + out1, out2 = jax.jit(f)(arr, np_inp) self.assertEqual(out1.sharding, s) self.assertEqual(out2.sharding, s) @@ -282,6 +280,5 @@ def test_sharding_preserverd_single_device(self): _, y = shard_alike(x, jnp.arange(8)) self.assertEqual(y.sharding, s) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 364a90621fa9..56cf9987911d 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -37,7 +37,6 @@ from jax._src import core from jax._src import prng from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals @@ -710,6 +709,26 @@ def f(x): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) + def test_vmap_of_grad_spmd_axis_name(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + @partial( + shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False + ) + def f(x): + return jnp.sin(jnp.sum(x)) + + x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4) + put_x = jax.device_put( + x, + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')), + ) + vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x) + vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x) + self.assertArraysEqual( + vmap_spmd_axisname_result, vmap_no_spmd_axisname_result + ) + def test_vmap_spmd_axis_name_pair(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -994,6 +1013,73 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_cond_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + x = jnp.arange(4) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(True, true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return y + return jax.lax.cond(True, true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + + # https://github.com/jax-ml/jax/issues/24418 + def f(a): + c = jax.lax.cond(jnp.any(a), lambda: 1, lambda: 0) + return jnp.reshape(c, a.shape) + + mesh = jtu.create_mesh((2,), ('x',)) + a = jnp.array([True, False]) + shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + + def test_switch_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + x = jnp.arange(4) + + def f(n, x, y): + return jax.lax.switch( + n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) + + shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) + def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): @@ -1236,7 +1322,11 @@ def foo(x): hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo')) if config.use_shardy_partitioner.value: - self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + if len(jax.devices()) > 1: + self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + else: + # When devices == 1, the `sdy.manual_computation` is inlined. + self.assertEqual(0, hlo_str.count('sdy.manual_computation')) else: self.assertIn('call @shmap_body', hlo_str) self.assertIn('call @shmap_body_0', hlo_str) @@ -1496,6 +1586,55 @@ def f(x): self.assertEqual(str(e1.primitive), 'psum2') self.assertEqual(str(e2.primitive), 'pbroadcast') + def test_transpose_float0(self): + mesh = jtu.create_mesh((4,), ('x',)) + + s = jax.sharding.NamedSharding(mesh, P(None, 'x')) + + # vjp that triggers float0 + @jax.custom_vjp + def f(x, _): + return x + def f_fwd(x, y): + return x, jnp.zeros(shape=y.shape, dtype=np.int32) + def f_rev(tmp, g): + return (g, tmp) + f.defvjp(f_fwd, f_rev) + + # trivial vjp that consumes float0 + @jax.custom_vjp + def g(x, y): + return x, y + def g_fwd(x, y): + return jax.vjp(lambda x, y: (x, y), x, y) + def g_bwd(vjp_fn, result): + return vjp_fn(result) + g.defvjp(g_fwd, g_bwd) + + @partial(shard_map, mesh=mesh, in_specs=(P('x'), P()), out_specs=P()) + def f_shmapped(x, y): + return jax.lax.psum(f(x, y).sum(), axis_name=('x')) + + @partial(shard_map, mesh=mesh, check_rep=False, + in_specs=P('x'), out_specs=(P('x'), P())) + def f_shmapped2(x, y): + return g(x, y) + + def f_wrapper(x, y): + x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y)) + return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum() + + @partial(jax.jit, in_shardings=s, + out_shardings=jax.sharding.NamedSharding(mesh, P())) + def example(x, y): + return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y) + + x = np.zeros(shape=(8,16), dtype=np.float32) + y = np.zeros(shape=(8,16), dtype=np.int32) + # Doesn't crash. + dx, dy = example(x, y) + self.assertEqual(dy.dtype, jax.dtypes.float0) + def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1772,8 +1911,8 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) if config.use_shardy_partitioner.value: self.assertIn( - 'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},' - ' {}]>] manual_axes={"i"}', + 'in_shardings=[<@mesh, [{"i", ?}, {?}]>]' + ' out_shardings=[<@mesh, [{"i", ?}, {?}]>] manual_axes={"i"}', f.lower(v).as_text(), ) else: @@ -1784,6 +1923,41 @@ def f(x): ) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_partial_auto_propagate_through(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) + sharding = jax.sharding.NamedSharding(mesh, P('i')) + + def g(x): + return jax.lax.with_sharding_constraint(x * x, sharding) + + @jax.jit + def f(x): + return shard_map( + g, + mesh, + in_specs=P(), + out_specs=P(), + check_rep=False, + auto=frozenset({'i'}), + )(x) + + v = jnp.arange(32.0).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i'))) + if config.use_shardy_partitioner.value: + self.assertIn( + 'in_shardings=[<@mesh, [{?}, {?}]>]' + ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j"}', + f.lower(v).as_text(), + ) + else: + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[2,2]T(1,0) last_tile_dims={manual, replicated}}', + f.lower(v).as_text('hlo'), + ) + actual = f(v) + self.assertAllClose(v * v, actual, check_dtypes=False) + self.assertEqual(actual.sharding, sharding) + def test_sharded_prng_with_abstract_mesh(self): shape = (8, 2, 2) mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) @@ -1892,6 +2066,29 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_grad_nested_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) + + def g(x): + return x * x + + def h(x): + return shard_map(g, mesh, + in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x).sum() + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) + def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) @@ -2642,7 +2839,6 @@ def fwd(a): @unittest.skipIf(sdy is None, "shardy is not enabled") class SdyIntegrationTest(jtu.JaxTestCase): - @unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292") # Verify we can lower to a `ManualComputationOp`. def test_shardy_collective_permute(self): mesh = jtu.create_mesh((2,), ('x',)) diff --git a/tests/state_test.py b/tests/state_test.py index 0d6cddfc88c8..a930fe293709 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -30,6 +30,7 @@ from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu +from jax._src.state import types as state_types from jax._src.util import tuple_insert import jax.numpy as jnp from jax._src.lax.control_flow import for_loop @@ -638,6 +639,26 @@ def f(a_ref): refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval) self.assertTrue((refval == inval.at[jnp.array([0, 1])].set(1.)).all()) + def test_discharge_swap(self): + def f(a_ref): + a = ref_swap( + a_ref.at[0:4, 0:3, 0:2].at[1:3, :, 0], + (slice(None), slice(1, 3)), + jnp.zeros((2, 2), jnp.float32)) + return [a + 1] + in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)] + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(f), in_avals) + + discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 2) + + inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) + outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval) + self.assertArraysEqual(outval, inval[1:3, 1:3, 0] + 1) + self.assertArraysEqual(refval, inval.at[1:3, 1:3, 0].set(0)) + def test_discharge_addupdate(self): def f(a_ref, b): ref_addupdate(a_ref, (), b + 1) @@ -745,12 +766,13 @@ def f(a_ref, b_ref): b_ref[...] = jnp.array(1., dtype=jnp.float32) return a_ref[...], b_ref[...] - scalar_ref = shaped_array_ref((), jnp.float32) + scalar_ref_1 = shaped_array_ref((), jnp.float32) + scalar_ref_2 = shaped_array_ref((), jnp.float32) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [scalar_ref, scalar_ref]) + lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) - prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) + prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns) self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr)) self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr)) @@ -1361,17 +1383,6 @@ def body(y, z): class GeneralRefTest(jtu.JaxTestCase): - def test_unshaped_ref(self): - def f(x_ref): - x = x_ref[...] - x_ref[...] = x - ref_addupdate(x_ref, (), x) - return [x] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))]) - self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray) - self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32")) - def test_token(self): def f(x_ref): x = x_ref[...] @@ -1411,6 +1422,26 @@ def f(refs): self.assertEqual(x, 2 + 2 * 3 * 2) self.assertEqual(y, 2 * 3 * 2) + def test_run_state_with_uninitialized_input(self): + def f(refs): + x_ref, y_ref = refs + # y_ref is uninitialized so we shouldn't read from it until we write into + # it. + x = x_ref[...] + y_ref[...] = x * 2 + x_ref[...] = y_ref[...] + x_ref[...] + # x + x * 2, x * 2 + # jax.ShapeDtypeStruct is weirdly special to JAX, so we make our own class. + class MyArrayType: + pass + state_types._ref_type_aval_mappings[MyArrayType] = lambda _: ( + AbstractRef(core.ShapedArray((), jnp.int32)), + state_types.uninitialized, + ) + x, y = run_state(f)((jnp.int32(2), MyArrayType())) + self.assertEqual(x, 2 + 2 * 2) + self.assertEqual(y, 2 * 2) + def test_nontrivial_run_state_jit(self): def f(refs): x_ref, y_ref = refs diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 378e3803bba2..1b921121e27d 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -13,11 +13,11 @@ # limitations under the License. import collections +from collections.abc import Hashable import dataclasses import functools import pickle import re -from typing import TypeVar from absl.testing import absltest from absl.testing import parameterized @@ -28,11 +28,16 @@ from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp +# Easier to read. +SequenceKey = tree_util.SequenceKey +DictKey = tree_util.DictKey +GetAttrKey = tree_util.GetAttrKey +FlattenedIndexKey = tree_util.FlattenedIndexKey + def _dummy_func(*args, **kwargs): return - ATuple = collections.namedtuple("ATuple", ("foo", "bar")) class ANamedTupleSubclass(ATuple): @@ -142,27 +147,6 @@ def tree_unflatten(cls, meta, data): data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data)) return FlatCache(None, leaves=data, treedef=meta) -_T = TypeVar("_T") - - -# Inspired by Flax. -def pytree_node_dataclass(clz: _T, **kwargs) -> _T: - data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore - meta_fields = [] - data_fields = [] - for field_info in dataclasses.fields(data_clz): - is_pytree_node = field_info.metadata.get("pytree_node", True) - if is_pytree_node: - data_fields.append(field_info.name) - else: - meta_fields.append(field_info.name) - - jax.tree_util.register_dataclass( - data_clz, data_fields, meta_fields - ) - - return data_clz - @tree_util.register_static class StaticInt(int): @@ -231,16 +215,18 @@ def __eq__(self, other): "PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))", ) -@pytree_node_dataclass +@jax.tree_util.register_dataclass +@dataclasses.dataclass class ADataclass: x: tuple[int, int] y: int -@pytree_node_dataclass +@jax.tree_util.register_dataclass +@dataclasses.dataclass class ADataclassWithMeta: x: tuple[int, int] y: int - z: int = dataclasses.field(metadata={"pytree_node": False}) + z: int = dataclasses.field(metadata={"static": True}) TREES += ( (ADataclass(x=(1, 2), y=3),), @@ -778,6 +764,74 @@ def is_empty(x): ], ) + def testTreeFlattenWithPathBuiltin(self): + x = (1, {"a": 2, "b": 3}) + flattened = tree_util.tree_flatten_with_path(x) + _, tdef = tree_util.tree_flatten(x) + self.assertEqual( + flattened[0], + [ + ((SequenceKey(0),), 1), + ((SequenceKey(1), DictKey("a")), 2), + ((SequenceKey(1), DictKey("b")), 3), + ], + ) + self.assertEqual(flattened[1], tdef) + + def testTreeFlattenWithPathCustom(self): + x = [ + AnObject2( + x=12, + y={"foo": SpecialWithKeys(x=2, y=3), "bar": None}, + z="constantdef", + ), + 5, + ] + flattened, _ = tree_util.tree_flatten_with_path(x) + self.assertEqual( + flattened, + [ + ((SequenceKey(0), "x"), 12), + ((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("x")), 2), + ((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("y")), 3), + ((SequenceKey(1),), 5), + ], + ) + + def testFlattenWithPathDefaultDict(self): + d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) + leaves, treedef = tree_util.tree_flatten_with_path(d) + self.assertEqual( + leaves, + [ + ((DictKey("a"),), 1), + ((DictKey("b"),), 2), + ((DictKey("c"), DictKey("a")), 1), + ((DictKey("c"), DictKey("b")), 2), + ], + ) + restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves]) + self.assertEqual(list(restored_d.keys()), ["a", "b", "c"]) + _, from_flatten = tree_util.tree_flatten(d) + self.assertEqual(treedef, from_flatten) + + def testFlattenWithPathOrderedDict(self): + d = collections.OrderedDict({"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) + leaves, treedef = tree_util.tree_flatten_with_path(d) + self.assertEqual( + leaves, + [ + ((DictKey("b"),), 2), + ((DictKey("a"),), 1), + ((DictKey("c"), DictKey("a")), 1), + ((DictKey("c"), DictKey("b")), 2), + ], + ) + restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves]) + self.assertEqual(list(restored_d.keys()), ["b", "a", "c"]) + _, from_flatten = tree_util.tree_flatten(d) + self.assertEqual(treedef, from_flatten) + def testFlattenOneLevel(self): EmptyTuple = collections.namedtuple("EmptyTuple", ()) tree1 = {'a': 1, @@ -858,6 +912,87 @@ def testBadFlattenNonIterableLeaves(self): tree_util.tree_flatten(t) +class TreeKeyTest(absltest.TestCase): + + def testBasic(self): + def assert_equal_and_hash_equal(a, b): + self.assertEqual(a, b) + self.assertEqual(hash(a), hash(b)) + + key = SequenceKey(idx=1) + self.assertEqual(str(key), "[1]") + self.assertEqual(key.idx, 1) + assert_equal_and_hash_equal(key, SequenceKey(1)) + + class DictKeyEntry(Hashable): + + def __init__(self, s: str): + self.s = s + + def __hash__(self): + return hash(self.s) + + def __eq__(self, other): + return self.s == other.s + + key = DictKey(key="foo") + self.assertEqual(str(key), "['foo']") + self.assertEqual(key.key, "foo") + assert_equal_and_hash_equal(key, DictKey("foo")) + assert_equal_and_hash_equal( + DictKey(DictKeyEntry("foo")), DictKey(DictKeyEntry("foo")) + ) + + key = GetAttrKey(name="bar") + self.assertEqual(str(key), ".bar") + self.assertEqual(key.name, "bar") + assert_equal_and_hash_equal(key, GetAttrKey("bar")) + + key = FlattenedIndexKey(1) + self.assertEqual(str(key), "[]") + self.assertEqual(key.key, 1) + assert_equal_and_hash_equal(key, FlattenedIndexKey(1)) + self.assertNotEqual(hash(key), hash(SequenceKey(1))) + + def testPatternMatching(self): + keys = [ + SequenceKey(1), + DictKey("foo"), + GetAttrKey("bar"), + FlattenedIndexKey(1), + ] + for key in keys: + match key: + case jax.tree_util.SequenceKey(idx=idx): + self.assertEqual(idx, 1) + case jax.tree_util.DictKey(key=key): + self.assertEqual(key, "foo") + case jax.tree_util.GetAttrKey(name=name): + self.assertEqual(name, "bar") + case jax.tree_util.FlattenedIndexKey(key=idx_key): + self.assertEqual(idx_key, 1) + case _: + raise ValueError(f"key not matched: {key}") + match [ + DictKey("foo"), + ]: + case [DictKey("foo"), *_]: + pass + case _: + raise ValueError(f"keys are not matched: {keys}") + + def testPickle(self): + keys = [ + SequenceKey(1), + DictKey("foo"), + GetAttrKey("bar"), + FlattenedIndexKey(1), + ] + for key in keys: + unpickled = pickle.loads(pickle.dumps(key)) + self.assertEqual(key, unpickled) + + class StaticTest(parameterized.TestCase): @parameterized.parameters( @@ -1294,6 +1429,36 @@ def test_tree_unflatten(self): class RegistrationTest(jtu.JaxTestCase): + def test_register_dataclass_with_field_specifier(self): + @tree_util.register_dataclass + @dataclasses.dataclass + class Foo: + x: int + y: int = dataclasses.field(metadata=dict(static=True)) + + f = Foo(2, 3) + self.assertLen(jax.tree.leaves(f), 1) + + def test_register_dataclass_field_errors(self): + class Foo: # not a dataclass + x: int + y: int + + msg = ("register_dataclass: data_fields and meta_fields are required" + " when nodetype is not a dataclass. Got nodetype=") + with self.assertRaisesRegex(TypeError, msg): + tree_util.register_dataclass(Foo) + + msg = ("register_dataclass: data_fields and meta_fields must both be specified"\ + r" when either is specified. Got data_fields=\['x'\] meta_fields=None.") + with self.assertRaisesRegex(TypeError, msg): + tree_util.register_dataclass(Foo, data_fields=['x']) + + msg = ("register_dataclass: data_fields and meta_fields must both be specified"\ + r" when either is specified. Got data_fields=None meta_fields=\['y'\].") + with self.assertRaisesRegex(TypeError, msg): + tree_util.register_dataclass(Foo, meta_fields=['y']) + def test_register_dataclass_missing_fields(self): @dataclasses.dataclass class Foo: diff --git a/tests/util_test.py b/tests/util_test.py index 5f07d2f50880..5e99fff4b347 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -42,8 +42,8 @@ def f(*args, **kwargs): assert not kwargs return tuple(a * factor for a in args) - @lu.transformation_with_aux - def kw_to_positional(factor, *args, **kwargs): + @lu.transformation_with_aux2 + def kw_to_positional(f, store, factor, *args, **kwargs): """A transformation with auxiliary output. Turns all keyword parameters into positional ones. @@ -55,12 +55,12 @@ def kw_to_positional(factor, *args, **kwargs): kwargs_keys = kwargs.keys() new_args = tuple(kwargs[k] for k in kwargs_keys) new_kwargs = dict(factor=factor) - results = yield args + new_args, new_kwargs # Yield transformed (args, kwargs) + results = f(*(args + new_args), **new_kwargs) # Yield transformed (args, kwargs) # Assume results correspond 1:1 to the args + new_args assert len(results) == len(args) + len(new_args) aux_output = len(new_args) - yield (results[0:len(args)], - dict(zip(kwargs_keys, results[len(args):]))), aux_output + store.store(aux_output) + return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):]))) wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`. wf, out_thunk = kw_to_positional(wf, 2) diff --git a/tests/version_test.py b/tests/version_test.py index 7ce98c8588e5..51297a9716b1 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -26,11 +26,11 @@ # This is a subset of the full PEP440 pattern; for example we skip pre & post releases VERSION_PATTERN = re.compile(r""" - ^ # start of string - (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' - (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' - (?:\+(?P[a-zA-Z0-9_]+))? # optional local version; like '+g6643af3c3' - $ # end of string + ^ # start of string + (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?:\+(?P[a-zA-Z0-9_.]+))? # optional local version; like '+g6643af3c3' + $ # end of string """, re.VERBOSE) @@ -61,11 +61,12 @@ def assert_no_subprocess_call(): @contextlib.contextmanager -def assert_subprocess_call(): +def assert_subprocess_call(stdout: bytes | None = None): """Run code, asserting that subprocess.Popen *is* called at least once.""" with mock.patch("subprocess.Popen") as mock_Popen: + mock_Popen.return_value.communicate.return_value = (stdout, b"") yield - mock_Popen.assert_called() + mock_Popen.return_value.communicate.assert_called() class JaxVersionTest(unittest.TestCase): @@ -126,7 +127,7 @@ def testBuildVersionFromEnvironment(self): self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() datestring = datetime.date.today().strftime("%Y%m%d") @@ -134,19 +135,28 @@ def testBuildVersionFromEnvironment(self): self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE="1", JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() self.assertEqual(version, base_version) self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE="1", - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() self.assertEqual(version, base_version) self.assertValidVersion(version) + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None, + JAX_CUSTOM_VERSION_SUFFIX="test"): + with assert_subprocess_call(stdout=b"1731433958-1c0f1076e"): + version = jax.version._get_version_for_build() + self.assertTrue(version.startswith(f"{base_version}.dev")) + self.assertTrue(version.endswith("test")) + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3") diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index 38bd7e05533e..d141bc15c249 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -20,7 +20,6 @@ from absl.testing import absltest import jax from jax._src import config -from jax._src import dispatch from jax._src import test_util as jtu from jax._src.lax import lax from jax.experimental.xla_metadata import set_xla_metadata @@ -65,7 +64,7 @@ def f(a, b): def test_f_nonjitted(self): def f_add(a, b): - return dispatch.apply_primitive(lax.add_p, a, b) + return lax.add(a, b) arg1 = jnp.arange(2) with set_xla_metadata(a="b"): @@ -126,7 +125,7 @@ def f_add_jit(a, b): def test_attr_caching_nonjit(self): def f_add(a, b): - return dispatch.apply_primitive(lax.add_p, a, b) + return lax.add(a, b) arg1 = jnp.arange(2) arg2 = jnp.arange(2) + 1 diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3dd8cbd33712..110b5e055b31 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "76da730179313b3bebad6dea6861768421b7358c" -XLA_SHA256 = "d67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757" +XLA_COMMIT = "a041e1b15524cd15751e9a5b5dc581b9f276958f" +XLA_SHA256 = "39ea15ad645a2973efbfe7d1b4761d114cb688b5d2934561009aab7c911473da" def repo(): tf_http_archive( @@ -37,7 +37,7 @@ def repo(): # local checkout by either: # a) overriding the TF repository on the build.py command line by passing a flag # like: - # python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + # python build/build.py build --local_xla_path=/path/to/xla # or # b) by commenting out the http_archive above and uncommenting the following: # local_repository(