diff --git a/.circleci/config.yml b/.circleci/config.yml index c96fcc903b..6743e5ef02 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -13,38 +13,54 @@ jobs: - run: name: Update Meson command: pip3 install --upgrade meson==0.58.1 - - run: - name: Create Meson build dirs - command: mkdir build-gcc && mkdir build-clang - - run: - name: Meson Clang - environment: - CC: clang - CXX: clang++ - command: meson build-clang - run: name: Meson GCC environment: CC: gcc-8 CXX: g++-8 - command: meson build-gcc - - run: - name: Build Clang - command: | - cd build-clang - ninja + command: meson build-gcc -Dgtest=false - run: name: Build GCC command: | cd build-gcc ninja -j 4 + "mac": + macos: + xcode: 13.4.1 + steps: + - checkout - run: - command: cp build-clang/lc0 /tmp/lc0-clang + name: "Pull Submodules" + command: | + git submodule init + git submodule update --remote - run: - command: cp build-gcc/lc0 /tmp/lc0-g++ - - store_artifacts: - path: /tmp/lc0-clang - destination: lc0-ubuntu-18-04-clang + name: Install build tools + command: | + pip3 install meson==0.63 + pip3 install ninja + brew install ispc + - run: + name: Build lc0 + command: | + meson build --buildtype=release -Dgtest=false -Dopencl=false + cd build + ninja + - run: + name: Build lc0 arm + command: | + meson build-arm --buildtype=release -Dgtest=false -Dopencl=false --cross-file cross-files/aarch64-darwin + cd build-arm + ninja + - run: + name: Make universal binary + command: lipo -create -o /tmp/lc0 build/lc0 build-arm/lc0 - store_artifacts: - path: /tmp/lc0-g++ - destination: lc0-ubuntu-18-04-g++ + path: /tmp/lc0 + destination: lc0-macos_12.3.1 +workflows: + version: 2 + builds: + jobs: + - build + - "mac" diff --git a/.gitignore b/.gitignore index c90b403032..ea18c9ee58 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ testdata/ xcuserdata .clang-tidy compile_flags.txt -.vscode \ No newline at end of file +.vscode +.mesonpy* \ No newline at end of file diff --git a/FLAGS.md b/FLAGS.md index 7d687657f2..6a974ccd4b 100644 --- a/FLAGS.md +++ b/FLAGS.md @@ -41,7 +41,6 @@ List of command line flags: | --slowmover=NUM | Scale thinking time | Parameter value `X` means that the whole remaining time is split in such a way that the current move gets `X × Y` seconds, and next moves will get `1 × Y` seconds. However, due to smart pruning, the engine usually doesn't use all allocated time.
Default: `2.2`| | --move-overhead=NUM | Move time overhead in milliseconds | How much overhead should the engine allocate for every move (to counteract things like slow connection, interprocess communication, etc.).
Default: `100` ms. | | --minibatch-size=NUM | Minibatch size for NN inference | How many positions the engine tries to batch together for computation. Theoretically larger batches may reduce strengths a bit, especially on a small number of playouts.
Default is `256`. Every backend/hardware has different optimal value (e.g., `1` if batching is not supported). | -| --max-prefetch=NUM | Maximum prefetch nodes per NN call | When the engine can't gather a large enough batch for immediate use, try to prefetch up to `X` positions, which are likely to be useful soon, and put them in the cache.
Default: `32`. | | --cpuct=NUM | Cpuct MCTS option | C_puct constant from Upper Confidence Tree search algorithm. Higher values promote more exploration/wider search, lower values promote more confidence/deeper search.
Default: `1.2`. | | --temperature=NUM | Initial temperature | Tau value from softmax formula. If equal to 0, the engine also picks the best move to make. Larger values increase randomness while making the move.
Default: `0` | | --tempdecay-moves=NUM | Moves with temperature decay | Reduce temperature for every move linearly from initial temperature to `0`, during this number of moves since the game started. `0` disables temperature decay.
Default: `0` | diff --git a/README.md b/README.md index 8a76835ee9..1c2c6cb406 100644 --- a/README.md +++ b/README.md @@ -9,27 +9,27 @@ Lc0 is a UCI-compliant chess engine designed to play chess via neural network, s Lc0 can be acquired either via a git clone or an archive download from GitHub. Be aware that there is a required submodule which isn't included in source archives. -For essentially all purposes, including selfplay game generation and match play, we highly recommend using the latest `release/version` branch (for example `release/0.28`), which is equivalent to using the latest version tag. +For essentially all purposes, including selfplay game generation and match play, we highly recommend using the latest `release/version` branch (for example `release/0.29`), which is equivalent to using the latest version tag. Versioning follows the Semantic Versioning guidelines, with major, minor and patch sections. The training server enforces game quality using the versions output by the client and engine. Download using git: -``` -git clone -b release/0.28 --recurse-submodules https://github.com/LeelaChessZero/lc0.git +```shell +git clone -b release/0.29 --recurse-submodules https://github.com/LeelaChessZero/lc0.git ``` If you have cloned already an old version, fetch, view and checkout a new branch: -``` +```shell git fetch --all git branch --all -git checkout -t remotes/origin/release/0.28 +git checkout -t remotes/origin/release/0.29 ``` If you prefer to download an archive, you need to also download and place the submodule: - * Download the [.zip](https://api.github.com/repos/LeelaChessZero/lc0/zipball/release/0.28) file ([.tar.gz](https://api.github.com/repos/LeelaChessZero/lc0/tarball/release/0.28) archive is also available) + * Download the [.zip](https://api.github.com/repos/LeelaChessZero/lc0/zipball/release/0.29) file ([.tar.gz](https://api.github.com/repos/LeelaChessZero/lc0/tarball/release/0.29) archive is also available) * Extract * Download https://github.com/LeelaChessZero/lczero-common/archive/master.zip (also available as [.tar.gz](https://github.com/LeelaChessZero/lczero-common/archive/master.tar.gz)) * Move the second archive into the first archive's `libs/lczero-common/` folder and extract @@ -48,7 +48,7 @@ Backend support includes (in theory) any CBLAS-compatible library for CPU usage, Finally, lc0 requires a compiler supporting C++17. Minimal versions seem to be g++ v8.0, clang v5.0 (with C++17 stdlib) or Visual Studio 2017. -*Note* that cuda checks the compiler version and stops even with newer compilers, and to work around this we have added the `nvcc_ccbin` build option. This is more of an issue with new Linux versions, where we recommend to install `g++-7` and add `-Dnvcc_ccbin=g++-7` to the `build.sh` command. +*Note* that cuda checks the compiler version and stops even with newer compilers, and to work around this we have added the `nvcc_ccbin` build option. This is more of an issue with new Linux versions, but you can get around it by using an earlier version of gcc just for cuda. As an example, adding `-Dnvcc_ccbin=g++-9` to the `build.sh` command line will use g++-9 with cuda instead of the system compiler. Given those basics, the OS and backend specific instructions are below. @@ -179,7 +179,7 @@ You'll need to be running the latest Raspberry Pi OS "buster". 1. Install OpenBLAS -``` +```shell git clone https://github.com/xianyi/OpenBLAS.git cd OpenBLAS/ make @@ -189,20 +189,20 @@ cd .. 2. Install Meson -``` -pip3 install meson -pip3 install ninja +```shell +pip install meson +pip install ninja ``` 3. Install compiler and standard libraries -``` +```shell sudo apt install clang-6.0 libstdc++-8-dev ``` 4. Clone lc0 and compile -``` +```shell git clone https://github.com/LeelaChessZero/lc0.git cd lc0 git submodule update --init --recursive @@ -211,6 +211,18 @@ CC=clang-6.0 CXX=clang++-6.0 ./build.sh -Ddefault_library=static 5. The resulting binary will be in build/release +## Python bindings + +Python bindings can be built and installed as follows. + +```shell +pip install --user git+https://github.com/LeelaChessZero/lc0.git +``` + +This will build the package `lczero-bindings` and install it to your Python user install directory. +All the `lc0` functionality related to position evaluation is now available in the module `lczero.backends`. +An example interactive session can be found [here](https://github.com/LeelaChessZero/lc0/pull/1261#issuecomment-622951248). + ## License Leela Chess is free software: you can redistribute it and/or modify diff --git a/appveyor.yml b/appveyor.yml index fb64dc3d18..e111a6bdb6 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -109,17 +109,22 @@ cache: - C:\ndk\android-ndk-r19c\toolchains\llvm\prebuilt\windows-x86_64 before_build: - cmd: git submodule update --init --recursive -- cmd: IF %BLAS%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h -- cmd: IF %ANDROID%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h +- cmd: IF %BLAS%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h +- cmd: IF %ANDROID%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h - cmd: SET BUILD_BLAS=%BLAS% - cmd: IF %OPENCL%==true SET BUILD_BLAS=true - cmd: IF %DX%==true SET BUILD_BLAS=true - cmd: SET EMBED=false - cmd: IF %APPVEYOR_REPO_TAG%==true IF %ANDROID%==true SET EMBED=true +- cmd: SET POPCNT=true +- cmd: IF %NAME%==cpu-openblas SET POPCNT=false +- cmd: SET F16C=true +- cmd: IF %NAME%==cpu-openblas SET F16C=false +- cmd: IF %CUDA%==true SET F16C=false - cmd: SET EXTRA= - cmd: IF %ANDROID%==false SET EXTRA=-Db_vscrt=md - cmd: IF %ONNX_DML%==true SET EXTRA=-Db_vscrt=md -Donnx_libdir=C:\cache\%ONNX_NAME%\lib -Donnx_include=C:\cache\%ONNX_NAME%\include -- cmd: IF %ANDROID%==false meson build --backend vs2017 --buildtype release -Dgtest=%GTEST% -Dopencl=%OPENCL% -Dblas=%BUILD_BLAS% -Ddnnl=true -Ddx=%DX% -Dcudnn=%CUDNN% -Donednn=%ONEDNN% -Dispc_native_only=false -Dpopcnt=false -Dcudnn_include="%CUDA_PATH%\include","%CUDA_PATH%\cuda\include" -Dcudnn_libdirs="%CUDA_PATH%\lib\x64","%CUDA_PATH%\cuda\lib\x64" -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\dist64\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\dist64\lib" -Ddnnl_dir="%PKG_FOLDER%\%DNNL_NAME%" -Dopencl_include="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\include" -Dopencl_libdirs="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\lib\x64" -Ddefault_library=static -Dmalloc=mimalloc -Dmimalloc_libdir="%MIMALLOC_PATH%"\out\msvc-x64\Release %EXTRA% +- cmd: IF %ANDROID%==false meson build --backend vs2017 --buildtype release -Dgtest=%GTEST% -Dopencl=%OPENCL% -Dblas=%BUILD_BLAS% -Ddnnl=true -Ddx=%DX% -Dcudnn=%CUDNN% -Donednn=%ONEDNN% -Dispc_native_only=false -Dpopcnt=%POPCNT% -Df16c=%F16C% -Dcudnn_include="%CUDA_PATH%\include","%CUDA_PATH%\cuda\include" -Dcudnn_libdirs="%CUDA_PATH%\lib\x64","%CUDA_PATH%\cuda\lib\x64" -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\dist64\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\dist64\lib" -Ddnnl_dir="%PKG_FOLDER%\%DNNL_NAME%" -Dopencl_include="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\include" -Dopencl_libdirs="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\lib\x64" -Ddefault_library=static -Dmalloc=mimalloc -Dmimalloc_libdir="%MIMALLOC_PATH%"\out\msvc-x64\Release %EXTRA% - cmd: IF %ANDROID%==true meson arm64-v8a --buildtype release -Dgtest=false -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\android-aarch64\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\android-aarch64\lib" -Dembed=%EMBED% -Ddefault_library=static --cross-file crossfile-aarch64 - cmd: IF %ANDROID%==true meson armeabi-v7a --buildtype release -Dgtest=false -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\android-armv7a\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\android-armv7a\lib" -Dembed=%EMBED% -Ddefault_library=static --cross-file crossfile-armv7a -Dispc=false -Dneon=false build_script: diff --git a/build.cmd b/build.cmd index 9f0d3d3144..9b607e01be 100644 --- a/build.cmd +++ b/build.cmd @@ -30,7 +30,11 @@ set CXX=cl set CC_LD=link set CXX_LD=link -if exist "C:\Program Files (x86)\Microsoft Visual Studio\2019" ( +if exist "C:\Program Files\Microsoft Visual Studio\2022" ( + where /q cl + if errorlevel 1 call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 + set backend=vs2022 +) else if exist "C:\Program Files (x86)\Microsoft Visual Studio\2019" ( where /q cl if errorlevel 1 call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 set backend=vs2019 diff --git a/build.sh b/build.sh index deca6e99db..419c0229d2 100755 --- a/build.sh +++ b/build.sh @@ -1,9 +1,10 @@ #!/usr/bin/env bash -pushd "$(dirname "$0")" - set -e +# Move to this script's directory. +CDPATH= cd -- "$(dirname -- "$0")" + case $1 in plain|debug|debugoptimized|release|minsize) BUILDTYPE=$1 @@ -16,27 +17,15 @@ esac BUILDDIR=build/${BUILDTYPE} -if ! hash meson 2>/dev/null && [ -x ${HOME}/.local/bin/meson ] -then - export PATH=${PATH}:${HOME}/.local/bin -fi - -if [ -f ${BUILDDIR}/build.ninja ] -then - meson configure ${BUILDDIR} -Dbuildtype=${BUILDTYPE} -Dprefix=${INSTALL_PREFIX:-/usr/local} "$@" -else - meson ${BUILDDIR} --buildtype ${BUILDTYPE} --prefix ${INSTALL_PREFIX:-/usr/local} "$@" -fi - -cd ${BUILDDIR} - -NINJA=$(awk '/ninja/ {ninja=$4} END {print ninja}' meson-logs/meson-log.txt) +MESON=$(PATH="${PATH}:${HOME}/.local/bin" command -v meson || :) +MESON=${MESON:?"Could not find meson. Is it installed and in PATH?"} -if [ -n "${INSTALL_PREFIX}" ] +if [ -f "${BUILDDIR}/build.ninja" ] then - ${NINJA} install + "${MESON}" configure "${BUILDDIR}" -Dbuildtype="${BUILDTYPE}" -Dprefix="${INSTALL_PREFIX:-/usr/local}" "$@" else - ${NINJA} + "${MESON}" "${BUILDDIR}" --buildtype "${BUILDTYPE}" --prefix "${INSTALL_PREFIX:-/usr/local}" "$@" fi -popd +"${MESON}" compile -C "${BUILDDIR}" +[ -n "${INSTALL_PREFIX}" ] && "${MESON}" install -C "${BUILDDIR}" diff --git a/changelog.txt b/changelog.txt index 39444f5848..c869b9f5c5 100644 --- a/changelog.txt +++ b/changelog.txt @@ -1,4 +1,52 @@ -v0.29.0-rc0 (2022-04-03) +v0.30.0-rc1 (2023-04-24) +~~~~~~~ +* Support for networks with attention body and smolgen added to blas, cuda, + metal and onnx backends. +* Persistent L2 cache optimization for the cuda backend. Use the + `cache_opt=true` backend option to turn it on. +* Some performance improvements for the cuda, onnx and blas backends. +* Added the `threads` backend option to onnx, defaults to 0 (let the + onnxruntime decide) except for onnx-cpu that defaults to 1. +* The onnx-dml package now includes a `directml.dll` installation script. +* Some users experienced memory issues with onnx-dml, so the defaults were + changed. This may affect performance, in which case you can use the `steps=8` + backend option to get the old behavior. +* The Python bindings are available as a package, see the README for + instructions. +* Some assorted fixes and code cleanups. + +v0.29.0 (2022-12-13) +~~~~~~~ +* Updated onednn version to the latest one. + +v0.29.0-rc1 (2022-12-09) +~~~~~~~ +* New metal backend for apple systems. This is now the default backend for + macos builds. +* New onnx-dml backend to use DirectML under windows, has better net + compatibility than dx12 and is faster than opencl. See the README for use + instructions, a separate download of the DirectML dll is required. +* Full attention policy support in cuda, cudnn, metal, onnx, blas, dnnl, and + eigen backends. +* Partial attention policy support in onednn backend (good enough for T79). +* Now the onnx backends can use fp16 when running with a network file (not with + .onnx model files). This is the default for onnx-cuda and onnx-dml, can be + switched on or off with by setting the `fp16` backend option to `true` or + `false` respectively. +* The onednn package comes with a dnnl compiled to allow running on an intel gpu + by adding `gpu=0` to the backend options. +* The default net is now 791556 for most backends except opencl and dx12 that + get 753723 (as they lack attention policy support). +* Support for using pgn book with long lines in training: selfplay can start at + a random point in the book. +* New "simple" time manager. +* Support for double Fischer random chess (dfrc). +* Added TC-dependent output to the backendbench assistant. +* Starting with this version, the check backend compares policy for valid moves + after softmax. +* Some assorted fixes and code cleanups. + +v0.29.0-rc0 (2022-04-03) ~~~~~~~ * Initial support for attention policy, only cuda backend and partially in blas/dnnl/eigen (good enough for T79). diff --git a/cross-files/aarch64-darwin b/cross-files/aarch64-darwin new file mode 100644 index 0000000000..441f101d56 --- /dev/null +++ b/cross-files/aarch64-darwin @@ -0,0 +1,27 @@ + +[host_machine] +system = 'darwin' +cpu_family = 'aarch64' +cpu = 'aarch64' +endian = 'little' + +[properties] + + +[binaries] +c = 'clang' +cpp = 'clang++' +objc = 'clang' +objcpp = 'clang++' +ar = 'ar' +ld = 'ld' + +[built-in options] +c_args = ['-arch', 'arm64'] +cpp_args = ['-arch', 'arm64'] +objc_args = ['-arch', 'arm64'] +objcpp_args = ['-arch', 'arm64'] +c_link_args = ['-arch', 'arm64'] +cpp_link_args = ['-arch', 'arm64'] +objc_link_args = ['-arch', 'arm64'] +objcpp_link_args = ['-arch', 'arm64'] diff --git a/cross-files/aarch64-linux-android b/cross-files/aarch64-linux-android index 40e1faca3a..4a55d838de 100644 --- a/cross-files/aarch64-linux-android +++ b/cross-files/aarch64-linux-android @@ -7,7 +7,7 @@ [host_machine] system = 'android' -cpu_family = 'arm' +cpu_family = 'aarch64' cpu = 'aarch64' endian = 'little' diff --git a/dist/README-onnx-dml.txt b/dist/README-onnx-dml.txt index 6b721d2c8e..5e34b3eb52 100644 --- a/dist/README-onnx-dml.txt +++ b/dist/README-onnx-dml.txt @@ -4,8 +4,9 @@ Lc0 is a UCI-compliant chess engine designed to play chess via neural network, specifically those of the LeelaChessZero project (https://lczero.org). -To run this version you will most likely need a very recent DirectML dll. -You can download the currently latest nuget installer package from +To run this version you will most likely need a very recent DirectML dll, +which you can get by running the included `install.cmd` script. Alternatively, +you can download the currently latest nuget installer package from . If you don't know how to use nuget installer packages, you can change the extension to .zip and open it as a normal zip file, the dll you need is diff --git a/dist/install-dml.cmd b/dist/install-dml.cmd new file mode 100644 index 0000000000..099f42958c --- /dev/null +++ b/dist/install-dml.cmd @@ -0,0 +1,31 @@ +@echo off +where /q tar +if errorlevel 1 goto error + +where /q lc0.exe +if errorlevel 1 cd /d %~dp0 +where /q lc0.exe +if errorlevel 1 ( + echo This script must run in the lc0 folder. + pause + exit /b +) + +cls +echo Installing the DirectML.dll version required by the Lc0 onnx-dml backend. +curl -# --ssl-no-revoke -o tmp_directml.zip https://globalcdn.nuget.org/packages/microsoft.ai.directml.1.10.0.nupkg" +if errorlevel 1 goto error + +tar -xzOf tmp_directml.zip bin/x64-win/DirectML.dll >DirectML.dll +if errorlevel 1 goto error + +del /q tmp_directml.zip +echo Installation successful. +pause +exit /b + +:error +cls +echo Installation failed - see the README for an alternative approach. +pause + diff --git a/libs/lczero-common b/libs/lczero-common index 4dfa4ce833..fafda0f59c 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 4dfa4ce8339357819f7de01517e6297d4c768cdf +Subproject commit fafda0f59c8511b5d933ef758c1e4b10a62da1e0 diff --git a/meson.build b/meson.build index 49b60c1cd6..f073c42705 100644 --- a/meson.build +++ b/meson.build @@ -16,7 +16,7 @@ project('lc0', 'cpp', default_options : ['cpp_std=c++17', 'b_ndebug=if-release', 'warning_level=3', 'b_lto=true', 'b_vscrt=mt'], - meson_version: '>=0.52') + meson_version: '>=0.54') cc = meson.get_compiler('cpp') @@ -48,7 +48,7 @@ endif if host_machine.system() == 'windows' add_project_arguments('-DNOMINMAX', language : 'cpp') endif -if host_machine.cpu_family() == 'arm' +if ['arm', 'aarch64'].contains(host_machine.cpu_family()) if get_option('neon') add_project_arguments(cc.get_supported_arguments(['-mfpu=neon']), language : 'cpp') add_project_link_arguments(cc.get_supported_arguments(['-mfpu=neon']), language : 'cpp') @@ -209,7 +209,6 @@ files += [ 'src/utils/random.cc', 'src/utils/string.cc', 'src/utils/weights_adapter.cc', - 'src/utils/fp16_utils.cc', 'src/version.cc', ] includes += include_directories('src') @@ -267,12 +266,14 @@ if get_option('build_backends') if get_option('blas') if get_option('mkl') and mkl_lib.found() - add_project_arguments(['-DUSE_MKL', '-DUSE_BLAS'], language : 'cpp') mkl_inc = get_option('mkl_include') if run_command('scripts/checkdir.py', mkl_inc).returncode() == 0 includes += include_directories(mkl_inc) endif - deps += [ mkl_lib ] + if cc.has_header('mkl.h') + add_project_arguments(['-DUSE_MKL', '-DUSE_BLAS'], language : 'cpp') + deps += [ mkl_lib ] + endif elif get_option('dnnl') and dnnl_lib.found() add_project_arguments(['-DUSE_DNNL', '-DUSE_BLAS'], language : 'cpp') @@ -313,7 +314,7 @@ if get_option('build_backends') ispc_arch = 'x86-64' ispc_extra_args = [] if get_option('ispc') and ispc.found() - ispc_native_only = get_option('ispc_native_only') + ispc_native_only = get_option('ispc_native_only') and not meson.is_cross_build() if host_machine.system() == 'windows' outputnames = [ '@BASENAME@.obj'] if not ispc_native_only @@ -330,26 +331,27 @@ if get_option('build_backends') '@BASENAME@_avx512knl.o', '@BASENAME@_avx512skx.o' ] endif endif - if ispc_native_only - ispc_target = 'host' - else - ispc_target = 'sse2-i32x8,sse4-i32x8,avx1-i32x8,avx2-i32x8,avx512knl-i32x16,avx512skx-i32x16' - endif + ispc_target = 'sse2-i32x8,sse4-i32x8,avx1-i32x8,avx2-i32x8,avx512knl-i32x16,avx512skx-i32x16' if host_machine.system() == 'android' ispc_extra_args += ['--target-os=android'] - if host_machine.cpu_family() == 'arm' - outputnames = [ '@BASENAME@.o'] - if host_machine.cpu() == 'aarch64' - ispc_target = 'neon-i32x8' - ispc_arch = 'aarch64' - else - ispc_target = 'neon-i32x4' - ispc_arch = 'arm' - endif + endif + + if ['arm', 'aarch64'].contains(host_machine.cpu_family()) + outputnames = [ '@BASENAME@.o'] + if host_machine.cpu_family() == 'aarch64' + ispc_target = 'neon-i32x8' + ispc_arch = 'aarch64' + else + ispc_target = 'neon-i32x4' + ispc_arch = 'arm' endif endif + if ispc_native_only + ispc_target = 'host' + endif + iscp_gen = generator(ispc, output: [ '@BASENAME@_ispc.h', outputnames ], arguments: [ '-O2', '--wno-perf', '--arch=' + ispc_arch, @@ -377,6 +379,7 @@ if get_option('build_backends') if get_option('ispc') and ispc.found() files += iscp_gen.process('src/neural/blas/winograd_transform.ispc') + files += iscp_gen.process('src/neural/blas/layer_norm.ispc') files += iscp_gen.process('src/neural/shared/activation.ispc') add_project_arguments('-DUSE_ISPC', language : 'cpp') endif @@ -458,6 +461,7 @@ if get_option('build_backends') includes += include_directories('src/neural/cuda/') cuda_arguments = ['-c', '@INPUT@', '-o', '@OUTPUT@', + '-I', meson.source_root() + '/subprojects/abseil-cpp-20211102.0', '-I', meson.current_source_dir() + '/src'] if host_machine.system() == 'windows' if get_option('b_vscrt') == 'mt' @@ -497,27 +501,32 @@ if get_option('build_backends') ) # Handling of fp16 cuda code. - nvcc_arch = '-arch=compute_70' - nvcc_sm_list = ['sm_80', 'sm_75', 'sm_86', 'sm_70'] + nvcc_sm_list = ['80', '75', '86', '70', '89', '90'] if host_machine.system() != 'windows' - nvcc_arch = '-arch=compute_60' - nvcc_sm_list += ['sm_60'] + nvcc_sm_list += ['60'] if ['arm', 'aarch64'].contains(host_machine.cpu_family()) # Add Jetson versions to the list. message('Jetson support enabled.') - nvcc_arch = '-arch=compute_53' - nvcc_sm_list += ['sm_72', 'sm_62', 'sm_53'] + nvcc_sm_list += ['72', '62', '53', '87'] endif endif # Ignore the given CC for fp16 when it is not in the supported list. if cuda_cc == '' or not nvcc_sm_list.contains('sm_' + cuda_cc) - nvcc_extra_args = [nvcc_arch] + nvcc_extra_args = [] nvcc_help = run_command(nvcc, '-h').stdout() foreach x : nvcc_sm_list - if nvcc_help.contains(x) - nvcc_extra_args += '-code=' + x + if nvcc_help.contains('sm_' + x) + nvcc_extra_args += '-gencode=arch=compute_' + x + ',code=sm_' + x endif endforeach + # For forward compatibility. + if nvcc_help.contains('sm_90') # Cuda 12+ + nvcc_extra_args += '-gencode=arch=compute_90,code=compute_90' + elif nvcc_help.contains('sm_80') # Cuda 11+ + nvcc_extra_args += '-gencode=arch=compute_80,code=compute_80' + elif nvcc_help.contains('sm_75') # Cuda 10+ + nvcc_extra_args += '-gencode=arch=compute_75,code=compute_75' + endif endif files += custom_target('cuda fp16 code', input : 'src/neural/cuda/fp16_kernels.cu', @@ -619,6 +628,15 @@ endif ############################################################################# ## Dependencies ############################################################################# + + ## ~~~~~~ + ## Abseil + ## ~~~~~~ + # $ meson wrap install abseil-cpp + absl = subproject('abseil-cpp', default_options : ['warning_level=0']) + deps += absl.get_variable('absl_container_dep').as_system() + includes += absl.get_variable('absl_include_dir') + ## ~~~~ ## zlib ## ~~~~ @@ -650,6 +668,10 @@ if not get_option('popcnt') add_project_arguments('-DNO_POPCNT', language : 'cpp') endif +if not get_option('f16c') + add_project_arguments('-DNO_F16C', language : 'cpp') +endif + if not get_option('pext') add_project_arguments('-DNO_PEXT', language : 'cpp') endif @@ -723,5 +745,7 @@ if get_option('python_bindings') python.extension_module('backends', [py_files + files], include_directories: [includes], - dependencies: [cpython] + deps) + dependencies: [cpython] + deps, + subdir: 'lczero', + install: true) endif diff --git a/meson_options.txt b/meson_options.txt index 002584666e..cc2364be45 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -133,6 +133,11 @@ option('popcnt', value: true, description: 'Use the popcnt instruction') +option('f16c', + type: 'boolean', + value: true, + description: 'Use natice fp16 conversion instructions') + option('pext', type: 'boolean', value: false, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..b55ffdc83f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["meson-python"] +build-backend = "mesonpy" + +[project] +name = "lczero_bindings" +version = "0.1.0" +description = "Leela Chess Zero Python bindings" +authors = [{ name = "The LCZero Authors" }] +license = {file = "COPYING"} +readme = "README.md" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Python :: 3", + "Topic :: Games/Entertainment :: Board Games", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Environment :: GPU" +] + +[project.urls] +homepage = "https://github.com/LeelaChessZero/lc0" + +[tool.meson-python.args] +dist = [] +setup = ["-Dpython_bindings=true"] +compile = [] +install = [] \ No newline at end of file diff --git a/scripts/appveyor_win_package.cmd b/scripts/appveyor_win_package.cmd index 57296c35f3..76124cd5ec 100644 --- a/scripts/appveyor_win_package.cmd +++ b/scripts/appveyor_win_package.cmd @@ -27,10 +27,12 @@ IF %NAME%==onednn copy "%PKG_FOLDER%\%DNNL_NAME%\THIRD-PARTY-PROGRAMS" dist\DNNL IF %NAME%==onednn 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\DNNL-LICENSE IF %NAME%==onednn 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\DNNL-THIRD-PARTY-PROGRAMS IF %ONNX_DML%==true type dist\README-onnx-dml.txt |more /P > dist\README.txt +IF %ONNX_DML%==true type dist\install-dml.cmd |more /P > dist\install.cmd IF %ONNX_DML%==true copy "%PKG_FOLDER%\%ONNX_NAME%\LICENSE" dist\ONNX-DML-LICENSE IF %ONNX_DML%==true copy "%PKG_FOLDER%\%ONNX_NAME%\ThirdPartyNotices.txt" dist\ONNX-DML-ThirdPartyNotices.txt IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip "%PKG_FOLDER%\%ONNX_NAME%\lib\onnxruntime.dll" IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\README.txt +IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\install.cmd IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\ONNX-DML-LICENSE IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\ONNX-DML-ThirdPartyNotices.txt IF %OPENCL%==true type scripts\check_opencl.bat |more /P > dist\check_opencl.bat diff --git a/scripts/compile_proto.py b/scripts/compile_proto.py index b7bd4eda6b..1a21c2da4e 100755 --- a/scripts/compile_proto.py +++ b/scripts/compile_proto.py @@ -81,6 +81,7 @@ class Lexer: + def __init__(self, text): self.text = text self.grammar = [(re.compile(x, re.S + re.M), y) for x, y in GRAMMAR] @@ -130,13 +131,16 @@ def NextTokenOrWhitespace(self): def Error(self, text): '''Throws an error with context in the file read.''' + line = self.text[:self.cur_offset].count('\n') + 1 line_start = self.text.rfind('\n', 0, self.cur_offset) + 1 line_end = self.text.find('\n', line_start) + if line_end == -1: + line_end = len(self.text) sys.stderr.write('%s:\n' % text) sys.stderr.write(self.text[line_start:line_end] + '\n') sys.stderr.write(' ' * (self.cur_offset - line_start) + '^^^\n') - raise ValueError("Parse error: %s at offset %d." % - (text, self.cur_offset)) + raise ValueError("Parse error: %s at line %d column %d." % + (text, line, (self.cur_offset - line_start))) def ReadIdentifierPath(lexer): @@ -155,7 +159,7 @@ def LookupType(name, stack): for x in y: if x.GetName() == name[0]: if len(name) == 1: - return x.GetType() + return x else: return LookupType(name[1:], [x.GetTypes()]) raise ValueError("Cannot find type: %s." % '.'.join(name)) @@ -167,6 +171,7 @@ def LookupType(name, stack): class ProtoTypeParser: + def __init__(self, lexer, object_stack): token, match = lexer.Pick() if token in TYPES: @@ -175,10 +180,16 @@ def __init__(self, lexer, object_stack): lexer.Consume(token) elif token == 'identifier': self.name = ReadIdentifierPath(lexer) - self.typetype = LookupType(self.name, object_stack) + self.typetype = 'forward' else: lexer.Error('Type expected') + def LookupForwardFieldType(self, object_stack): + if self.IsForward(): + typ = LookupType(self.name, object_stack) + self.typetype = typ.GetType() + self.name = [typ.GetFullName()] + def IsZigzag(self): if self.typetype == 'basic': return self.name in ZIGZAG_TYPES @@ -188,7 +199,7 @@ def GetCppType(self): if self.typetype == 'basic': return TYPES[self.name] else: - return '::'.join(self.name) + return '_'.join(self.name) def GetVariableCppType(self): if self.IsBytesType(): @@ -196,6 +207,9 @@ def GetVariableCppType(self): else: return self.GetCppType() + def IsEnumType(self): + return self.typetype == 'enum' + def IsVarintType(self): return self.typetype == 'enum' or (self.typetype == 'basic' and self.name in VARINT_TYPES) @@ -231,6 +245,9 @@ def GetWireType(self): def IsMessage(self): return self.typetype == 'message' + def IsForward(self): + return self.typetype == 'forward' + def IsIntegralType(self): if self.typetype == 'basic': if self.name == 'double': @@ -251,6 +268,7 @@ def IsIntegralType(self): class ProtoFieldParser: + def __init__(self, lexer, object_stack): token, match = lexer.Pick() if token not in ['repeated', 'optional', 'required']: @@ -266,6 +284,9 @@ def __init__(self, lexer, object_stack): def IsType(self): return False + def LookupForwardFieldType(self, object_stack): + self.type.LookupForwardFieldType(object_stack) + def GetParser(self): name = self.name.group(0) if self.type.IsMessage(): @@ -331,42 +352,91 @@ def GenerateOutput(self, w): w.Write('%s %s(%d, %s, &out);' % (prefix, fname[wire_id], self.number, name)) - def GenerateFunctions(self, w): + def GenerateJsonOutput(self, w): + name = self.name.group(0) + if self.category == 'repeated': + prefix = 'if (!%s_.empty())' % name + funcname = 'AppendJsonRepeatedField' + else: + prefix = 'if (has_%s_)' % name + funcname = 'AppendJsonField' + if self.type.IsEnumType(): + value = '%s_Name(%s_)' % (self.type.GetCppType(), name) + else: + value = name + "_" + w.Write('%s %s("%s", %s, &first, &out);' % + (prefix, funcname, name, value)) + + def GenerateFunctionDeclarations(self, w): name = self.name.group(0) cpp_type = self.type.GetCppType() var_cpp_type = self.type.GetVariableCppType() if self.category == 'repeated': if self.type.IsMessage(): - w.Write("%s* add_%s() { return &%s_.emplace_back(); }" % - (cpp_type, name, name)) + w.Write("%s* add_%s();" % (cpp_type, name)) else: - w.Write("void add_%s(%s val) { %s_.emplace_back(val); }" % - (name, cpp_type, name)) - w.Write("const std::vector<%s>& %s() const { return %s_; }" % - (var_cpp_type, name, name)) + w.Write("void add_%s(%s val);" % (name, cpp_type)) + w.Write("const std::vector<%s>& %s() const;" % + (var_cpp_type, name)) if self.type.IsMessage(): - w.Write("const %s& %s(size_t idx) const { return %s_[idx]; }" % - (cpp_type, name, name)) + w.Write("const %s& %s(size_t idx) const;" % (cpp_type, name)) else: - w.Write("%s %s(size_t idx) const { return %s_[idx]; }" % - (cpp_type, name, name)) - w.Write("size_t %s_size() const { return %s_.size(); }" % - (name, name)) + w.Write("%s %s(size_t idx) const;" % (cpp_type, name)) + w.Write("size_t %s_size() const;" % (name)) else: - w.Write("bool has_%s() const { return has_%s_; }" % (name, name)) + w.Write("bool has_%s() const;" % (name)) if self.type.IsMessage(): - w.Write("const %s& %s() const { return %s_; }" % - (cpp_type, name, name)) - w.Write("%s* mutable_%s() {" % (cpp_type, name)) + w.Write("const %s& %s() const;" % (cpp_type, name)) + w.Write("%s* mutable_%s();" % (cpp_type, name)) + else: + w.Write("%s %s() const;" % (cpp_type, name)) + w.Write("void set_%s(%s val);" % (name, cpp_type)) + + def GenerateFunctionDefinitions(self, w, class_name): + name = self.name.group(0) + cpp_type = self.type.GetCppType() + var_cpp_type = self.type.GetVariableCppType() + if self.category == 'repeated': + if self.type.IsMessage(): + w.Write( + "inline %s* %s::add_%s() { return &%s_.emplace_back(); }" % + (cpp_type, class_name, name, name)) + else: + w.Write( + "inline void %s::add_%s(%s val) { %s_.emplace_back(val); }" + % (class_name, name, cpp_type, name)) + w.Write( + "inline const std::vector<%s>& %s::%s() const { return %s_; }" + % (var_cpp_type, class_name, name, name)) + if self.type.IsMessage(): + w.Write( + "inline const %s& %s::%s(size_t idx) const { return %s_[idx]; }" + % (cpp_type, class_name, name, name)) + else: + w.Write( + "inline %s %s::%s(size_t idx) const { return %s_[idx]; }" % + (cpp_type, class_name, name, name)) + w.Write( + "inline size_t %s::%s_size() const { return %s_.size(); }" % + (class_name, name, name)) + else: + w.Write("inline bool %s::has_%s() const { return has_%s_; }" % + (class_name, name, name)) + if self.type.IsMessage(): + w.Write("inline const %s& %s::%s() const { return %s_; }" % + (cpp_type, class_name, name, name)) + w.Write("inline %s* %s::mutable_%s() {" % + (cpp_type, class_name, name)) w.Indent() w.Write('has_%s_ = true;' % (name)) w.Write('return &%s_;' % name) w.Unindent() w.Write("}") else: - w.Write("%s %s() const { return %s_; }" % - (cpp_type, name, name)) - w.Write("void set_%s(%s val) {" % (name, cpp_type)) + w.Write("inline %s %s::%s() const { return %s_; }" % + (cpp_type, class_name, name, name)) + w.Write("inline void %s::set_%s(%s val) {" % + (class_name, name, cpp_type)) w.Indent() w.Write("has_%s_ = true;" % name) w.Write("%s_ = val;" % name) @@ -385,10 +455,12 @@ def GenerateVariable(self, w): class ProtoEnumParser: - def __init__(self, lexer): + + def __init__(self, lexer, scope): lexer.Consume('enum') self.name = lexer.Consume('identifier').group(0) self.values = [] + self.scope = scope[:] lexer.Consume('{') while True: token, match = lexer.Pick() @@ -404,21 +476,54 @@ def __init__(self, lexer): def GetName(self): return self.name + def GetFullName(self): + return '_'.join([x.GetName() for x in self.scope] + [self.name]) + def GetType(self): return 'enum' def IsType(self): return True - def Generate(self, w): + def ResolveForwardDeclarations(self, _): + pass + + def GenerateMessageDeclarations(self, w): + pass + + def GenerateMessageDefinitions(self, w): + pass + + def GenerateFunctionDefinitions(self, w): + pass + + def GenerateEnumDefinitions(self, w): # Protobuf enum is mapped directly to C++ enum. - w.Write('enum %s {' % self.name) + w.Write('enum %s : int {' % self.GetFullName()) w.Indent() for key, value in self.values: - w.Write('%s = %d,' % (key, value)) + w.Write('%s_%s = %d,' % (self.GetFullName(), key, value)) + w.Unindent() + w.Write('};') + w.Write('inline std::string %s_Name(%s val) {' % + (self.GetFullName(), self.GetFullName())) + w.Indent() + w.Write('switch (val) {') + w.Indent() + for key, _ in self.values: + w.Write('case %s_%s:' % (self.GetFullName(), key)) + w.Write(' return "%s";' % key) w.Unindent() w.Write('};') - # Static array of all possible enum values. + w.Write('return "%s(" + std::to_string(val) + ")";' % self.name) + w.Unindent() + w.Write('}') + + def GenerateUsingDirectives(self, w): + w.Write('using %s = %s;' % (self.name, self.GetFullName())) + for key, _ in self.values: + w.Write('static constexpr %s %s =' % (self.name, key)) + w.Write(' %s_%s;' % (self.GetFullName(), key)) w.Write('static constexpr std::array<%s,%d> %s_AllValues = {' % (self.name, len(self.values), self.name)) w.Indent() @@ -430,22 +535,18 @@ def Generate(self, w): w.Write('static std::string %s_Name(%s val) {' % (self.name, self.name)) w.Indent() - w.Write('switch (val) {') - w.Indent() - for key, _ in self.values: - w.Write('case %s:' % key) - w.Write(' return "%s";' % key) - w.Unindent() - w.Write('};') - w.Write('return "%s(" + std::to_string(val) + ")";' % self.name) + w.Write('return %s_Name(val);' % (self.GetFullName())) w.Unindent() w.Write('}') class ProtoMessageParser: - def __init__(self, lexer, type_stack): + + def __init__(self, lexer, type_stack, scope): + type_stack[0].append(self) self.types = [] self.fields = [] + self.scope = scope[:] lexer.Consume('message') self.name = lexer.Consume('identifier').group(0) lexer.Consume('{') @@ -454,10 +555,10 @@ def __init__(self, lexer, type_stack): if token == '}': break elif token == 'message': - self.types.append( - ProtoMessageParser(lexer, [self.types, *type_stack])) + ProtoMessageParser(lexer, [self.types, *type_stack], + self.scope + [self]) elif token == 'enum': - self.types.append(ProtoEnumParser(lexer)) + self.types.append(ProtoEnumParser(lexer, self.scope + [self])) elif token in ['repeated', 'optional', 'required']: self.fields.append( ProtoFieldParser(lexer, [self.types, *type_stack])) @@ -468,6 +569,9 @@ def __init__(self, lexer, type_stack): def GetName(self): return self.name + def GetFullName(self): + return '_'.join([x.GetName() for x in self.scope] + [self.name]) + def GetType(self): return 'message' @@ -483,7 +587,15 @@ def GetFieldsGruppedByWireType(self): type_to_fields.setdefault(x.type.GetWireType(), []).append(x) return type_to_fields - def WriteFieldParser(self, w, wire_id, fields): + def ResolveForwardDeclarations(self, type_stack): + type_stack.append(self.types) + for x in self.types: + x.ResolveForwardDeclarations(type_stack) + for x in self.fields: + x.LookupForwardFieldType(type_stack) + type_stack.pop() + + def WriteFieldParserDeclaration(self, w, wire_id, fields): fname = {0: 'SetVarInt', 1: 'SetInt64', 2: 'SetString', 5: 'SetInt32'} tname = { 0: 'std::uint64_t', @@ -491,8 +603,19 @@ def WriteFieldParser(self, w, wire_id, fields): 2: 'std::string_view', 5: 'std::uint32_t' } - w.Write('void %s(int field_id, %s val) override {' % + w.Write('void %s(int field_id, %s val) final;' % (fname[wire_id], tname[wire_id])) + + def WriteFieldParserDefinition(self, w, wire_id, fields): + fname = {0: 'SetVarInt', 1: 'SetInt64', 2: 'SetString', 5: 'SetInt32'} + tname = { + 0: 'std::uint64_t', + 1: 'std::uint64_t', + 2: 'std::string_view', + 5: 'std::uint32_t' + } + w.Write('inline void %s::%s(int field_id, %s val) {' % + (self.GetFullName(), fname[wire_id], tname[wire_id])) w.Indent() w.Write('switch (field_id) {') w.Indent() @@ -503,19 +626,67 @@ def WriteFieldParser(self, w, wire_id, fields): w.Unindent() w.Write('}') - def Generate(self, w): + def GenerateUsingDirectives(self, w): + w.Write('using %s = %s;' % (self.name, self.GetFullName())) + + def GenerateMessageDeclarations(self, w): + w.Write(f'class %s;' % self.GetFullName()) + for x in self.types: + x.GenerateMessageDeclarations(w) + + def GenerateEnumDefinitions(self, w): + for x in self.types: + x.GenerateEnumDefinitions(w) + + def GenerateMessageDefinitions(self, w): + # Writing nested messages. + for x in self.types: + if x.GetType() == 'message': + x.GenerateMessageDefinitions(w) # Protobuf message is a C++ class. - w.Write('class %s : public lczero::ProtoMessage {' % self.name) + w.Write('class %s final : public lczero::ProtoMessage {' % + self.GetFullName()) w.Write(' public:') w.Indent() - # Writing submessages and enums. + # Writing using directives. for x in self.types: - x.Generate(w) + x.GenerateUsingDirectives(w) + # Writing function declarations. for x in self.fields: w.Write('') - x.GenerateFunctions(w) + x.GenerateFunctionDeclarations(w) w.Write('') - w.Write('std::string OutputAsString() const override {') + w.Write('std::string OutputAsString() const final;') + w.Write('std::string OutputAsJson() const final;') + w.Write('void Clear() final;') + + w.Unindent() + w.Write('') + w.Write(' private:') + w.Indent() + for k, v in self.GetFieldsGruppedByWireType().items(): + self.WriteFieldParserDeclaration(w, k, v) + w.Write('') + for x in self.fields: + x.GenerateVariable(w) + w.Unindent() + w.Write('};') + w.Write('') + + def GenerateFunctionDefinitions(self, w): + # Writing nested messages. + for x in self.types: + if x.GetType() == 'message': + x.GenerateFunctionDefinitions(w) + self.GenerateOutputAsStringFunc(w) + self.GenerateOutputAsJsonFunc(w) + self.GenerateClearFunc(w) + self.GenerateParserFuncs(w) + self.GenerateFieldAccessorFuncs(w) + + def GenerateOutputAsStringFunc(self, w): + w.Write('inline std::string %s::OutputAsString() const {' % + self.GetFullName()) w.Indent() w.Write('std::string out;') for x in sorted(self.fields, key=lambda x: x.number): @@ -523,31 +694,44 @@ def Generate(self, w): w.Write('return out;') w.Unindent() w.Write('}') - w.Write('') - w.Write('void Clear() override {') + + def GenerateOutputAsJsonFunc(self, w): + w.Write('inline std::string %s::OutputAsJson() const {' % + self.GetFullName()) w.Indent() + if self.fields: + w.Write('bool first = true;') + w.Write('std::string out = "{";') for x in self.fields: - x.GenerateClear(w) + x.GenerateJsonOutput(w) + w.Write('out += "}";') + w.Write('return out;') w.Unindent() w.Write('}') - w.Unindent() - w.Write('') - w.Write(' private:') + + def GenerateClearFunc(self, w): + w.Write('inline void %s::Clear() {' % self.GetFullName()) w.Indent() - for k, v in self.GetFieldsGruppedByWireType().items(): - self.WriteFieldParser(w, k, v) - w.Write('') for x in self.fields: - x.GenerateVariable(w) + x.GenerateClear(w) w.Unindent() - w.Write('};') + w.Write('}') + + def GenerateParserFuncs(self, w): + for k, v in self.GetFieldsGruppedByWireType().items(): + self.WriteFieldParserDefinition(w, k, v) + + def GenerateFieldAccessorFuncs(self, w): + for x in self.fields: + x.GenerateFunctionDefinitions(w, self.GetFullName()) class ProtoFileParser: '''Root grammar of .proto file''' + def __init__(self, lexer): self.package = None - self.objects = [] + self.types = [] while True: token, match = lexer.Pick() if token == 'EOF': @@ -558,6 +742,8 @@ def __init__(self, lexer): self.ParsePackage(lexer) elif token == 'message': self.ParseMessage(lexer) + elif token == 'enum': + self.ParseEnum(lexer) else: lexer.Error('Expected message or something similar') @@ -575,7 +761,10 @@ def ParsePackage(self, lexer): lexer.Consume(';') def ParseMessage(self, lexer): - self.objects.append(ProtoMessageParser(lexer, [self.objects])) + ProtoMessageParser(lexer, [self.types], []) + + def ParseEnum(self, lexer): + self.types.append(ProtoEnumParser(lexer, [])) def Generate(self, w): w.Write('// This file is AUTOGENERATED, do not edit.') @@ -583,16 +772,32 @@ def Generate(self, w): w.Write('#include "utils/protomessage.h"') for x in self.package: w.Write('namespace %s {' % x) - w.Indent() - for object in self.objects: - object.Generate(w) - w.Unindent() + w.Write('') + w.Write('// Forward declarations.') + for object in self.types: + object.GenerateMessageDeclarations(w) + for object in self.types: + object.GenerateEnumDefinitions(w) + w.Write('') + w.Write('// Class declarations.') + for object in self.types: + object.GenerateMessageDefinitions(w) + w.Write('') + w.Write('// Function definitions.') + for object in self.types: + object.GenerateFunctionDefinitions(w) for x in reversed(self.package): w.Write('} // namespace %s' % x) + def ResolveForwardDeclarations(self): + type_stack = [self.types] + for object in self.types: + object.ResolveForwardDeclarations(type_stack) + class Writer: '''A helper class for writing file line by line with indent.''' + def __init__(self, fo): self.fo = fo self.indent = 0 @@ -626,5 +831,6 @@ def Write(self, text): with open(args.input, 'r') as input, open(dest_path, 'w') as output: proto_file = ProtoFileParser(Lexer(input.read())) + proto_file.ResolveForwardDeclarations() writer = Writer(output) proto_file.Generate(writer) diff --git a/src/benchmark/backendbench.cc b/src/benchmark/backendbench.cc index 6792f9b778..401d8bea1b 100644 --- a/src/benchmark/backendbench.cc +++ b/src/benchmark/backendbench.cc @@ -29,6 +29,7 @@ #include "chess/board.h" #include "mcts/node.h" +#include "neural/encoder.h" #include "neural/factory.h" #include "utils/optionsparser.h" diff --git a/src/benchmark/benchmark.cc b/src/benchmark/benchmark.cc index 5679ff45b6..a48d7b1e3f 100644 --- a/src/benchmark/benchmark.cc +++ b/src/benchmark/benchmark.cc @@ -97,12 +97,12 @@ void Benchmark::Run() { NNCache cache; cache.SetCapacity(option_dict.Get(kNNCacheSizeId)); - NodeTree tree; + NodeTree tree = {option_dict}; tree.ResetToPosition(position, {}); const auto start = std::chrono::steady_clock::now(); auto search = std::make_unique( - tree, network.get(), + &tree, network.get(), std::make_unique( std::bind(&Benchmark::OnBestMove, this, std::placeholders::_1), std::bind(&Benchmark::OnInfo, this, std::placeholders::_1)), diff --git a/src/chess/position.cc b/src/chess/position.cc index ec085f1e37..9efbde0f6a 100644 --- a/src/chess/position.cc +++ b/src/chess/position.cc @@ -78,7 +78,7 @@ Position::Position(const ChessBoard& board, int rule50_ply, int game_ply) } uint64_t Position::Hash() const { - return HashCat({us_board_.Hash(), static_cast(repetitions_)}); + return us_board_.Hash(); } std::string Position::DebugString() const { return us_board_.DebugString(); } @@ -130,7 +130,7 @@ int PositionHistory::ComputeLastMoveRepetitions(int* cycle_length) const { // TODO(crem) implement hash/cache based solution. if (last.GetRule50Ply() < 4) return 0; - for (int idx = positions_.size() - 3; idx >= 0; idx -= 2) { + for (int idx = positions_.size() - 5; idx >= 0; idx -= 2) { const auto& pos = positions_[idx]; if (pos.GetBoard() == last.GetBoard()) { *cycle_length = positions_.size() - 1 - idx; diff --git a/src/chess/position.h b/src/chess/position.h index 69241467c4..879e08fb2c 100644 --- a/src/chess/position.h +++ b/src/chess/position.h @@ -27,6 +27,7 @@ #pragma once +#include #include #include "chess/board.h" @@ -99,11 +100,22 @@ GameResult operator-(const GameResult& res); class PositionHistory { public: PositionHistory() = default; - PositionHistory(const PositionHistory& other) = default; + PositionHistory(const PositionHistory& other) { + positions_.reserve( + std::max(other.positions_.size() + 1, other.positions_.capacity())); + positions_ = other.positions_; + } PositionHistory(PositionHistory&& other) = default; - PositionHistory& operator=(const PositionHistory& other) = default; - PositionHistory& operator=(PositionHistory&& other) = default; + PositionHistory& operator=(const PositionHistory& other) { + if (this == &other) return *this; + positions_.clear(); + positions_.reserve( + std::max(other.positions_.size() + 1, other.positions_.capacity())); + positions_ = other.positions_; + return *this; + } + PositionHistory& operator=(PositionHistory&& other) = default; // Returns first position of the game (or fen from which it was initialized). const Position& Starting() const { return positions_.front(); } diff --git a/src/engine.cc b/src/engine.cc index a632ebb0ad..fe267a55b4 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -96,7 +96,7 @@ void EngineController::PopulateOptions(OptionsParser* options) { NetworkFactory::PopulateOptions(options); options->Add(kThreadsOptionId, 1, 128) = kDefaultThreads; - options->Add(kNNCacheSizeId, 0, 999999999) = 2000000; + options->Add(kNNCacheSizeId, 0, 999999999) = 2000; SearchParams::Populate(options); options->Add(kSyzygyTablebaseId); @@ -132,9 +132,11 @@ void EngineController::UpdateFromUciOptions() { if (!syzygy_tb_->init(tb_paths)) { CERR << "Failed to load Syzygy tablebases!"; syzygy_tb_ = nullptr; - } else { - tb_paths_ = tb_paths; } + tb_paths_ = tb_paths; + } else if (tb_paths.empty()) { + syzygy_tb_ = nullptr; + tb_paths_.clear(); } // Network. @@ -189,7 +191,7 @@ Position EngineController::ApplyPositionMoves() { board.SetFromFen(current_position_.fen, &no_capture_ply, &game_move); int game_ply = 2 * game_move - (board.flipped() ? 1 : 2); Position pos(board, no_capture_ply, game_ply); - for (std::string move_str: current_position_.moves) { + for (std::string move_str : current_position_.moves) { Move move(move_str); if (pos.IsBlackToMove()) move.Mirror(); pos = Position(pos, move); @@ -204,7 +206,7 @@ void EngineController::SetupPosition( UpdateFromUciOptions(); - if (!tree_) tree_ = std::make_unique(); + if (!tree_) tree_ = std::make_unique(options_); std::vector moves; for (const auto& move : moves_str) moves.emplace_back(move); @@ -294,7 +296,7 @@ void EngineController::Go(const GoParams& params) { auto stopper = time_manager_->GetStopper(params, *tree_.get()); search_ = std::make_unique( - *tree_, network_.get(), std::move(responder), + tree_.get(), network_.get(), std::move(responder), StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()), *move_start_time_, std::move(stopper), params.infinite || params.ponder, options_, &cache_, syzygy_tb_.get()); diff --git a/src/lc0ctl/describenet.cc b/src/lc0ctl/describenet.cc index f67997f1fc..634589f66b 100644 --- a/src/lc0ctl/describenet.cc +++ b/src/lc0ctl/describenet.cc @@ -96,6 +96,20 @@ void ShowNetworkFormatInfo(const pblczero::Net& weights) { COUT << Justify("MLH") << NetworkFormat::MovesLeftFormat_Name(net_format.moves_left()); } + if (net_format.has_default_activation()) { + COUT << Justify("Default activation") + << NetworkFormat::DefaultActivation_Name( + net_format.default_activation()); + } + if (net_format.has_smolgen_activation()) { + COUT << Justify("Smolgen activation") + << NetworkFormat::ActivationFunction_Name( + net_format.smolgen_activation()); + } + if (net_format.has_ffn_activation()) { + COUT << Justify("FFN activation") + << NetworkFormat::ActivationFunction_Name(net_format.ffn_activation()); + } } void ShowNetworkTrainingInfo(const pblczero::Net& weights) { @@ -124,33 +138,84 @@ void ShowNetworkTrainingInfo(const pblczero::Net& weights) { } } +void ShowNetworkWeightsBodyInfo(const pblczero::Net& weights) { + const auto& w = weights.weights(); + if (w.encoder_size() > 0) { + COUT << Justify("Encoders") << w.encoder_size(); + COUT << Justify("Encoder heads") << w.headcount(); + COUT << Justify("Embedding size") << w.ip_emb_b().params().size() / 2; + COUT << Justify("Dmodel") << w.encoder(0).mha().q_b().params().size() / 2; + COUT << Justify("Encoder DFF") + << w.encoder(0).ffn().dense1_b().params().size() / 2; + } else { + COUT << Justify("Blocks") << w.residual_size(); + int se_count = 0; + for (size_t i = 0; i < w.residual_size(); i++) { + if (w.residual(i).has_se()) se_count++; + } + if (se_count > 0) { + COUT << Justify("SE blocks") << se_count; + } + COUT << Justify("Filters") + << w.input().weights().params().size() / 2 / 112 / 9; + } +} + +void ShowNetworkWeightsPolicyInfo(const pblczero::Net& weights) { + using pblczero::NetworkFormat; + const auto& w = weights.weights(); + const auto& format = weights.format().network_format(); + auto pol_activation = NetworkFormat::ACTIVATION_DEFAULT; + if (format.policy() == NetworkFormat::POLICY_ATTENTION) { + // Non-attentionbody nets use hardcoded SELU as policy activation and FFN + // activations. + auto ffn_activation = format.ffn_activation(); + if (w.encoder_size() == 0) { + pol_activation = NetworkFormat::ACTIVATION_SELU; + ffn_activation = NetworkFormat::ACTIVATION_SELU; + } + + COUT << Justify("Policy") << "Attention"; + COUT << Justify("Policy activation") + << NetworkFormat::ActivationFunction_Name(pol_activation); + + if (w.pol_encoder_size() > 0) { + COUT << Justify("Policy encoders") << w.pol_encoder_size(); + COUT << Justify("Policy encoder heads") << w.pol_headcount(); + COUT << Justify("Policy encoder Dmodel") + << w.pol_encoder(0).mha().q_b().params().size() / 2; + COUT << Justify("Policy encoder DFF") + << w.pol_encoder(0).ffn().dense1_b().params().size() / 2; + COUT << Justify("Policy FFN activation") + << NetworkFormat::ActivationFunction_Name(ffn_activation); + } + COUT << Justify("Policy Dmodel") << w.ip2_pol_b().params().size() / 2; + } else { + COUT << Justify("Policy") << (w.has_policy1() ? "Convolution" : "Dense"); + COUT << Justify("Policy activation") + << NetworkFormat::ActivationFunction_Name(pol_activation); + if (!w.has_policy1()) { + int policy_channels = w.policy().biases().params().size() / 2; + if (policy_channels == 0) { + policy_channels = w.policy().bn_means().params().size() / 2; + } + COUT << Justify("Policy channels") << policy_channels; + } + } +} + void ShowNetworkWeightsInfo(const pblczero::Net& weights) { if (!weights.has_weights()) return; COUT << "\nWeights"; COUT << "~~~~~~~"; + + ShowNetworkWeightsBodyInfo(weights); + ShowNetworkWeightsPolicyInfo(weights); + const auto& w = weights.weights(); - COUT << Justify("Blocks") << w.residual_size(); - int se_count = 0; - for (size_t i = 0; i < w.residual_size(); i++) { - if (w.residual(i).has_se()) se_count++; - } - if (se_count > 0) { - COUT << Justify("SE blocks") << se_count; - } - - COUT << Justify("Filters") - << w.input().weights().params().size() / 2 / 112 / 9; - COUT << Justify("Policy") << (w.has_policy1() ? "Convolution" : "Dense"); - if (!w.has_policy1()) { - int policy_channels = w.policy().biases().params().size() / 2; - if (policy_channels == 0) { - policy_channels = w.policy().bn_means().params().size() / 2; - } - COUT << Justify("Policy channels") << policy_channels; - } COUT << Justify("Value") << (w.ip2_val_w().params().size() / 2 % 3 == 0 ? "WDL" : "Classical"); - COUT << Justify("MLH") << (w.has_moves_left() ? "Present" : "Absent"); + COUT << Justify("MLH") << (w.has_ip2_mov_w() ? "Present" : "Absent"); } void ShowNetworkOnnxInfo(const pblczero::Net& weights, diff --git a/src/mcts/node.cc b/src/mcts/node.cc index 9bda7f144a..8547805e0e 100644 --- a/src/mcts/node.cc +++ b/src/mcts/node.cc @@ -27,95 +27,24 @@ #include "mcts/node.h" +#include + #include #include #include #include #include +#include +#include #include #include +#include -#include "neural/encoder.h" -#include "neural/network.h" #include "utils/exception.h" #include "utils/hashcat.h" namespace lczero { -///////////////////////////////////////////////////////////////////////// -// Node garbage collector -///////////////////////////////////////////////////////////////////////// - -namespace { -// Periodicity of garbage collection, milliseconds. -const int kGCIntervalMs = 100; - -// Every kGCIntervalMs milliseconds release nodes in a separate GC thread. -class NodeGarbageCollector { - public: - NodeGarbageCollector() : gc_thread_([this]() { Worker(); }) {} - - // Takes ownership of a subtree, to dispose it in a separate thread when - // it has time. - void AddToGcQueue(std::unique_ptr node, size_t solid_size = 0) { - if (!node) return; - Mutex::Lock lock(gc_mutex_); - subtrees_to_gc_.emplace_back(std::move(node)); - subtrees_to_gc_solid_size_.push_back(solid_size); - } - - ~NodeGarbageCollector() { - // Flips stop flag and waits for a worker thread to stop. - stop_.store(true); - gc_thread_.join(); - } - - private: - void GarbageCollect() { - while (!stop_.load()) { - // Node will be released in destructor when mutex is not locked. - std::unique_ptr node_to_gc; - size_t solid_size = 0; - { - // Lock the mutex and move last subtree from subtrees_to_gc_ into - // node_to_gc. - Mutex::Lock lock(gc_mutex_); - if (subtrees_to_gc_.empty()) return; - node_to_gc = std::move(subtrees_to_gc_.back()); - subtrees_to_gc_.pop_back(); - solid_size = subtrees_to_gc_solid_size_.back(); - subtrees_to_gc_solid_size_.pop_back(); - } - // Solid is a hack... - if (solid_size != 0) { - for (size_t i = 0; i < solid_size; i++) { - node_to_gc.get()[i].~Node(); - } - std::allocator alloc; - alloc.deallocate(node_to_gc.release(), solid_size); - } - } - } - - void Worker() { - while (!stop_.load()) { - std::this_thread::sleep_for(std::chrono::milliseconds(kGCIntervalMs)); - GarbageCollect(); - }; - } - - mutable Mutex gc_mutex_; - std::vector> subtrees_to_gc_ GUARDED_BY(gc_mutex_); - std::vector subtrees_to_gc_solid_size_ GUARDED_BY(gc_mutex_); - - // When true, Worker() should stop and exit. - std::atomic stop_{false}; - std::thread gc_thread_; -}; - -NodeGarbageCollector gNodeGc; -} // namespace - ///////////////////////////////////////////////////////////////////////// // Edge ///////////////////////////////////////////////////////////////////////// @@ -188,115 +117,102 @@ std::unique_ptr Edge::FromMovelist(const MoveList& moves) { } ///////////////////////////////////////////////////////////////////////// -// Node +// LowNode + Node ///////////////////////////////////////////////////////////////////////// -Node* Node::CreateSingleChildNode(Move move) { - assert(!edges_); - assert(!child_); - edges_ = Edge::FromMovelist({move}); - num_edges_ = 1; - child_ = std::make_unique(this, 0); - return child_.get(); +// Put @low_node at the end of TT @gc_queue, if both @gc_queue and @low_node +// are not null and &low_node is TT and about to become parent-less (has only +// one parent). +static void TTGCEnqueue(GCQueue* gc_queue, const LowNode* low_node) { + if (gc_queue && low_node && low_node->IsTT() && + low_node->GetNumParents() == 1) + gc_queue->push_back(low_node->GetHash()); } -void Node::CreateEdges(const MoveList& moves) { - assert(!edges_); - assert(!child_); - edges_ = Edge::FromMovelist(moves); - num_edges_ = moves.size(); -} +void Node::Trim(GCQueue* gc_queue) { + wl_ = 0.0f; + + TTGCEnqueue(gc_queue, low_node_); + UnsetLowNode(); + // sibling_ -Node::ConstIterator Node::Edges() const { - return {*this, !solid_children_ ? &child_ : nullptr}; + d_ = 0.0f; + m_ = 0.0f; + n_ = 0; + n_in_flight_ = 0; + + // edge_ + + // index_ + + terminal_type_ = Terminal::NonTerminal; + lower_bound_ = GameResult::BLACK_WON; + upper_bound_ = GameResult::WHITE_WON; + repetition_ = false; } -Node::Iterator Node::Edges() { - return {*this, !solid_children_ ? &child_ : nullptr}; + +Node* Node::GetChild() const { + if (!low_node_) return nullptr; + return low_node_->GetChild()->get(); } +bool Node::HasChildren() const { return low_node_ && low_node_->HasChildren(); } + float Node::GetVisitedPolicy() const { float sum = 0.0f; - for (auto* node : VisitedNodes()) sum += GetEdgeToNode(node)->GetP(); + for (auto* node : VisitedNodes()) sum += node->GetP(); return sum; } -Edge* Node::GetEdgeToNode(const Node* node) const { - assert(node->parent_ == this); - assert(node->index_ < num_edges_); - return &edges_[node->index_]; +uint32_t Node::GetNInFlight() const { + return n_in_flight_.load(std::memory_order_acquire); +} + +uint32_t Node::GetChildrenVisits() const { + return low_node_ ? low_node_->GetChildrenVisits() : 0; } -Edge* Node::GetOwnEdge() const { return GetParent()->GetEdgeToNode(this); } +uint32_t Node::GetTotalVisits() const { + return low_node_ ? low_node_->GetN() : 0; +} + +const Edge& LowNode::GetEdgeAt(uint16_t index) const { return edges_[index]; } std::string Node::DebugString() const { std::ostringstream oss; - oss << " Term:" << static_cast(terminal_type_) << " This:" << this - << " Parent:" << parent_ << " Index:" << index_ - << " Child:" << child_.get() << " Sibling:" << sibling_.get() - << " WL:" << wl_ << " N:" << n_ << " N_:" << n_in_flight_ - << " Edges:" << static_cast(num_edges_) + oss << " This:" << this << " LowNode:" << low_node_ + << " Index:" << index_ << " Move:" << GetMove().as_string() + << " Sibling:" << sibling_.get() << " P:" << GetP() << " WL:" << wl_ + << " D:" << d_ << " M:" << m_ << " N:" << n_ << " N_:" << n_in_flight_ + << " Term:" << static_cast(terminal_type_) << " Bounds:" << static_cast(lower_bound_) - 2 << "," - << static_cast(upper_bound_) - 2 - << " Solid:" << solid_children_; + << static_cast(upper_bound_) - 2; return oss.str(); } -bool Node::MakeSolid() { - if (solid_children_ || num_edges_ == 0 || IsTerminal()) return false; - // Can only make solid if no immediate leaf childredn are in flight since we - // allow the search code to hold references to leaf nodes across locks. - Node* old_child_to_check = child_.get(); - uint32_t total_in_flight = 0; - while (old_child_to_check != nullptr) { - if (old_child_to_check->GetN() <= 1 && - old_child_to_check->GetNInFlight() > 0) { - return false; - } - if (old_child_to_check->IsTerminal() && - old_child_to_check->GetNInFlight() > 0) { - return false; - } - total_in_flight += old_child_to_check->GetNInFlight(); - old_child_to_check = old_child_to_check->sibling_.get(); - } - // If the total of children in flight is not the same as self, then there are - // collisions against immediate children (which don't update the GetNInFlight - // of the leaf) and its not safe. - if (total_in_flight != GetNInFlight()) { - return false; - } - std::allocator alloc; - auto* new_children = alloc.allocate(num_edges_); - for (int i = 0; i < num_edges_; i++) { - new (&(new_children[i])) Node(this, i); - } - std::unique_ptr old_child = std::move(child_); - while (old_child) { - int index = old_child->index_; - new_children[index] = std::move(*old_child.get()); - // This isn't needed, but it helps crash things faster if something has gone wrong. - old_child->parent_ = nullptr; - gNodeGc.AddToGcQueue(std::move(old_child)); - new_children[index].UpdateChildrenParents(); - old_child = std::move(new_children[index].sibling_); - } - // This is a hack. - child_ = std::unique_ptr(new_children); - solid_children_ = true; - return true; +std::string LowNode::DebugString() const { + std::ostringstream oss; + oss << " This:" << this << " Hash:" << hash_ + << " Edges:" << edges_.get() + << " NumEdges:" << static_cast(num_edges_) + << " Child:" << child_.get() << " WL:" << wl_ << " D:" << d_ + << " M:" << m_ << " N:" << n_ << " NP:" << num_parents_ + << " Term:" << static_cast(terminal_type_) + << " Bounds:" << static_cast(lower_bound_) - 2 << "," + << static_cast(upper_bound_) - 2 + << " IsTransposition:" << is_transposition; + return oss.str(); } -void Node::SortEdges() { - assert(edges_); - assert(!child_); +void Edge::SortEdges(Edge* edges, int num_edges) { // Sorting on raw p_ is the same as sorting on GetP() as a side effect of // the encoding, and its noticeably faster. - std::sort(edges_.get(), (edges_.get() + num_edges_), + std::sort(edges, (edges + num_edges), [](const Edge& a, const Edge& b) { return a.p_ > b.p_; }); } -void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) { - if (type != Terminal::TwoFold) SetBounds(result, result); +void LowNode::MakeTerminal(GameResult result, float plies_left, Terminal type) { + SetBounds(result, result); terminal_type_ = type; m_ = plies_left; if (result == GameResult::DRAW) { @@ -308,34 +224,105 @@ void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) { } else if (result == GameResult::BLACK_WON) { wl_ = -1.0f; d_ = 0.0f; - // Terminal losses have no uncertainty and no reason for their U value to be - // comparable to another non-loss choice. Force this by clearing the policy. - if (GetParent() != nullptr) GetOwnEdge()->SetP(0.0f); } + + assert(WLDMInvariantsHold()); } -void Node::MakeNotTerminal() { +void LowNode::MakeNotTerminal(const Node* node) { + assert(edges_); + if (!IsTerminal()) return; + terminal_type_ = Terminal::NonTerminal; + lower_bound_ = GameResult::BLACK_WON; + upper_bound_ = GameResult::WHITE_WON; n_ = 0; + wl_ = 0.0; + d_ = 0.0; + m_ = 0.0; - // If we have edges, we've been extended (1 visit), so include children too. - if (edges_) { - n_++; - for (const auto& child : Edges()) { + // Include children too. + if (node->GetNumEdges() > 0) { + for (const auto& child : node->Edges()) { const auto n = child.GetN(); if (n > 0) { n_ += n; // Flip Q for opponent. // Default values don't matter as n is > 0. - wl_ += -child.GetWL(0.0f) * n; + wl_ += child.GetWL(0.0f) * n; d_ += child.GetD(0.0f) * n; + m_ += child.GetM(0.0f) * n; } } // Recompute with current eval (instead of network's) and children's eval. wl_ /= n_; d_ /= n_; + m_ /= n_; + } + + assert(WLDMInvariantsHold()); +} + +void LowNode::SetBounds(GameResult lower, GameResult upper) { + lower_bound_ = lower; + upper_bound_ = upper; +} + +uint8_t Node::GetNumEdges() const { + return low_node_ ? low_node_->GetNumEdges() : 0; +} + +void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) { + SetBounds(result, result); + terminal_type_ = type; + m_ = plies_left; + if (result == GameResult::DRAW) { + wl_ = 0.0f; + d_ = 1.0f; + } else if (result == GameResult::WHITE_WON) { + wl_ = 1.0f; + d_ = 0.0f; + } else if (result == GameResult::BLACK_WON) { + wl_ = -1.0f; + d_ = 0.0f; + // Terminal losses have no uncertainty and no reason for their U value to be + // comparable to another non-loss choice. Force this by clearing the policy. + SetP(0.0f); + } + + assert(WLDMInvariantsHold()); +} + +void Node::MakeNotTerminal(bool also_low_node) { + // At least one of node and low node pair needs to be a terminal. + if (!IsTerminal() && + (!also_low_node || !low_node_ || !low_node_->IsTerminal())) + return; + + terminal_type_ = Terminal::NonTerminal; + repetition_ = false; + if (low_node_) { // Two-fold or derived terminal. + // Revert low node first. + if (also_low_node && low_node_) low_node_->MakeNotTerminal(this); + + auto [lower_bound, upper_bound] = low_node_->GetBounds(); + lower_bound_ = -upper_bound; + upper_bound_ = -lower_bound; + n_ = low_node_->GetN(); + wl_ = -low_node_->GetWL(); + d_ = low_node_->GetD(); + m_ = low_node_->GetM() + 1; + } else { // Real terminal. + lower_bound_ = GameResult::BLACK_WON; + upper_bound_ = GameResult::WHITE_WON; + n_ = 0.0f; + wl_ = 0.0f; + d_ = 0.0f; + m_ = 0.0f; } + + assert(WLDMInvariantsHold()); } void Node::SetBounds(GameResult lower, GameResult upper) { @@ -344,108 +331,323 @@ void Node::SetBounds(GameResult lower, GameResult upper) { } bool Node::TryStartScoreUpdate() { - if (n_ == 0 && n_in_flight_ > 0) return false; - ++n_in_flight_; + if (n_ > 0) { + n_in_flight_.fetch_add(1, std::memory_order_acq_rel); + } else { + uint32_t expected_n_if_flight_ = 0; + if (!n_in_flight_.compare_exchange_strong(expected_n_if_flight_, 1, + std::memory_order_acq_rel)) { + return false; + } + } + return true; } -void Node::CancelScoreUpdate(int multivisit) { - n_in_flight_ -= multivisit; +void Node::CancelScoreUpdate(uint32_t multivisit) { + assert(GetNInFlight() >= (uint32_t)multivisit); + n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel); +} + +void LowNode::FinalizeScoreUpdate(float v, float d, float m, + uint32_t multivisit) { + assert(edges_); + // Recompute Q. + wl_ += multivisit * (v - wl_) / (n_ + multivisit); + d_ += multivisit * (d - d_) / (n_ + multivisit); + m_ += multivisit * (m - m_) / (n_ + multivisit); + + assert(WLDMInvariantsHold()); + + // Increment N. + n_ += multivisit; +} + +void LowNode::AdjustForTerminal(float v, float d, float m, + uint32_t multivisit) { + assert(static_cast(multivisit) <= n_); + + // Recompute Q. + wl_ += multivisit * v / n_; + d_ += multivisit * d / n_; + m_ += multivisit * m / n_; + + assert(WLDMInvariantsHold()); } -void Node::FinalizeScoreUpdate(float v, float d, float m, int multivisit) { +void Node::FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit) { // Recompute Q. wl_ += multivisit * (v - wl_) / (n_ + multivisit); d_ += multivisit * (d - d_) / (n_ + multivisit); m_ += multivisit * (m - m_) / (n_ + multivisit); + assert(WLDMInvariantsHold()); + // Increment N. n_ += multivisit; // Decrement virtual loss. - n_in_flight_ -= multivisit; + assert(GetNInFlight() >= (uint32_t)multivisit); + n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel); } -void Node::AdjustForTerminal(float v, float d, float m, int multivisit) { +void Node::AdjustForTerminal(float v, float d, float m, uint32_t multivisit) { + assert(static_cast(multivisit) <= n_); + // Recompute Q. wl_ += multivisit * v / n_; d_ += multivisit * d / n_; m_ += multivisit * m / n_; + + assert(WLDMInvariantsHold()); } -void Node::RevertTerminalVisits(float v, float d, float m, int multivisit) { - // Compute new n_ first, as reducing a node to 0 visits is a special case. - const int n_new = n_ - multivisit; - if (n_new <= 0) { - // If n_new == 0, reset all relevant values to 0. - wl_ = 0.0; - d_ = 1.0; - m_ = 0.0; - n_ = 0; - } else { - // Recompute Q and M. - wl_ -= multivisit * (v - wl_) / n_new; - d_ -= multivisit * (d - d_) / n_new; - m_ -= multivisit * (m - m_) / n_new; - // Decrement N. - n_ -= multivisit; +void Node::IncrementNInFlight(uint32_t multivisit) { + n_in_flight_.fetch_add(multivisit, std::memory_order_acq_rel); +} + +void LowNode::ReleaseChildren(GCQueue* gc_queue) { + for (auto child = GetChild()->get(); child != nullptr; + child = child->GetSibling()->get()) { + TTGCEnqueue(gc_queue, child->GetLowNode()); } + child_.reset(); } -void Node::UpdateChildrenParents() { - if (!solid_children_) { - Node* cur_child = child_.get(); - while (cur_child != nullptr) { - cur_child->parent_ = this; - cur_child = cur_child->sibling_.get(); - } - } else { - Node* child_array = child_.get(); - for (int i = 0; i < num_edges_; i++) { - child_array[i].parent_ = this; +void LowNode::ReleaseChildrenExceptOne(Node* node_to_save, GCQueue* gc_queue) { + // Stores node which will have to survive (or nullptr if it's not found). + atomic_unique_ptr saved_node; + // Pointer to unique_ptr, so that we could move from it. + for (auto node = &child_; *node != nullptr; node = (*node)->GetSibling()) { + // If current node is the one that we have to save. + if (node->get() == node_to_save) { + // Save the node, and take the ownership from the unique_ptr. + saved_node = std::move(*node); + node = &saved_node; + } else { + TTGCEnqueue(gc_queue, (*node)->GetLowNode()); } } + // Kill all remaining siblings. + saved_node->GetSibling()->reset(); + // Make saved node the only child. (kills previous siblings). + child_ = std::move(saved_node); } -void Node::ReleaseChildren() { - gNodeGc.AddToGcQueue(std::move(child_), solid_children_ ? num_edges_ : 0); +void Node::ReleaseChildrenExceptOne(Node* node_to_save, + GCQueue* gc_queue) const { + // Sometime we have no graph yet or a reverted terminal without low node. + if (low_node_) low_node_->ReleaseChildrenExceptOne(node_to_save, gc_queue); } -void Node::ReleaseChildrenExceptOne(Node* node_to_save) { - if (solid_children_) { - std::unique_ptr saved_node; - if (node_to_save != nullptr) { - saved_node = std::make_unique(this, node_to_save->index_); - *saved_node = std::move(*node_to_save); - } - gNodeGc.AddToGcQueue(std::move(child_), num_edges_); - child_ = std::move(saved_node); - if (child_) { - child_->UpdateChildrenParents(); +void Node::SetLowNode(LowNode* low_node) { + assert(!low_node_); + low_node->AddParent(); + low_node_ = low_node; +} +void Node::UnsetLowNode() { + if (low_node_) low_node_->RemoveParent(); + low_node_ = nullptr; +} + +static std::string PtrToNodeName(const void* ptr) { + std::ostringstream oss; + oss << "n_" << ptr; + return oss.str(); +} + +std::string LowNode::DotNodeString() const { + std::ostringstream oss; + oss << PtrToNodeName(this) << " [" + << "shape=box"; + // Adjust formatting to limit node size. + oss << std::fixed << std::setprecision(3); + oss << ",label=\"" // + << std::showpos // + << "WL=" << wl_ // + << std::noshowpos // + << "\\lD=" << d_ << "\\lM=" << m_ << "\\lN=" << n_ << "\\l\""; + // Set precision for tooltip. + oss << std::fixed << std::showpos << std::setprecision(5); + oss << ",tooltip=\"" // + << std::showpos // + << "WL=" << wl_ // + << std::noshowpos // + << "\\nD=" << d_ << "\\nM=" << m_ << "\\nN=" << n_ + << "\\nNP=" << num_parents_ + << "\\nTerm=" << static_cast(terminal_type_) // + << std::showpos // + << "\\nBounds=" << static_cast(lower_bound_) - 2 << "," + << static_cast(upper_bound_) - 2 + << "\\nIsTransposition=" << is_transposition // + << std::noshowpos // + << "\\n\\nThis=" << this << "\\nEdges=" << edges_.get() + << "\\nNumEdges=" << static_cast(num_edges_) + << "\\nChild=" << child_.get() << "\\n\""; + oss << "];"; + return oss.str(); +} + +std::string Node::DotEdgeString(bool as_opponent, const LowNode* parent) const { + std::ostringstream oss; + oss << (parent == nullptr ? "top" : PtrToNodeName(parent)) << " -> " + << (low_node_ ? PtrToNodeName(low_node_) : PtrToNodeName(this)) << " ["; + oss << "label=\"" + << (parent == nullptr ? "N/A" : GetMove(as_opponent).as_string()) + << "\\lN=" << n_ << "\\lN_=" << n_in_flight_; + oss << "\\l\""; + // Set precision for tooltip. + oss << std::fixed << std::setprecision(5); + oss << ",labeltooltip=\"" + << "P=" << (parent == nullptr ? 0.0f : GetP()) // + << std::showpos // + << "\\nWL= " << wl_ // + << std::noshowpos // + << "\\nD=" << d_ << "\\nM=" << m_ << "\\nN=" << n_ + << "\\nN_=" << n_in_flight_ + << "\\nTerm=" << static_cast(terminal_type_) // + << std::showpos // + << "\\nBounds=" << static_cast(lower_bound_) - 2 << "," + << static_cast(upper_bound_) - 2 << "\\n\\nThis=" << this // + << std::noshowpos // + << "\\nLowNode=" << low_node_ << "\\nParent=" << parent + << "\\nIndex=" << index_ << "\\nSibling=" << sibling_.get() << "\\n\""; + oss << "];"; + return oss.str(); +} + +std::string Node::DotGraphString(bool as_opponent) const { + std::ostringstream oss; + std::unordered_set seen; + std::list> unvisited_fifo; + + oss << "strict digraph {" << std::endl; + oss << "edge [" + << "headport=n" + << ",tooltip=\" \"" // Remove default tooltips from edge parts. + << "];" << std::endl; + oss << "node [" + << "shape=point" // For fake nodes. + << ",style=filled" // Show tooltip everywhere on the node. + << ",fillcolor=ivory" + << "];" << std::endl; + oss << "ranksep=" << 4.0f * std::log10(GetN()) << std::endl; + + oss << DotEdgeString(!as_opponent) << std::endl; + if (low_node_) { + seen.insert(low_node_); + unvisited_fifo.push_back(std::pair(this, as_opponent)); + } + + while (!unvisited_fifo.empty()) { + auto [parent_node, parent_as_opponent] = unvisited_fifo.front(); + unvisited_fifo.pop_front(); + + auto parent_low_node = parent_node->GetLowNode(); + seen.insert(parent_low_node); + oss << parent_low_node->DotNodeString() << std::endl; + + for (auto& child_edge : parent_node->Edges()) { + auto child = child_edge.node(); + if (child == nullptr) break; + + oss << child->DotEdgeString(parent_as_opponent) << std::endl; + auto child_low_node = child->GetLowNode(); + if (child_low_node != nullptr && + (seen.find(child_low_node) == seen.end())) { + seen.insert(child_low_node); + unvisited_fifo.push_back(std::pair(child, !parent_as_opponent)); + } } - solid_children_ = false; - } else { - // Stores node which will have to survive (or nullptr if it's not found). - std::unique_ptr saved_node; - // Pointer to unique_ptr, so that we could move from it. - for (std::unique_ptr* node = &child_; *node; - node = &(*node)->sibling_) { - // If current node is the one that we have to save. - if (node->get() == node_to_save) { - // Kill all remaining siblings. - gNodeGc.AddToGcQueue(std::move((*node)->sibling_)); - // Save the node, and take the ownership from the unique_ptr. - saved_node = std::move(*node); - break; + } + + oss << "}" << std::endl; + + return oss.str(); +} + +bool Node::ZeroNInFlight() const { + std::unordered_set seen; + std::list unvisited_fifo; + size_t nonzero_node_count = 0; + + if (GetNInFlight() > 0) { + std::cerr << DebugString() << std::endl; + ++nonzero_node_count; + } + if (low_node_) { + seen.insert(low_node_); + unvisited_fifo.push_back(this); + } + + while (!unvisited_fifo.empty()) { + auto parent_node = unvisited_fifo.front(); + unvisited_fifo.pop_front(); + + for (auto& child_edge : parent_node->Edges()) { + auto child = child_edge.node(); + if (child == nullptr) break; + + if (child->GetNInFlight() > 0) { + std::cerr << child->DebugString() << std::endl; + ++nonzero_node_count; + } + + auto child_low_node = child->GetLowNode(); + if (child_low_node != nullptr && + (seen.find(child_low_node) == seen.end())) { + seen.insert(child_low_node); + unvisited_fifo.push_back(child); } } - // Make saved node the only child. (kills previous siblings). - gNodeGc.AddToGcQueue(std::move(child_)); - child_ = std::move(saved_node); } - if (!child_) { - num_edges_ = 0; - edges_.reset(); // Clear edges list. + + if (nonzero_node_count > 0) { + std::cerr << "GetNInFlight() is nonzero on " << nonzero_node_count + << " nodes" << std::endl; + return false; } + + return true; +} + +void Node::SortEdges() const { + assert(low_node_); + low_node_->SortEdges(); +} + +uint64_t Node::GetHash() const { + if (low_node_) { + return low_node_->GetHash(); + } else { + return 0; + } +} +bool Node::IsTT() const { return low_node_ && low_node_->IsTT(); } + +static constexpr float wld_tolerance = 0.000001f; +static constexpr float m_tolerance = 0.000001f; + +static bool WLDMInvariantsHold(float wl, float d, float m) { + return -(1.0f + wld_tolerance) < wl && wl < (1.0f + wld_tolerance) && // + -(0.0f + wld_tolerance) < d && d < (1.0f + wld_tolerance) && // + -(0.0f + m_tolerance) < m && // + std::abs(wl + d) < (1.0f + wld_tolerance); +} + +bool Node::WLDMInvariantsHold() const { + if (lczero::WLDMInvariantsHold(GetWL(), GetD(), GetM())) return true; + + std::cerr << DebugString() << std::endl; + + return false; +} + +bool LowNode::WLDMInvariantsHold() const { + if (lczero::WLDMInvariantsHold(GetWL(), GetD(), GetM())) return true; + + std::cerr << DebugString() << std::endl; + + return false; } ///////////////////////////////////////////////////////////////////////// @@ -465,33 +667,64 @@ std::string EdgeAndNode::DebugString() const { void NodeTree::MakeMove(Move move) { if (HeadPosition().IsBlackToMove()) move.Mirror(); const auto& board = HeadPosition().GetBoard(); + auto hash = GetHistoryHash(history_); + move = board.GetModernMove(move); // TODO: Why convert here? + // Find edge for @move, if it exists. Node* new_head = nullptr; - for (auto& n : current_head_->Edges()) { - if (board.IsSameMove(n.GetMove(), move)) { - new_head = n.GetOrSpawnNode(current_head_); - // Ensure head is not terminal, so search can extend or visit children of - // "terminal" positions, e.g., WDL hits, converted terminals, 3-fold draw. - if (new_head->IsTerminal()) new_head->MakeNotTerminal(); - break; + while (new_head == nullptr) { + for (auto& n : current_head_->Edges()) { + if (board.IsSameMove(n.GetMove(), move)) { + new_head = n.GetOrSpawnNode(current_head_); + // Ensure head is not terminal, so search can extend or visit children + // of "terminal" positions, e.g., WDL hits, converted terminals, 3-fold + // draw. + if (new_head->IsTerminal()) new_head->MakeNotTerminal(); + break; + } + } + + if (new_head != nullptr) break; + + // Current head node (if any) is non-TT, does not have a matching edge and + // will be removed by NonTTMaintenance later. + current_head_->UnsetLowNode(); + + // Check TT first, then create, if necessary. + auto tt_iter = tt_.find(hash); + if (tt_iter != tt_.end()) { + current_head_->SetLowNode(tt_iter->second.get()); + if (current_head_->IsTerminal()) current_head_->MakeNotTerminal(); + } else { + non_tt_.emplace_back(std::make_unique(hash, MoveList({move}), + static_cast(0))); + current_head_->SetLowNode(non_tt_.back().get()); } } - move = board.GetModernMove(move); - current_head_->ReleaseChildrenExceptOne(new_head); - new_head = current_head_->child_.get(); - current_head_ = - new_head ? new_head : current_head_->CreateSingleChildNode(move); + + // Remove edges that will not be needed any more. + current_head_->ReleaseChildrenExceptOne(new_head, &gc_queue_); + new_head = current_head_->GetChild(); + + // Move damaged node from TT to non-TT to avoid reuse. + // It can have TT parents, until they get garbage collected. + if (current_head_->IsTT()) { + auto tt_iter = tt_.find(current_head_->GetHash()); + tt_iter->second->ClearTT(); + non_tt_.emplace_back(std::move(tt_iter->second)); + tt_.erase(tt_iter); + } + + current_head_ = new_head; + history_.Append(move); + moves_.push_back(move); } void NodeTree::TrimTreeAtHead() { - // If solid, this will be empty before move and will be moved back empty - // afterwards which is fine. - auto tmp = std::move(current_head_->sibling_); - // Send dependent nodes for GC instead of destroying them immediately. - current_head_->ReleaseChildren(); - *current_head_ = Node(current_head_->GetParent(), current_head_->index_); - current_head_->sibling_ = std::move(tmp); + current_head_->Trim(&gc_queue_); + // Free unused non-TT low nodes. + NonTTMaintenance(); } bool NodeTree::ResetToPosition(const std::string& starting_fen, @@ -508,11 +741,12 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen, } if (!gamebegin_node_) { - gamebegin_node_ = std::make_unique(nullptr, 0); + gamebegin_node_ = std::make_unique(0); } history_.Reset(starting_board, no_capture_ply, full_moves * 2 - (starting_board.flipped() ? 1 : 2)); + moves_.clear(); Node* old_head = current_head_; current_head_ = gamebegin_node_.get(); @@ -522,6 +756,9 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen, if (old_head == current_head_) seen_old_head = true; } + // Remove any non-TT nodes that were not reused. + NonTTMaintenance(); + // MakeMove guarantees that no siblings exist; but, if we didn't see the old // head, it means we might have a position that was an ancestor to a // previously searched position, which means that the current_head_ might @@ -532,11 +769,82 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen, } void NodeTree::DeallocateTree() { - // Same as gamebegin_node_.reset(), but actual deallocation will happen in - // GC thread. - gNodeGc.AddToGcQueue(std::move(gamebegin_node_)); - gamebegin_node_ = nullptr; + gamebegin_node_.reset(); current_head_ = nullptr; + // Free all nodes. + // There may be non-TT children of TT nodes that were not garbage collected + // fast enough. + NonTTMaintenance(); + TTClear(); + non_tt_.clear(); + gc_queue_.clear(); +} + +LowNode* NodeTree::TTFind(uint64_t hash) { + auto tt_iter = tt_.find(hash); + if (tt_iter != tt_.end()) { + return tt_iter->second.get(); + } else { + return nullptr; + } +} + +std::pair NodeTree::TTGetOrCreate(uint64_t hash) { + auto [tt_iter, is_tt_miss] = + tt_.insert({hash, std::make_unique(hash)}); + return {tt_iter->second.get(), is_tt_miss}; +} + +void NodeTree::TTMaintenance() { TTGCSome(0); } + +void NodeTree::TTClear() { + // Make sure destructors don't fail. + absl::c_for_each( + tt_, [](const auto& item) { item.second->ReleaseChildren(nullptr); }); + // Remove any released non-TT children of TT nodes that were not garbage + // collected fast enough. + NonTTMaintenance(); + tt_.clear(); + gc_queue_.clear(); +} + +LowNode* NodeTree::NonTTAddClone(const LowNode& node) { + non_tt_.push_back(std::make_unique(node)); + return non_tt_.back().get(); +} + +void NodeTree::NonTTMaintenance() { + // Release children of parent-less nodes. + absl::c_for_each(non_tt_, [this](const auto& item) { + if (item->GetNumParents() == 0) item->ReleaseChildren(&gc_queue_); + }); + // Erase parent-less nodes. + for (auto item = non_tt_.begin(); item != non_tt_.end();) { + if ((*item)->GetNumParents() == 0) { + item = non_tt_.erase(item); + } else { + ++item; + } + } +} + +bool NodeTree::TTGCSome(size_t count) { + if (gc_queue_.empty()) return false; + + for (auto n = count > 0 ? std::min(count, gc_queue_.size()) + : gc_queue_.size(); + n > 0; --n) { + auto hash = gc_queue_.front(); + gc_queue_.pop_front(); + auto tt_iter = tt_.find(hash); + if (tt_iter != tt_.end()) { + if (tt_iter->second->GetNumParents() == 0) { + tt_.erase(tt_iter); + } + } + } + + return gc_queue_.empty(); } } // namespace lczero diff --git a/src/mcts/node.h b/src/mcts/node.h index 2982de24da..324013b0ab 100644 --- a/src/mcts/node.h +++ b/src/mcts/node.h @@ -27,8 +27,12 @@ #pragma once +#include + #include #include +#include +#include #include #include #include @@ -36,26 +40,26 @@ #include "chess/board.h" #include "chess/callbacks.h" #include "chess/position.h" -#include "neural/cache.h" -#include "neural/encoder.h" -#include "proto/net.pb.h" +#include "mcts/params.h" #include "utils/mutex.h" namespace lczero { -// Children of a node are stored the following way: -// * Edges and Nodes edges point to are stored separately. -// * There may be dangling edges (which don't yet point to any Node object yet) -// * Edges are stored are a simple array on heap. -// * Nodes are stored as a linked list, and contain index_ field which shows -// which edge of a parent that node points to. -// Or they are stored a contiguous array of Node objects on the heap if -// solid_children_ is true. If the children have been 'solidified' their -// sibling links are unused and left empty. In this state there are no -// dangling edges, but the nodes may not have ever received any visits. +// Terminology: +// * Edge - a potential edge with a move and policy information. +// * Node - an existing edge with number of visits and evaluation. +// * LowNode - a node with number of visits, evaluation and edges. +// +// Storage: +// * Potential edges are stored in a simple array inside the LowNode as edges_. +// * Existing edges are stored in a linked list starting with a child_ pointer +// in the LowNode and continuing with a sibling_ pointer in each Node. +// * Existing edges have a copy of their potential edge counterpart, index_ +// among potential edges and are linked to the target LowNode via the +// low_node_ pointer. // // Example: -// Parent Node +// LowNode // | // +-------------+-------------+----------------+--------------+ // | | | | | @@ -64,21 +68,103 @@ namespace lczero { // Node, Q=0.5 Node, Q=-0.2 // // Is represented as: -// +--------------+ -// | Parent Node | -// +--------------+ +--------+ -// | edges_ | -------------------------------------> | Edge[] | -// | | +------------+ +--------+ -// | child_ | -> | Node | | Nf3 | -// +--------------+ +------------+ | Bc5 | -// | index_ = 1 | | a4 | -// | q_ = 0.5 | +------------+ | Qxf7 | -// | sibling_ | -> | Node | | a3 | -// +------------+ +------------+ +--------+ -// | index_ = 3 | -// | q_ = -0.2 | -// | sibling_ | -> nullptr -// +------------+ +// +-----------------+ +// | LowNode | +// +-----------------+ +--------+ +// | edges_ | -------------------------------------> | Edge[] | +// | | +------------+ +--------+ +// | child_ | -> | Node | | Nf3 | +// | | +------------+ | Bc5 | +// | ... | | edge_ | | a4 | +// | | | index_ = 1 | | Qxf7 | +// | | | q_ = 0.5 | +------------+ | a3 | +// | | | sibling_ | -> | Node | +--------+ +// | | +------------+ +------------+ +// | | | edge_ | +// +-----------------+ | index_ = 3 | +// | q_ = -0.2 | +// | sibling_ | -> nullptr +// +------------+ + +// Define __i386__ or __arm__ also for 32 bit Windows. +#if defined(_M_IX86) +#define __i386__ +#endif +#if defined(_M_ARM) && !defined(_M_AMD64) +#define __arm__ +#endif + +// Atomic unique_ptr based on the public domain code from +// https://stackoverflow.com/a/42811152 . +template +class atomic_unique_ptr { + using pointer = T*; + using unique_pointer = std::unique_ptr; + + public: + // Manage no pointer. + constexpr atomic_unique_ptr() noexcept : ptr() {} + + // Make pointer @p managed. + explicit atomic_unique_ptr(pointer p) noexcept : ptr(p) {} + + // Move the managed pointer ownership from another atomic_unique_ptr. + atomic_unique_ptr(atomic_unique_ptr&& p) noexcept : ptr(p.release()) {} + // Move the managed pointer ownership from another atomic_unique_ptr. + atomic_unique_ptr& operator=(atomic_unique_ptr&& p) noexcept { + reset(p.release()); + return *this; + } + + // Move the managed object ownership from a unique_ptr. + atomic_unique_ptr(unique_pointer&& p) noexcept : ptr(p.release()) {} + // Move the managed object ownership from a unique_ptr. + atomic_unique_ptr& operator=(unique_pointer&& p) noexcept { + reset(p.release()); + return *this; + } + + // Replace the managed pointer, deleting the old one. + void reset(pointer p = pointer()) noexcept { + auto old = ptr.exchange(p, std::memory_order_acquire); + if (old) delete old; + } + // Release ownership of and delete the owned pointer. + ~atomic_unique_ptr() { reset(); } + + // Returns the managed pointer. + operator pointer() const noexcept { return ptr; } + // Returns the managed pointer. + pointer operator->() const noexcept { return ptr; } + // Returns the managed pointer. + pointer get() const noexcept { return ptr; } + + // Checks whether there is a managed pointer. + explicit operator bool() const noexcept { return ptr != pointer(); } + + // Replace the managed pointer, only releasing returning the old one. + pointer set(pointer p = pointer()) noexcept { + return ptr.exchange(p, std::memory_order_acquire); + } + // Return the managed pointer and release its ownership. + pointer release() noexcept { return set(pointer()); } + + // Move managed pointer from @source, iff the managed pointer equals + // @expected. + bool compare_exchange(pointer expected, + atomic_unique_ptr& source) noexcept { + if (ptr.compare_exchange_strong(expected, source.ptr, + std::memory_order_acq_rel)) { + source.release(); + return true; + } else { + return false; + } + } + + private: + std::atomic ptr; +}; class Node; class Edge { @@ -98,6 +184,8 @@ class Edge { // Debug information about the edge. std::string DebugString() const; + static void SortEdges(Edge* edges, int num_edges); + private: // Move corresponding to this node. From the point of view of a player, // i.e. black's e7e5 is stored as e2e4. @@ -116,6 +204,29 @@ struct Eval { float ml; }; +struct NNEval { + // To minimize the number of padding bytes and to avoid having unnecessary + // padding when new fields are added, we arrange the fields by size, largest + // to smallest. + + // 8 byte fields on 64-bit platforms, 4 byte on 32-bit. + // Array of edges. + std::unique_ptr edges; + + // 4 byte fields. + float q = 0.0f; + float d = 0.0f; + float m = 0.0f; + + // 1 byte fields. + // Number of edges in @edges. + uint8_t num_edges = 0; +}; + +typedef std::pair Bounds; + +enum class Terminal : uint8_t { NonTerminal, EndOfGame, Tablebase }; + class EdgeAndNode; template class Edge_Iterator; @@ -123,47 +234,56 @@ class Edge_Iterator; template class VisitedNode_Iterator; +class LowNode; +typedef std::list GCQueue; class Node { public: using Iterator = Edge_Iterator; using ConstIterator = Edge_Iterator; - enum class Terminal : uint8_t { NonTerminal, EndOfGame, Tablebase, TwoFold }; - - // Takes pointer to a parent node and own index in a parent. - Node(Node* parent, uint16_t index) - : parent_(parent), + // Takes own @index in the parent. + Node(uint16_t index) + : index_(index), + terminal_type_(Terminal::NonTerminal), + lower_bound_(GameResult::BLACK_WON), + upper_bound_(GameResult::WHITE_WON), + repetition_(false) {} + // Takes own @edge and @index in the parent. + Node(const Edge& edge, uint16_t index) + : edge_(edge), index_(index), terminal_type_(Terminal::NonTerminal), lower_bound_(GameResult::BLACK_WON), upper_bound_(GameResult::WHITE_WON), - solid_children_(false) {} - - // We have a custom destructor, but its behavior does not need to be emulated - // during move operations so default is fine. - Node(Node&& move_from) = default; - Node& operator=(Node&& move_from) = default; - - // Allocates a new edge and a new node. The node has to be no edges before - // that. - Node* CreateSingleChildNode(Move m); - - // Creates edges from a movelist. There has to be no edges before that. - void CreateEdges(const MoveList& moves); - - // Gets parent node. - Node* GetParent() const { return parent_; } + repetition_(false) {} + ~Node() { UnsetLowNode(); } + + // Trim node, resetting everything except parent, sibling, edge and index. + void Trim(GCQueue* gc_queue); + + // Get first child. + Node* GetChild() const; + // Get next sibling. + atomic_unique_ptr* GetSibling() { return &sibling_; } + // Moves sibling in. + void MoveSiblingIn(atomic_unique_ptr& sibling) { + sibling_ = std::move(sibling); + } // Returns whether a node has children. - bool HasChildren() const { return static_cast(edges_); } + bool HasChildren() const; // Returns sum of policy priors which have had at least one playout. float GetVisitedPolicy() const; uint32_t GetN() const { return n_; } - uint32_t GetNInFlight() const { return n_in_flight_; } - uint32_t GetChildrenVisits() const { return n_ > 0 ? n_ - 1 : 0; } - // Returns n = n_if_flight. - int GetNStarted() const { return n_ + n_in_flight_; } + uint32_t GetNInFlight() const; + uint32_t GetChildrenVisits() const; + uint32_t GetTotalVisits() const; + // Returns n + n_in_flight. + int GetNStarted() const { + return n_ + n_in_flight_.load(std::memory_order_acquire); + } + float GetQ(float draw_score) const { return wl_ + draw_score * d_; } // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1 // for terminal nodes. @@ -174,25 +294,16 @@ class Node { // Returns whether the node is known to be draw/lose/win. bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; } bool IsTbTerminal() const { return terminal_type_ == Terminal::Tablebase; } - bool IsTwoFoldTerminal() const { return terminal_type_ == Terminal::TwoFold; } - typedef std::pair Bounds; Bounds GetBounds() const { return {lower_bound_, upper_bound_}; } - uint8_t GetNumEdges() const { return num_edges_; } - // Output must point to at least max_needed floats. - void CopyPolicy(int max_needed, float* output) const { - if (!edges_) return; - int loops = std::min(static_cast(num_edges_), max_needed); - for (int i = 0; i < loops; i++) { - output[i] = edges_[i].GetP(); - } - } + uint8_t GetNumEdges() const; // Makes the node terminal and sets it's score. - void MakeTerminal(GameResult result, float plies_left = 0.0f, + void MakeTerminal(GameResult result, float plies_left = 1.0f, Terminal type = Terminal::EndOfGame); - // Makes the node not terminal and updates its visits. - void MakeNotTerminal(); + // Makes the node not terminal and recomputes bounds, visits and values. + // Changes low node as well unless @also_low_node is false. + void MakeNotTerminal(bool also_low_node = true); void SetBounds(GameResult lower, GameResult upper); // If this node is not in the process of being expanded by another thread @@ -201,28 +312,19 @@ class Node { // Otherwise return false. bool TryStartScoreUpdate(); // Decrements n-in-flight back. - void CancelScoreUpdate(int multivisit); + void CancelScoreUpdate(uint32_t multivisit); // Updates the node with newly computed value v. // Updates: // * Q (weighted average of all V in a subtree) - // * N (+=1) - // * N-in-flight (-=1) - void FinalizeScoreUpdate(float v, float d, float m, int multivisit); + // * N (+=multivisit) + // * N-in-flight (-=multivisit) + void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit); // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount. - void AdjustForTerminal(float v, float d, float m, int multivisit); - // Revert visits to a node which ended in a now reverted terminal. - void RevertTerminalVisits(float v, float d, float m, int multivisit); + void AdjustForTerminal(float v, float d, float m, uint32_t multivisit); // When search decides to treat one visit as several (in case of collisions // or visiting terminal nodes several times), it amplifies the visit by // incrementing n_in_flight. - void IncrementNInFlight(int multivisit) { n_in_flight_ += multivisit; } - - // Updates max depth, if new depth is larger. - void UpdateMaxDepth(int depth); - - // Calculates the full depth if new depth is larger, updates it, returns - // in depth parameter, and returns true if it was indeed updated. - bool UpdateFullDepth(uint16_t* depth); + void IncrementNInFlight(uint32_t multivisit); // Returns range for iterating over edges. ConstIterator Edges() const; @@ -232,48 +334,55 @@ class Node { VisitedNode_Iterator VisitedNodes() const; VisitedNode_Iterator VisitedNodes(); - // Deletes all children. - void ReleaseChildren(); - // Deletes all children except one. // The node provided may be moved, so should not be relied upon to exist // afterwards. - void ReleaseChildrenExceptOne(Node* node); + void ReleaseChildrenExceptOne(Node* node_to_save, GCQueue* gc_queue) const; - // For a child node, returns corresponding edge. - Edge* GetEdgeToNode(const Node* node) const; + // Returns move from the point of view of the player making it (if as_opponent + // is false) or as opponent (if as_opponent is true). + Move GetMove(bool as_opponent = false) const { + return edge_.GetMove(as_opponent); + } + // Returns or sets value of Move policy prior returned from the neural net + // (but can be changed by adding Dirichlet noise or when turning terminal). + // Must be in [0,1]. + float GetP() const { return edge_.GetP(); } + void SetP(float val) { edge_.SetP(val); } - // Returns edge to the own node. - Edge* GetOwnEdge() const; + LowNode* GetLowNode() const { return low_node_; } + + void SetLowNode(LowNode* low_node); + void UnsetLowNode(); // Debug information about the node. std::string DebugString() const; + // Return string describing the edge from node's parent to its low node in the + // Graphviz dot format. + std::string DotEdgeString(bool as_opponent = false, + const LowNode* parent = nullptr) const; + // Return string describing the graph starting at this node in the Graphviz + // dot format. + std::string DotGraphString(bool as_opponent = false) const; - // Reallocates this nodes children to be in a solid block, if possible and not - // already done. Returns true if the transformation was performed. - bool MakeSolid(); + // Returns true if graph under this node has every n_in_flight_ == 0 and + // prints offending nodes and low nodes and stats to cerr otherwise. + bool ZeroNInFlight() const; - void SortEdges(); + void SortEdges() const; - // Index in parent edges - useful for correlated ordering. + // Index in parent's edges - useful for correlated ordering. uint16_t Index() const { return index_; } - ~Node() { - if (solid_children_ && child_) { - // As a hack, solid_children is actually storing an array in here, release - // so we can correctly invoke the array delete. - for (int i = 0; i < num_edges_; i++) { - child_.get()[i].~Node(); - } - std::allocator alloc; - alloc.deallocate(child_.release(), num_edges_); - } - } + void SetRepetition() { repetition_ = true; } + bool IsRepetition() const { return repetition_; } - private: - // For each child, ensures that its parent pointer is pointing to this. - void UpdateChildrenParents(); + uint64_t GetHash() const; + bool IsTT() const; + bool WLDMInvariantsHold() const; + + private: // To minimize the number of padding bytes and to avoid having unnecessary // padding when new fields are added, we arrange the fields by size, largest // to smallest. @@ -281,22 +390,16 @@ class Node { // 8 byte fields. // Average value (from value head of neural network) of all visited nodes in // subtree. For terminal nodes, eval is stored. This is from the perspective - // of the player who "just" moved to reach this position, rather than from the - // perspective of the player-to-move for the position. - // WL stands for "W minus L". Is equal to Q if draw score is 0. + // of the player who "just" moved to reach this position, rather than from + // the perspective of the player-to-move for the position. WL stands for "W + // minus L". Is equal to Q if draw score is 0. double wl_ = 0.0f; // 8 byte fields on 64-bit platforms, 4 byte on 32-bit. - // Array of edges. - std::unique_ptr edges_; - // Pointer to a parent node. nullptr for the root. - Node* parent_ = nullptr; - // Pointer to a first child. nullptr for a leaf node. - // As a 'hack' actually a unique_ptr to Node[] if solid_children. - std::unique_ptr child_; + // Pointer to the low node. + LowNode* low_node_ = nullptr; // Pointer to a next sibling. nullptr if there are no further siblings. - // Also null in the solid case. - std::unique_ptr sibling_; + atomic_unique_ptr sibling_; // 4 byte fields. // Averaged draw probability. Works similarly to WL, except that D is not @@ -309,48 +412,228 @@ class Node { // (AKA virtual loss.) How many threads currently process this node (started // but not finished). This value is added to n during selection which node // to pick in MCTS, and also when selecting the best move. - uint32_t n_in_flight_ = 0; + std::atomic n_in_flight_ = 0; + + // Move and policy for this edge. + Edge edge_; // 2 byte fields. // Index of this node is parent's edge list. uint16_t index_; + // 1 byte fields. + // Bit fields using parts of uint8_t fields initialized in the constructor. + // Whether or not this node end game (with a winning of either sides or + // draw). + Terminal terminal_type_ : 2; + // Best and worst result for this node. + GameResult lower_bound_ : 2; + GameResult upper_bound_ : 2; + // Edge was handled as a repetition at some point. + bool repetition_ : 1; +}; + +// Check that Node still fits into an expected cache line size. +static_assert(sizeof(Node) <= 64, "Node is too large"); + +class LowNode { + public: + // For TT nodes. + LowNode(uint64_t hash) + : hash_(hash), + terminal_type_(Terminal::NonTerminal), + lower_bound_(GameResult::BLACK_WON), + upper_bound_(GameResult::WHITE_WON), + is_transposition(false), + is_tt_(true) {} + // Init from another low node, but use it for NNEval only. + // For non-TT nodes. + LowNode(const LowNode& p) + : wl_(p.wl_), + hash_(p.hash_), + d_(p.d_), + m_(p.m_), + num_edges_(p.num_edges_), + terminal_type_(Terminal::NonTerminal), + lower_bound_(GameResult::BLACK_WON), + upper_bound_(GameResult::WHITE_WON), + is_transposition(false), + is_tt_(false) { + assert(p.edges_); + edges_ = std::make_unique(num_edges_); + std::memcpy(edges_.get(), p.edges_.get(), num_edges_ * sizeof(Edge)); + } + // Init @edges_ with moves from @moves and 0 policy. + // Also create the first child at @index. + // For non-TT nodes. + LowNode(uint64_t hash, const MoveList& moves, uint16_t index) + : hash_(hash), + num_edges_(moves.size()), + terminal_type_(Terminal::NonTerminal), + lower_bound_(GameResult::BLACK_WON), + upper_bound_(GameResult::WHITE_WON), + is_transposition(false), + is_tt_(false) { + edges_ = Edge::FromMovelist(moves); + child_ = std::make_unique(edges_[index], index); + } + + void SetNNEval(const NNEval* eval) { + assert(!edges_); + assert(n_ == 0); + assert(!child_); + + edges_ = std::make_unique(eval->num_edges); + std::memcpy(edges_.get(), eval->edges.get(), + eval->num_edges * sizeof(Edge)); + + wl_ = eval->q; + d_ = eval->d; + m_ = eval->m; + + assert(WLDMInvariantsHold()); + + num_edges_ = eval->num_edges; + } + + // Gets the first child. + atomic_unique_ptr* GetChild() { return &child_; } + + // Returns whether a node has children. + bool HasChildren() const { return num_edges_ > 0; } + + uint32_t GetN() const { return n_; } + uint32_t GetChildrenVisits() const { return n_ - 1; } + + // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1 + // for terminal nodes. + float GetWL() const { return wl_; } + float GetD() const { return d_; } + float GetM() const { return m_; } + + // Returns whether the node is known to be draw/loss/win. + bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; } + Bounds GetBounds() const { return {lower_bound_, upper_bound_}; } + Terminal GetTerminalType() const { return terminal_type_; } + + uint8_t GetNumEdges() const { return num_edges_; } + // Gets pointer to the start of the edge array. + Edge* GetEdges() const { return edges_.get(); } + + // Makes the node terminal and sets it's score. + void MakeTerminal(GameResult result, float plies_left = 0.0f, + Terminal type = Terminal::EndOfGame); + // Makes the low node not terminal and recomputes bounds, visits and values + // using incoming @node. + void MakeNotTerminal(const Node* node); + void SetBounds(GameResult lower, GameResult upper); + + // Decrements n-in-flight back. + void CancelScoreUpdate(uint32_t multivisit); + // Updates the node with newly computed value v. + // Updates: + // * Q (weighted average of all V in a subtree) + // * N (+=multivisit) + // * N-in-flight (-=multivisit) + void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit); + // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount. + void AdjustForTerminal(float v, float d, float m, uint32_t multivisit); + + // Deletes all children. + void ReleaseChildren(GCQueue* gc_queue); + + // Deletes all children except one. + // The node provided may be moved, so should not be relied upon to exist + // afterwards. + void ReleaseChildrenExceptOne(Node* node_to_save, GCQueue* gc_queue); + + // Return move policy for edge/node at @index. + const Edge& GetEdgeAt(uint16_t index) const; + + // Debug information about the node. + std::string DebugString() const; + // Return string describing this node in the Graphviz dot format. + std::string DotNodeString() const; + + void SortEdges() { + assert(edges_); + assert(!child_); + Edge::SortEdges(edges_.get(), num_edges_); + } + + // Add new parent with @n_in_flight visits. + void AddParent() { + ++num_parents_; + + assert(num_parents_ > 0); + + is_transposition |= num_parents_ > 1; + } + // Remove parent and its first visit. + void RemoveParent() { + assert(num_parents_ > 0); + --num_parents_; + } + uint16_t GetNumParents() const { return num_parents_; } + bool IsTransposition() const { return is_transposition; } + + uint64_t GetHash() const { return hash_; } + bool IsTT() const { return is_tt_; } + void ClearTT() { is_tt_ = !is_tt_; } + + bool WLDMInvariantsHold() const; + + private: + // To minimize the number of padding bytes and to avoid having unnecessary + // padding when new fields are added, we arrange the fields by size, largest + // to smallest. + + // 8 byte fields. + // Average value (from value head of neural network) of all visited nodes in + // subtree. For terminal nodes, eval is stored. This is from the perspective + // of the player who "just" moved to reach this position, rather than from the + // perspective of the player-to-move for the position. + // WL stands for "W minus L". Is equal to Q if draw score is 0. + double wl_ = 0.0f; + // Position hash and a TT key. + uint64_t hash_ = 0; + + // 8 byte fields on 64-bit platforms, 4 byte on 32-bit. + // Array of edges. + std::unique_ptr edges_; + // Pointer to the first child. nullptr when no children. + atomic_unique_ptr child_; + + // 4 byte fields. + // Averaged draw probability. Works similarly to WL, except that D is not + // flipped depending on the side to move. + float d_ = 0.0f; + // Estimated remaining plies. + float m_ = 0.0f; + // How many completed visits this node had. + uint32_t n_ = 0; + + // 2 byte fields. + // Number of parents. + uint16_t num_parents_ = 0; + // 1 byte fields. // Number of edges in @edges_. uint8_t num_edges_ = 0; - // Bit fields using parts of uint8_t fields initialized in the constructor. // Whether or not this node end game (with a winning of either sides or draw). Terminal terminal_type_ : 2; // Best and worst result for this node. GameResult lower_bound_ : 2; GameResult upper_bound_ : 2; - // Whether the child_ is actually an array of equal length to edges. - bool solid_children_ : 1; - - // TODO(mooskagh) Unfriend NodeTree. - friend class NodeTree; - friend class Edge_Iterator; - friend class Edge_Iterator; - friend class Edge; - friend class VisitedNode_Iterator; - friend class VisitedNode_Iterator; + // Low node is a transposition (for ever). + bool is_transposition : 1; + // Low node is in TT, i.e. was not evaluated or was modified. + bool is_tt_ : 1; }; -// Define __i386__ or __arm__ also for 32 bit Windows. -#if defined(_M_IX86) -#define __i386__ -#endif -#if defined(_M_ARM) && !defined(_M_AMD64) -#define __arm__ -#endif - -// A basic sanity check. This must be adjusted when Node members are adjusted. -#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__)) -static_assert(sizeof(Node) == 48, "Unexpected size of Node for 32bit compile"); -#else -static_assert(sizeof(Node) == 64, "Unexpected size of Node"); -#endif +// Check that LowNode still fits into an expected cache line size. +static_assert(sizeof(LowNode) <= 64, "LowNode is too large"); // Contains Edge and Node pair and set of proxy functions to simplify access // to them. @@ -391,13 +674,15 @@ class EdgeAndNode { // Whether the node is known to be terminal. bool IsTerminal() const { return node_ ? node_->IsTerminal() : false; } bool IsTbTerminal() const { return node_ ? node_->IsTbTerminal() : false; } - Node::Bounds GetBounds() const { + Bounds GetBounds() const { return node_ ? node_->GetBounds() - : Node::Bounds{GameResult::BLACK_WON, GameResult::WHITE_WON}; + : Bounds{GameResult::BLACK_WON, GameResult::WHITE_WON}; } // Edge related getters. - float GetP() const { return edge_->GetP(); } + float GetP() const { + return node_ != nullptr ? node_->GetP() : edge_->GetP(); + } Move GetMove(bool flip = false) const { return edge_ ? edge_->GetMove(flip) : Move(); } @@ -436,21 +721,20 @@ class EdgeAndNode { template class Edge_Iterator : public EdgeAndNode { public: - using Ptr = std::conditional_t*, - std::unique_ptr*>; + using Ptr = std::conditional_t*, + atomic_unique_ptr*>; // Creates "end()" iterator. Edge_Iterator() {} - // Creates "begin()" iterator. Also happens to be a range constructor. - // child_ptr will be nullptr if parent_node is solid children. - Edge_Iterator(const Node& parent_node, Ptr child_ptr) - : EdgeAndNode(parent_node.edges_.get(), nullptr), - node_ptr_(child_ptr), - total_count_(parent_node.num_edges_) { - if (edge_ && child_ptr != nullptr) Actualize(); - if (edge_ && child_ptr == nullptr) { - node_ = parent_node.child_.get(); + // Creates "begin()" iterator. + Edge_Iterator(LowNode* parent_node) + : EdgeAndNode(parent_node != nullptr ? parent_node->GetEdges() : nullptr, + nullptr) { + if (parent_node != nullptr) { + node_ptr_ = parent_node->GetChild(); + total_count_ = parent_node->GetNumEdges(); + if (edge_) Actualize(); } } @@ -466,11 +750,7 @@ class Edge_Iterator : public EdgeAndNode { edge_ = nullptr; } else { ++edge_; - if (node_ptr_ != nullptr) { - Actualize(); - } else { - ++node_; - } + Actualize(); } } Edge_Iterator& operator*() { return *this; } @@ -478,27 +758,43 @@ class Edge_Iterator : public EdgeAndNode { // If there is node, return it. Otherwise spawn a new one and return it. Node* GetOrSpawnNode(Node* parent) { if (node_) return node_; // If there is already a node, return it. - // Should never reach here in solid mode. - assert(node_ptr_ != nullptr); - Actualize(); // But maybe other thread already did that. - if (node_) return node_; // If it did, return. - // Now we are sure we have to create a new node. - // Suppose there are nodes with idx 3 and 7, and we want to insert one with - // idx 5. Here is how it looks like: - // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.7) - // Here is how we do that: - // 1. Store pointer to a node idx_.7: - // node_ptr_ -> &Node(idx_.3).sibling_ -> nullptr - // tmp -> Node(idx_.7) - std::unique_ptr tmp = std::move(*node_ptr_); - // 2. Create fresh Node(idx_.5): - // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.5) - // tmp -> Node(idx_.7) - *node_ptr_ = std::make_unique(parent, current_idx_); - // 3. Attach stored pointer back to a list: - // node_ptr_ -> - // &Node(idx_.3).sibling_ -> Node(idx_.5).sibling_ -> Node(idx_.7) - (*node_ptr_)->sibling_ = std::move(tmp); + + // We likely need to add a new node, prepare it now. + atomic_unique_ptr new_node = std::make_unique( + parent->GetLowNode()->GetEdgeAt(current_idx_), current_idx_); + while (true) { + auto node = Actualize(); // But maybe other thread already did that. + if (node_) return node_; // If it did, return. + + // New node needs to be added, but we might be in a race with another + // thread doing what we do or adding a different index to the same + // sibling. + + // Suppose there are nodes with idx 3 and 7, and we want to insert one + // with idx 5. Here is how it looks like: + // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.7) + // Here is how we do that: + // 1. Store pointer to a node idx_.7: + // node_ptr_ -> &Node(idx_.3).sibling_ -> nullptr + // tmp -> Node(idx_.7) + // 2. Create fresh Node(idx_.5): + // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.5) + // tmp -> Node(idx_.7) + // 3. Attach stored pointer back to a list: + // node_ptr_ -> + // &Node(idx_.3).sibling_ -> Node(idx_.5).sibling_ -> Node(idx_.7) + + // Atomically add the new node into the right place. + // Set new node's sibling to the expected sibling seen by Actualize in + // node_ptr_. + auto new_sibling = new_node->GetSibling(); + new_sibling->set(node); + // Try to atomically insert the new node and stop if it works. + if (node_ptr_->compare_exchange(node, new_node)) break; + // Recover from failure and try again. + // Release expected sibling to avoid double free. + new_sibling->release(); + } // 4. Actualize: // node_ -> &Node(idx_.5) // node_ptr_ -> &Node(idx_.5).sibling_ -> Node(idx_.7) @@ -507,24 +803,30 @@ class Edge_Iterator : public EdgeAndNode { } private: - void Actualize() { - // This must never be called in solid mode. - assert(node_ptr_ != nullptr); + // Moves node_ptr_ as close as possible to the target index and returns the + // contents of node_ptr_ for use by atomic insert in GetOrSpawnNode. + Node* Actualize() { // If node_ptr_ is behind, advance it. // This is needed (and has to be 'while' rather than 'if') as other threads // could spawn new nodes between &node_ptr_ and *node_ptr_ while we didn't // see. - while (*node_ptr_ && (*node_ptr_)->index_ < current_idx_) { - node_ptr_ = &(*node_ptr_)->sibling_; + // Read the direct pointer just once as other threads may change it between + // uses. + auto node = node_ptr_->get(); + while (node != nullptr && node->Index() < current_idx_) { + node_ptr_ = node->GetSibling(); + node = node_ptr_->get(); } // If in the end node_ptr_ points to the node that we need, populate node_ // and advance node_ptr_. - if (*node_ptr_ && (*node_ptr_)->index_ == current_idx_) { - node_ = (*node_ptr_).get(); - node_ptr_ = &node_->sibling_; + if (node != nullptr && node->Index() == current_idx_) { + node_ = node; + node_ptr_ = node->GetSibling(); } else { node_ = nullptr; } + + return node; } // Pointer to a pointer to the next node. Has to be a pointer to pointer @@ -534,6 +836,9 @@ class Edge_Iterator : public EdgeAndNode { uint16_t total_count_ = 0; }; +inline Node::ConstIterator Node::Edges() const { return {this->GetLowNode()}; } +inline Node::Iterator Node::Edges() { return {this->GetLowNode()}; } + // TODO(crem) Replace this with less hacky iterator once we support C++17. // This class has multiple hypostases within one class: // * Range (begin() and end() functions) @@ -549,16 +854,17 @@ class VisitedNode_Iterator { // Creates "end()" iterator. VisitedNode_Iterator() {} - // Creates "begin()" iterator. Also happens to be a range constructor. - // child_ptr will be nullptr if parent_node is solid children. - VisitedNode_Iterator(const Node& parent_node, Node* child_ptr) - : node_ptr_(child_ptr), - total_count_(parent_node.num_edges_), - solid_(parent_node.solid_children_) { - if (node_ptr_ != nullptr && node_ptr_->GetN() == 0) { - operator++(); + // Creates "begin()" iterator. + VisitedNode_Iterator(LowNode* parent_node) { + if (parent_node != nullptr) { + node_ptr_ = parent_node->GetChild()->get(); + total_count_ = parent_node->GetNumEdges(); + if (node_ptr_ != nullptr && node_ptr_->GetN() == 0) { + operator++(); + } } } + // These are technically wrong, but are usable to compare with end(). bool operator==(const VisitedNode_Iterator& other) const { return node_ptr_ == other.node_ptr_; @@ -574,72 +880,56 @@ class VisitedNode_Iterator { // Functions to support iterator interface. // Equality comparison operators are inherited from EdgeAndNode. void operator++() { - if (solid_) { - while (++current_idx_ != total_count_ && - node_ptr_[current_idx_].GetN() == 0) { - if (node_ptr_[current_idx_].GetNInFlight() == 0) { - // Once there is not even n in flight, we can skip to the end. This is - // due to policy being in sorted order meaning that additional n in - // flight are always selected from the front of the section with no n - // in flight or visited. - current_idx_ = total_count_; - break; - } - } - if (current_idx_ == total_count_) { + do { + node_ptr_ = node_ptr_->GetSibling()->get(); + // If n started is 0, can jump direct to end due to sorted policy + // ensuring that each time a new edge becomes best for the first time, + // it is always the first of the section at the end that has NStarted of + // 0. + if (node_ptr_ != nullptr && node_ptr_->GetN() == 0 && + node_ptr_->GetNInFlight() == 0) { node_ptr_ = nullptr; + break; } - } else { - do { - node_ptr_ = node_ptr_->sibling_.get(); - // If n started is 0, can jump direct to end due to sorted policy - // ensuring that each time a new edge becomes best for the first time, - // it is always the first of the section at the end that has NStarted of - // 0. - if (node_ptr_ != nullptr && node_ptr_->GetN() == 0 && - node_ptr_->GetNInFlight() == 0) { - node_ptr_ = nullptr; - break; - } - } while (node_ptr_ != nullptr && node_ptr_->GetN() == 0); - } - } - Node* operator*() { - if (solid_) { - return &(node_ptr_[current_idx_]); - } else { - return node_ptr_; - } + } while (node_ptr_ != nullptr && node_ptr_->GetN() == 0); } + Node* operator*() { return node_ptr_; } private: // Pointer to current node. Node* node_ptr_ = nullptr; uint16_t current_idx_ = 0; uint16_t total_count_ = 0; - bool solid_ = false; }; inline VisitedNode_Iterator Node::VisitedNodes() const { - return {*this, child_.get()}; + return {this->GetLowNode()}; } inline VisitedNode_Iterator Node::VisitedNodes() { - return {*this, child_.get()}; + return {this->GetLowNode()}; } class NodeTree { public: + // Transposition Table (TT) type for holding all normal low nodes in the DAG. + typedef absl::flat_hash_map> + TranspositionTable; + + // Apply search params. + NodeTree(const SearchParams& params) + : hash_history_length_(params.GetCacheHistoryLength() + 1) {} + // When search params are not available. + NodeTree() : hash_history_length_(1) {} ~NodeTree() { DeallocateTree(); } + // Adds a move to current_head_. void MakeMove(Move move); // Resets the current head to ensure it doesn't carry over details from a // previous search. void TrimTreeAtHead(); - // Sets the position in a tree, trying to reuse the tree. - // If @auto_garbage_collect, old tree is garbage collected immediately. (may - // take some milliseconds) - // Returns whether a new position the same game as old position (with some - // moves added). Returns false, if the position is completely different, + // Sets the position in the tree, trying to reuse the tree. + // Returns whether the new position is the same game as the old position (with + // some moves added). Returns false, if the position is completely different, // or if it's shorter than before. bool ResetToPosition(const std::string& starting_fen, const std::vector& moves); @@ -649,14 +939,60 @@ class NodeTree { Node* GetCurrentHead() const { return current_head_; } Node* GetGameBeginNode() const { return gamebegin_node_.get(); } const PositionHistory& GetPositionHistory() const { return history_; } + const std::vector& GetMoves() const { return moves_; } + + // Look up a low node in the Transposition Table by @hash and return it, or + // nullptr on failure. + LowNode* TTFind(uint64_t hash); + // Get a low node for the @hash from the Transposition Table or create a + // new low node and insert it into the Transposition Table if it is not there + // already. Return the low node for the hash. + std::pair TTGetOrCreate(uint64_t hash); + // Evict unused low nodes from the Transposition Table. + void TTMaintenance(); + // Clear the Transposition Table. + // NOTE: Safe only when non-TT nodes were already detached. + void TTClear(); + // Release the first @count items from TT GC list. @count == 0 means release + // all. Return true, if there is more to release. + bool TTGCSome(size_t count = 1); + + // Add a clone of low @node to special nodes outside of the Transposition + // Table and return it. + LowNode* NonTTAddClone(const LowNode& node); + + size_t AllocatedNodeCount() const { return tt_.size() + non_tt_.size(); }; + + // Get position hash used for TT nodes and NN cache. + uint64_t GetHistoryHash(const PositionHistory& history) const { + return history.HashLast(hash_history_length_); + } private: void DeallocateTree(); + + // Evict unused non-TT low nodes. + void NonTTMaintenance(); + // A node which to start search from. Node* current_head_ = nullptr; // Root node of a game tree. std::unique_ptr gamebegin_node_; PositionHistory history_; + std::vector moves_; + + // Transposition Table (TT) for holding references to all normal low nodes in + // the DAG. + TranspositionTable tt_; + // Collection of low nodes that are not fit for Transposition Table due to + // noise or incomplete information. + std::vector> non_tt_; + + // History positions to hash in node hashes used in TT and NN cache. + int hash_history_length_; + + // Garbage collection queue. + GCQueue gc_queue_; }; } // namespace lczero diff --git a/src/mcts/params.cc b/src/mcts/params.cc index 6fbc4cfb30..88a6eb14a5 100644 --- a/src/mcts/params.cc +++ b/src/mcts/params.cc @@ -28,8 +28,11 @@ #include "mcts/params.h" #include +#include +#include #include "utils/exception.h" +#include "utils/string.h" #if __has_include("params_override.h") #include "params_override.h" @@ -38,11 +41,8 @@ #ifndef DEFAULT_MINIBATCH_SIZE #define DEFAULT_MINIBATCH_SIZE 256 #endif -#ifndef DEFAULT_MAX_PREFETCH -#define DEFAULT_MAX_PREFETCH 32 -#endif #ifndef DEFAULT_TASK_WORKERS -#define DEFAULT_TASK_WORKERS 4 +#define DEFAULT_TASK_WORKERS 3 #endif namespace lczero { @@ -55,6 +55,97 @@ FillEmptyHistory EncodeHistoryFill(std::string history_fill) { return FillEmptyHistory::NO; } +ContemptPerspective EncodeContemptPerspective(std::string perspective) { + if (perspective == "sidetomove") return ContemptPerspective::STM; + if (perspective == "white") return ContemptPerspective::WHITE; + if (perspective == "black") return ContemptPerspective::BLACK; + assert(perspective == "none"); + return ContemptPerspective::NONE; +} + +float GetContempt(std::string name, std::string contempt_str) { + float contempt = 0; + for (auto& entry : StrSplit(contempt_str, ",")) { + auto parts = StrSplit(entry, "="); + if (parts.size() == 1) { + try { + contempt = std::stof(parts[0]); + } catch (std::exception& e) { + throw Exception("Invalid default contempt: " + entry); + } + } else if (parts.size() == 2) { + if (std::search(name.begin(), name.end(), parts[0].begin(), + parts[0].end(), [](unsigned char a, unsigned char b) { + return std::tolower(a) == std::tolower(b); + }) != name.end()) { + try { + contempt = std::stof(parts[1]); + } catch (std::exception& e) { + throw Exception("Invalid contempt entry: " + entry); + } + break; + } + } else { + throw Exception("Invalid contempt entry:" + entry); + } + } + return contempt; +} + +// Calculate ratio and diff for WDL conversion from the contempt settings. +// More accurate model, allowing book bias dependent Elo calculation. +// Doesn't take lower accuracy of opponent into account and needs clamping. +SearchParams::WDLRescaleParams AccurateWDLRescaleParams( + float contempt, float draw_rate_target, float draw_rate_reference, + float book_exit_bias, float contempt_max, float contempt_attenuation) { + float scale_target = + 1.0f / std::log((1.0f + draw_rate_target) / (1.0f - draw_rate_target)); + float scale_reference = 1.0f / std::log((1.0f + draw_rate_reference) / + (1.0f - draw_rate_reference)); + float ratio = scale_target / scale_reference; + float diff = + scale_target / (scale_reference * scale_reference) / + (1.0f / + std::pow(std::cosh(0.5f * (1 - book_exit_bias) / scale_target), 2) + + 1.0f / + std::pow(std::cosh(0.5f * (1 + book_exit_bias) / scale_target), 2)) * + std::log(10) / 200 * std::clamp(contempt, -contempt_max, contempt_max) * + contempt_attenuation; + return SearchParams::WDLRescaleParams(ratio, diff); +} + +// Calculate ratio and diff for WDL conversion from the contempt settings. +// Less accurate Elo model, but automatically chooses draw rate and accuracy +// based on the absolute Elo of both sides. Doesn't require clamping, but still +// uses the parameter. +SearchParams::WDLRescaleParams SimplifiedWDLRescaleParams( + float contempt, float draw_rate_reference, float elo_active, + float contempt_max, float contempt_attenuation) { + // Parameters for the Elo dependent draw rate and scaling: + const float scale_zero = 15.0f; + const float elo_slope = 425.0f; + const float offset = 6.75f; + + float scale_reference = 1.0f / std::log((1.0f + draw_rate_reference) / + (1.0f - draw_rate_reference)); + float elo_opp = + elo_active - std::clamp(contempt, -contempt_max, contempt_max); + float scale_active = + 1.0f / (1.0f / scale_zero + std::exp(elo_active / elo_slope - offset)); + float scale_opp = + 1.0f / (1.0f / scale_zero + std::exp(elo_opp / elo_slope - offset)); + float scale_target = + std::sqrt((scale_active * scale_active + scale_opp * scale_opp) / 2.0f); + float ratio = scale_target / scale_reference; + float mu_active = + -std::log(10) / 200 * scale_zero * elo_slope * + std::log(1.0f + std::exp(-elo_active / elo_slope + offset) / scale_zero); + float mu_opp = + -std::log(10) / 200 * scale_zero * elo_slope * + std::log(1.0f + std::exp(-elo_opp / elo_slope + offset) / scale_zero); + float diff = (mu_active - mu_opp) * contempt_attenuation; + return SearchParams::WDLRescaleParams(ratio, diff); +} } // namespace const OptionId SearchParams::kMiniBatchSizeId{ @@ -62,11 +153,6 @@ const OptionId SearchParams::kMiniBatchSizeId{ "How many positions the engine tries to batch together for parallel NN " "computation. Larger batches may reduce strength a bit, especially with a " "small number of playouts."}; -const OptionId SearchParams::kMaxPrefetchBatchId{ - "max-prefetch", "MaxPrefetch", - "When the engine cannot gather a large enough batch for immediate use, try " - "to prefetch up to X positions which are likely to be useful soon, and put " - "them into cache."}; const OptionId SearchParams::kCpuctId{ "cpuct", "CPuct", "cpuct_init constant from \"UCT search\" algorithm. Higher values promote " @@ -269,15 +355,52 @@ const OptionId SearchParams::kDrawScoreWhiteId{ const OptionId SearchParams::kDrawScoreBlackId{ "draw-score-black", "DrawScoreBlack", "Adjustment, added to a draw score of a black player."}; +const OptionId SearchParams::kContemptPerspectiveId{ + "contempt-perspective", "ContemptPerspective", + "Affects the way asymmetric WDL parameters are applied. Default is " + "'sidetomove' for matches, use 'white' and 'black' for analysis. Use " + "'none' to deactivate contempt and the WDL conversion."}; +const OptionId SearchParams::kContemptId{ + "contempt", "Contempt", + "The simulated Elo advantage for the WDL conversion. Comma separated " + "list in the form [name=]value, where the name is compared with the " + "`UCI_Opponent` value to find the appropriate contempt value. The default " + "value is taken from `UCI_RatingAdv` and will be overridden if either a " + "value without name is given, or if a name match is found."}; +const OptionId SearchParams::kContemptMaxValueId{ + "contempt-max-value", "ContemptMaxValue", + "The maximum value of contempt used. Higher values will be capped."}; +const OptionId SearchParams::kWDLCalibrationEloId{ + "wdl-calibration-elo", "WDLCalibrationElo", + "Elo of the active side, adjusted for time control relative to rapid."}; +const OptionId SearchParams::kWDLContemptAttenuationId{ + "wdl-contempt-attenuation", "WDLContemptAttenuation", + "Scales how Elo advantage is applied for contempt. Use 1.0 for realistic " + "analysis, and 0.5-0.6 for optimal match performance."}; +const OptionId SearchParams::kWDLEvalObjectivityId{ + "wdl-eval-objectivity", "WDLEvalObjectivity", + "When calculating the centipawn eval output, decides how objective/" + "contempt influenced the reported eval should be. Value 0.0 reports the " + "internally used WDL values, 1.0 attempts an objective eval."}; +const OptionId SearchParams::kWDLDrawRateTargetId{ + "wdl-draw-rate-target", "WDLDrawRateTarget", + "To define the accuracy of play, the target draw rate in equal " + "positions is used as a proxy."}; +const OptionId SearchParams::kWDLDrawRateReferenceId{ + "wdl-draw-rate-reference", "WDLDrawRateReference", + "Set this to the draw rate predicted by the used neural network at " + "default settings. The accuracy rescaling is done relative to the " + "reference draw rate."}; +const OptionId SearchParams::kWDLBookExitBiasId{ + "wdl-book-exit-bias", "WDLBookExitBias", + "The book exit bias used when measuring engine Elo. Value of startpos is " + "around 0.2, value of 50% white win is 1. Only relevant if target draw " + "rate is above 80%."}; const OptionId SearchParams::kNpsLimitId{ "nps-limit", "NodesPerSecondLimit", "An option to specify an upper limit to the nodes per second searched. The " "accuracy depends on the minibatch size used, increasing for lower sizes, " "and on the length of the search. Zero to disable."}; -const OptionId SearchParams::kSolidTreeThresholdId{ - "solid-tree-threshold", "SolidTreeThreshold", - "Only nodes with at least this number of visits will be considered for " - "solidification for improved cache locality."}; const OptionId SearchParams::kTaskWorkersPerSearchWorkerId{ "task-workers", "TaskWorkers", "The number of task workers to use to help the search worker."}; @@ -316,12 +439,19 @@ const OptionId SearchParams::kMaxCollisionVisitsScalingEndId{ const OptionId SearchParams::kMaxCollisionVisitsScalingPowerId{ "max-collision-visits-scaling-power", "MaxCollisionVisitsScalingPower", "Power to apply to the interpolation between 1 and max to make it curved."}; +const OptionId SearchParams::kUCIOpponentId{ + "", "UCI_Opponent", + "UCI option used by the GUI to pass the name and other information about " + "the current opponent."}; +const OptionId SearchParams::kUCIRatingAdvId{ + "", "UCI_RatingAdv", + "UCI extension used by some GUIs to pass the estimated Elo advantage over " + "the current opponent, used as the default contempt value."}; void SearchParams::Populate(OptionsParser* options) { // Here the uci optimized defaults" are set. // Many of them are overridden with training specific values in tournament.cc. options->Add(kMiniBatchSizeId, 1, 1024) = DEFAULT_MINIBATCH_SIZE; - options->Add(kMaxPrefetchBatchId, 0, 1024) = DEFAULT_MAX_PREFETCH; options->Add(kCpuctId, 0.0f, 100.0f) = 1.745f; options->Add(kCpuctAtRootId, 0.0f, 100.0f) = 1.745f; options->Add(kCpuctBaseId, 1.0f, 1000000000.0f) = 38739.0f; @@ -369,7 +499,8 @@ void SearchParams::Populate(OptionsParser* options) { "centipawn_2018", "win_percentage", "Q", - "W-L"}; + "W-L", + "WDL_mu"}; options->Add(kScoreTypeId, score_type) = "centipawn"; std::vector history_fill_opt{"no", "fen_only", "always"}; options->Add(kHistoryFillId, history_fill_opt) = "fen_only"; @@ -386,8 +517,22 @@ void SearchParams::Populate(OptionsParser* options) { options->Add(kDrawScoreOpponentId, -100, 100) = 0; options->Add(kDrawScoreWhiteId, -100, 100) = 0; options->Add(kDrawScoreBlackId, -100, 100) = 0; + std::vector perspective = {"sidetomove", "white", "black", + "none"}; + options->Add(kContemptPerspectiveId, perspective) = + "sidetomove"; + // The default kContemptId is empty, so the initial contempt value is taken + // from kUCIRatingAdvId. Adding any value (without name) in the comma + // separated kContemptId list will override this. + options->Add(kContemptId) = ""; + options->Add(kContemptMaxValueId, 0, 10000.0f) = 420.0f; + options->Add(kWDLCalibrationEloId, 0, 10000.0f) = 0.0f; + options->Add(kWDLContemptAttenuationId, -10.0f, 10.0f) = 1.0f; + options->Add(kWDLEvalObjectivityId, 0.0f, 1.0f) = 1.0f; + options->Add(kWDLDrawRateTargetId, 0.001f, 0.999f) = 0.5f; + options->Add(kWDLDrawRateReferenceId, 0.001f, 0.999f) = 0.5f; + options->Add(kWDLBookExitBiasId, -2.0f, 2.0f) = 0.65f; options->Add(kNpsLimitId, 0.0f, 1e6f) = 0.0f; - options->Add(kSolidTreeThresholdId, 1, 2000000000) = 100; options->Add(kTaskWorkersPerSearchWorkerId, 0, 128) = DEFAULT_TASK_WORKERS; options->Add(kMinimumWorkSizeForProcessingId, 2, 100000) = 20; @@ -397,6 +542,8 @@ void SearchParams::Populate(OptionsParser* options) { options->Add(kMinimumWorkPerTaskForProcessingId, 1, 100000) = 8; options->Add(kIdlingMinimumWorkId, 0, 10000) = 0; options->Add(kThreadIdlingThresholdId, 0, 128) = 1; + options->Add(kUCIOpponentId); + options->Add(kUCIRatingAdvId, -10000.0f, 10000.0f) = 0.0f; options->HideOption(kNoiseEpsilonId); options->HideOption(kNoiseAlphaId); @@ -415,6 +562,10 @@ void SearchParams::Populate(OptionsParser* options) { options->HideOption(kTemperatureEndgameId); options->HideOption(kTemperatureWinpctCutoffId); options->HideOption(kTemperatureVisitOffsetId); + options->HideOption(kContemptMaxValueId); + options->HideOption(kWDLContemptAttenuationId); + options->HideOption(kWDLDrawRateTargetId); + options->HideOption(kWDLBookExitBiasId); } SearchParams::SearchParams(const OptionsDict& options) @@ -465,12 +616,32 @@ SearchParams::SearchParams(const OptionsDict& options) kDrawScoreOpponent{options.Get(kDrawScoreOpponentId) / 100.0f}, kDrawScoreWhite{options.Get(kDrawScoreWhiteId) / 100.0f}, kDrawScoreBlack{options.Get(kDrawScoreBlackId) / 100.0f}, + kContemptPerspective(EncodeContemptPerspective( + options.Get(kContemptPerspectiveId))), + kContempt(options.IsDefault(kContemptId) + ? options.Get(kUCIRatingAdvId) + : GetContempt(options.Get(kUCIOpponentId), + options.Get(kContemptId))), + kWDLRescaleParams( + options.Get(kWDLCalibrationEloId) == 0 + ? AccurateWDLRescaleParams( + kContempt, options.Get(kWDLDrawRateTargetId), + options.Get(kWDLDrawRateReferenceId), + options.Get(kWDLBookExitBiasId), + options.Get(kContemptMaxValueId), + options.Get(kWDLContemptAttenuationId)) + : SimplifiedWDLRescaleParams( + kContempt, options.Get(kWDLDrawRateReferenceId), + options.Get(kWDLCalibrationEloId), + options.Get(kContemptMaxValueId), + options.Get(kWDLContemptAttenuationId))), + kWDLEvalObjectivity(options.Get(kWDLEvalObjectivityId)), kMaxOutOfOrderEvals(std::max( 1, static_cast(options.Get(kMaxOutOfOrderEvalsId) * options.Get(kMiniBatchSizeId)))), kNpsLimit(options.Get(kNpsLimitId)), - kSolidTreeThreshold(options.Get(kSolidTreeThresholdId)), - kTaskWorkersPerSearchWorker(options.Get(kTaskWorkersPerSearchWorkerId)), + kTaskWorkersPerSearchWorker( + options.Get(kTaskWorkersPerSearchWorkerId)), kMinimumWorkSizeForProcessing( options.Get(kMinimumWorkSizeForProcessingId)), kMinimumWorkSizeForPicking( diff --git a/src/mcts/params.h b/src/mcts/params.h index 7fcf0ef879..599992d0eb 100644 --- a/src/mcts/params.h +++ b/src/mcts/params.h @@ -33,19 +33,28 @@ namespace lczero { +enum class ContemptPerspective { STM, WHITE, BLACK, NONE }; + class SearchParams { public: SearchParams(const OptionsDict& options); SearchParams(const SearchParams&) = delete; + // Use struct for WDLRescaleParams calculation to make them const. + struct WDLRescaleParams { + WDLRescaleParams(float r, float d) { + ratio = r; + diff = d; + } + float ratio; + float diff; + }; + // Populates UciOptions with search parameters. static void Populate(OptionsParser* options); // Parameter getters. - int GetMiniBatchSize() const { return kMiniBatchSize; } - int GetMaxPrefetchBatch() const { - return options_.Get(kMaxPrefetchBatchId); - } + uint32_t GetMiniBatchSize() const { return kMiniBatchSize; } float GetCpuct(bool at_root) const { return at_root ? kCpuctAtRoot : kCpuct; } float GetCpuctBase(bool at_root) const { return at_root ? kCpuctBaseAtRoot : kCpuctBase; @@ -104,13 +113,18 @@ class SearchParams { } bool GetDisplayCacheUsage() const { return kDisplayCacheUsage; } int GetMaxConcurrentSearchers() const { return kMaxConcurrentSearchers; } + ContemptPerspective GetContemptPerspective() const { + return kContemptPerspective; + } float GetSidetomoveDrawScore() const { return kDrawScoreSidetomove; } float GetOpponentDrawScore() const { return kDrawScoreOpponent; } float GetWhiteDrawDelta() const { return kDrawScoreWhite; } float GetBlackDrawDelta() const { return kDrawScoreBlack; } - int GetMaxOutOfOrderEvals() const { return kMaxOutOfOrderEvals; } + float GetWDLRescaleRatio() const { return kWDLRescaleParams.ratio; } + float GetWDLRescaleDiff() const { return kWDLRescaleParams.diff; } + float GetWDLEvalObjectivity() const { return kWDLEvalObjectivity; } + uint32_t GetMaxOutOfOrderEvals() const { return kMaxOutOfOrderEvals; } float GetNpsLimit() const { return kNpsLimit; } - int GetSolidTreeThreshold() const { return kSolidTreeThreshold; } int GetTaskWorkersPerSearchWorker() const { return kTaskWorkersPerSearchWorker; @@ -141,7 +155,6 @@ class SearchParams { // Search parameter IDs. static const OptionId kMiniBatchSizeId; - static const OptionId kMaxPrefetchBatchId; static const OptionId kCpuctId; static const OptionId kCpuctAtRootId; static const OptionId kCpuctBaseId; @@ -184,13 +197,21 @@ class SearchParams { static const OptionId kMovesLeftSlopeId; static const OptionId kDisplayCacheUsageId; static const OptionId kMaxConcurrentSearchersId; + static const OptionId kContemptPerspectiveId; static const OptionId kDrawScoreSidetomoveId; static const OptionId kDrawScoreOpponentId; static const OptionId kDrawScoreWhiteId; static const OptionId kDrawScoreBlackId; + static const OptionId kContemptId; + static const OptionId kContemptMaxValueId; + static const OptionId kWDLCalibrationEloId; + static const OptionId kWDLContemptAttenuationId; + static const OptionId kWDLEvalObjectivityId; + static const OptionId kWDLDrawRateTargetId; + static const OptionId kWDLDrawRateReferenceId; + static const OptionId kWDLBookExitBiasId; static const OptionId kMaxOutOfOrderEvalsId; static const OptionId kNpsLimitId; - static const OptionId kSolidTreeThresholdId; static const OptionId kTaskWorkersPerSearchWorkerId; static const OptionId kMinimumWorkSizeForProcessingId; static const OptionId kMinimumWorkSizeForPickingId; @@ -201,6 +222,8 @@ class SearchParams { static const OptionId kMaxCollisionVisitsScalingStartId; static const OptionId kMaxCollisionVisitsScalingEndId; static const OptionId kMaxCollisionVisitsScalingPowerId; + static const OptionId kUCIOpponentId; + static const OptionId kUCIRatingAdvId; private: const OptionsDict& options_; @@ -244,9 +267,12 @@ class SearchParams { const float kDrawScoreOpponent; const float kDrawScoreWhite; const float kDrawScoreBlack; + const ContemptPerspective kContemptPerspective; + const float kContempt; + const WDLRescaleParams kWDLRescaleParams; + const float kWDLEvalObjectivity; const int kMaxOutOfOrderEvals; const float kNpsLimit; - const int kSolidTreeThreshold; const int kTaskWorkersPerSearchWorker; const int kMinimumWorkSizeForProcessing; const int kMinimumWorkSizeForPicking; diff --git a/src/mcts/search.cc b/src/mcts/search.cc index 6fe8d1c7e6..557f1c7a12 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -34,12 +34,11 @@ #include #include #include +#include #include #include #include "mcts/node.h" -#include "neural/cache.h" -#include "neural/encoder.h" #include "utils/fastmath.h" #include "utils/random.h" @@ -111,6 +110,11 @@ class MEvaluator { const float child_m = child.GetM(parent_m_); float m = std::clamp(m_slope_ * (child_m - parent_m_), -m_cap_, m_cap_); m *= FastSign(-q); + if (q_threshold_ > 0.0f && q_threshold_ < 1.0f) { + // This allows a smooth M effect with higher q thresholds, which is + // necessary for using MLH together with contempt. + q = std::max(0.0f, (std::abs(q) - q_threshold_)) / (1.0f - q_threshold_); + } m *= a_constant_ + a_linear_ * std::abs(q) + a_square_ * q * q; return m; } @@ -145,7 +149,7 @@ class MEvaluator { } // namespace -Search::Search(const NodeTree& tree, Network* network, +Search::Search(NodeTree* dag, Network* network, std::unique_ptr uci_responder, const MoveList& searchmoves, std::chrono::steady_clock::time_point start_time, @@ -154,10 +158,11 @@ Search::Search(const NodeTree& tree, Network* network, SyzygyTablebase* syzygy_tb) : ok_to_respond_bestmove_(!infinite), stopper_(std::move(stopper)), - root_node_(tree.GetCurrentHead()), + root_node_(dag->GetCurrentHead()), cache_(cache), + dag_(dag), syzygy_tb_(syzygy_tb), - played_history_(tree.GetPositionHistory()), + played_history_(dag->GetPositionHistory()), network_(network), params_(options), searchmoves_(searchmoves), @@ -194,6 +199,44 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) { } } // namespace +namespace { +// WDL conversion formula based on random walk model. +inline void WDLRescale(float& v, float& d, float* mu_uci, + float wdl_rescale_ratio, float wdl_rescale_diff, + float sign, bool invert) { + if (invert) { + wdl_rescale_diff = -wdl_rescale_diff; + wdl_rescale_ratio = 1.0f / wdl_rescale_ratio; + } + auto w = (1 + v - d) / 2; + auto l = (1 - v - d) / 2; + // Safeguard against numerical issues; skip WDL transformation if WDL is too + // extreme. + const float zero = 0.0001f; + const float one = 0.9999f; + if (w > zero && d > zero && l > zero && w < one && d < one && l < one) { + auto a = FastLog(1 / l - 1); + auto b = FastLog(1 / w - 1); + auto s = 2 / (a + b); + // Safeguard against unrealistically broad WDL distributions coming from + // the NN. Could be made into a parameter, but probably unnecessary. + if (!invert) s = std::min(1.4f, s); + auto mu = (a - b) / (a + b); + auto s_new = s * wdl_rescale_ratio; + if (invert) { + std::swap(s, s_new); + s = std::min(1.4f, s); + } + auto mu_new = mu + sign * s * s * wdl_rescale_diff; + auto w_new = FastLogistic((-1.0f + mu_new) / s_new); + auto l_new = FastLogistic((-1.0f - mu_new) / s_new); + v = w_new - l_new; + d = std::max(0.0f, 1.0f - w_new - l_new); + if (mu_uci) *mu_uci = mu_new; + } +} +} // namespace + void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) { const auto max_pv = params_.GetMultiPv(); const auto edges = GetBestChildrenNoTemperature(root_node_, max_pv, 0); @@ -235,12 +278,25 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) { ++multipv; uci_infos.emplace_back(common_info); auto& uci_info = uci_infos.back(); - const auto wl = edge.GetWL(default_wl); - const auto floatD = edge.GetD(default_d); + auto wl = edge.GetWL(default_wl); + auto d = edge.GetD(default_d); + float mu_uci = 0.0f; + // Only the diff effect is inverted, so we only need to call if diff != 0. + if (params_.GetContemptPerspective() != ContemptPerspective::NONE) { + auto sign = + ((params_.GetContemptPerspective() == ContemptPerspective::STM) || + ((params_.GetContemptPerspective() == ContemptPerspective::BLACK) == + played_history_.IsBlackToMove())) + ? 1.0f + : -1.0f; + WDLRescale(wl, d, &mu_uci, params_.GetWDLRescaleRatio(), + params_.GetWDLRescaleDiff() * params_.GetWDLEvalObjectivity(), + sign, true); + } const auto q = edge.GetQ(default_q, draw_score); if (edge.IsTerminal() && wl != 0.0f) { uci_info.mate = std::copysign( - std::round(edge.GetM(0.0f)) / 2 + (edge.IsTbTerminal() ? 101 : 1), + std::round(edge.GetM(0.0f) + 1) / 2 + (edge.IsTbTerminal() ? 100 : 0), wl); } else if (score_type == "centipawn_with_drawscore") { uci_info.score = 90 * tan(1.5637541897 * q); @@ -256,20 +312,30 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) { uci_info.score = q * 10000; } else if (score_type == "W-L") { uci_info.score = wl * 10000; + } else if (score_type == "WDL_mu") { + // Reports the WDL mu value whenever it is reasonable, and defaults to + // centipawn otherwise. + float centipawn_score = 90 * tan(1.5637541897 * wl); + uci_info.score = + mu_uci != 0.0f && std::abs(wl) + d < 0.99f && + (std::abs(mu_uci) < 1.0f || + std::abs(centipawn_score) < std::abs(100 * mu_uci)) + ? 100 * mu_uci + : centipawn_score; } - auto w = - std::max(0, static_cast(std::round(500.0 * (1.0 + wl - floatD)))); - auto l = - std::max(0, static_cast(std::round(500.0 * (1.0 - wl - floatD)))); + auto wdl_w = + std::max(0, static_cast(std::round(500.0 * (1.0 + wl - d)))); + auto wdl_l = + std::max(0, static_cast(std::round(500.0 * (1.0 - wl - d)))); // Using 1000-w-l so that W+D+L add up to 1000.0. - auto d = 1000 - w - l; - if (d < 0) { - w = std::min(1000, std::max(0, w + d / 2)); - l = 1000 - w; - d = 0; + auto wdl_d = 1000 - wdl_w - wdl_l; + if (wdl_d < 0) { + wdl_w = std::min(1000, std::max(0, wdl_w + wdl_d / 2)); + wdl_l = 1000 - wdl_w; + wdl_d = 0; } - uci_info.wdl = ThinkingInfo::WDL{w, d, l}; + uci_info.wdl = ThinkingInfo::WDL{wdl_w, wdl_d, wdl_l}; if (network_->GetCapabilities().has_mlh()) { uci_info.moves_left = static_cast( (1.0f + edge.GetM(1.0f + root_node_->GetM())) / 2.0f); @@ -278,10 +344,13 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) { if (per_pv_counters) uci_info.nodes = edge.GetN(); bool flip = played_history_.IsBlackToMove(); int depth = 0; + auto history = played_history_; for (auto iter = edge; iter; iter = GetBestChildNoTemperature(iter.node(), depth), flip = !flip) { uci_info.pv.push_back(iter.GetMove(flip)); - if (!iter.node()) break; // Last edge was dangling, cannot continue. + history.Append(iter.GetMove()); + // Last edge was dangling or a draw by repetition, cannot continue. + if (!iter.node() || history.Last().GetRepetitions() >= 2) break; depth += 1; } } @@ -371,13 +440,12 @@ inline float ComputeCpuct(const SearchParams& params, uint32_t N, } // namespace std::vector Search::GetVerboseStats(Node* node) const { - assert(node == root_node_ || node->GetParent() == root_node_); const bool is_root = (node == root_node_); const bool is_odd_depth = !is_root; const bool is_black_to_move = (played_history_.IsBlackToMove() == is_root); const float draw_score = GetDrawScore(is_odd_depth); const float fpu = GetFpu(params_, node, is_root, draw_score); - const float cpuct = ComputeCpuct(params_, node->GetN(), is_root); + const float cpuct = ComputeCpuct(params_, node->GetTotalVisits(), is_root); const float U_coeff = cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); std::vector edges; @@ -419,9 +487,13 @@ std::vector Search::GetVerboseStats(Node* node) const { std::optional v; if (n && n->IsTerminal()) { v = n->GetQ(sign * draw_score); - } else { - NNCacheLock nneval = GetCachedNNEval(n); - if (nneval) v = -nneval->q; + } else if (n) { + auto history = played_history_; + if (!is_root) { + history.Append(node->GetMove()); + } + NNCacheLock nneval = GetCachedNNEval(history); + if (nneval) v = -nneval->eval->q; } if (v) { print(oss, "(V: ", sign * *v, ") ", 7, 4); @@ -436,10 +508,13 @@ std::vector Search::GetVerboseStats(Node* node) const { up = -up; std::swap(lo, up); } - *oss << (lo == up ? "(T) " - : lo == GameResult::DRAW && up == GameResult::WHITE_WON ? "(W) " - : lo == GameResult::BLACK_WON && up == GameResult::DRAW ? "(L) " - : ""); + *oss << (lo == up + ? "(T) " + : lo == GameResult::DRAW && up == GameResult::WHITE_WON + ? "(W) " + : lo == GameResult::BLACK_WON && up == GameResult::DRAW + ? "(L) " + : ""); } }; @@ -502,18 +577,8 @@ void Search::SendMovesStats() const REQUIRES(counters_mutex_) { } } -NNCacheLock Search::GetCachedNNEval(const Node* node) const { - if (!node) return {}; - - std::vector moves; - for (; node != root_node_; node = node->GetParent()) { - moves.push_back(node->GetOwnEdge()->GetMove()); - } - PositionHistory history(played_history_); - for (auto iter = moves.rbegin(), end = moves.rend(); iter != end; ++iter) { - history.Append(*iter); - } - const auto hash = history.HashLast(params_.GetCacheHistoryLength() + 1); +NNCacheLock Search::GetCachedNNEval(const PositionHistory& history) const { + const auto hash = dag_->GetHistoryHash(history); NNCacheLock nneval(cache_, hash); return nneval; } @@ -838,7 +903,8 @@ void Search::PopulateCommonIterationStats(IterationStats* stats) { nps_start_time_ = std::chrono::steady_clock::now(); } } - stats->total_nodes = total_playouts_ + initial_visits_; + stats->total_visits = total_playouts_ + initial_visits_; + stats->total_allocated_nodes = dag_->AllocatedNodeCount(); stats->nodes_since_movestart = total_playouts_; stats->batches_since_movestart = total_batches_; stats->average_depth = cum_depth_ / (total_playouts_ ? total_playouts_ : 1); @@ -957,10 +1023,9 @@ void Search::Wait() { void Search::CancelSharedCollisions() REQUIRES(nodes_mutex_) { for (auto& entry : shared_collisions_) { - Node* node = entry.first; - for (node = node->GetParent(); node != root_node_->GetParent(); - node = node->GetParent()) { - node->CancelScoreUpdate(entry.second); + auto path = entry.first; + for (auto it = ++(path.crbegin()); it != path.crend(); ++it) { + std::get<0>(*it)->CancelScoreUpdate(entry.second); } } shared_collisions_.clear(); @@ -972,7 +1037,16 @@ Search::~Search() { { SharedMutex::Lock lock(nodes_mutex_); CancelSharedCollisions(); + +#ifndef NDEBUG + assert(root_node_->ZeroNInFlight()); +#endif } + + // Free previously released nodes that were not reused during this search. + dag_->TTMaintenance(); + dag_->TTMaintenance(); + LOGFILE << "Search destroyed."; } @@ -1036,14 +1110,13 @@ void SearchWorker::RunTasks(int tid) { if (task != nullptr) { switch (task->task_type) { case PickTask::kGathering: { - PickNodesToExtendTask(task->start, task->base_depth, - task->collision_limit, task->moves_to_base, - &(task->results), &(task_workspaces_[tid])); + PickNodesToExtendTask(task->start_path, task->collision_limit, + task->history, &(task->results), + &(task_workspaces_[tid])); break; } case PickTask::kProcessing: { - ProcessPickedTask(task->start_idx, task->end_idx, - &(task_workspaces_[tid])); + ProcessPickedTask(task->start_idx, task->end_idx); break; } } @@ -1059,7 +1132,7 @@ void SearchWorker::ExecuteOneIteration() { if (params_.GetMaxConcurrentSearchers() != 0) { while (true) { - // If search is stop, we've not gathered or done anything and we don't + // If search is stopped, we've not gathered or done anything and we don't // want to, so we can safely skip all below. But make sure we have done // at least one iteration. if (search_->stop_.load(std::memory_order_acquire) && @@ -1087,9 +1160,6 @@ void SearchWorker::ExecuteOneIteration() { // 2b. Collect collisions. CollectCollisions(); - // 3. Prefetch into cache. - MaybePrefetchIntoCache(); - if (params_.GetMaxConcurrentSearchers() != 0) { search_->pending_searchers_.fetch_add(1, std::memory_order_acq_rel); } @@ -1132,8 +1202,9 @@ void SearchWorker::ExecuteOneIteration() { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ void SearchWorker::InitializeIteration( std::unique_ptr computation) { - computation_ = std::make_unique(std::move(computation), - search_->cache_); + computation_ = std::make_unique( + std::move(computation), search_->network_->GetCapabilities().input_format, + params_.GetHistoryFill(), search_->cache_); computation_->Reserve(params_.GetMiniBatchSize()); minibatch_.clear(); minibatch_.reserve(2 * params_.GetMiniBatchSize()); @@ -1166,7 +1237,7 @@ int CalculateCollisionsLeft(int64_t nodes, const SearchParams& params) { void SearchWorker::GatherMinibatch() { // Total number of nodes to process. - int minibatch_size = 0; + uint32_t minibatch_size = 0; int cur_n = 0; { SharedMutex::Lock lock(search_->nodes_mutex_); @@ -1176,7 +1247,7 @@ void SearchWorker::GatherMinibatch() { // applied, which doesn't clearly make sense to include here... int64_t remaining_n = latest_time_manager_hints_.GetEstimatedRemainingPlayouts(); - int collisions_left = CalculateCollisionsLeft( + uint32_t collisions_left = CalculateCollisionsLeft( std::min(static_cast(cur_n), remaining_n), params_); // Number of nodes processed out of order. @@ -1223,39 +1294,45 @@ void SearchWorker::GatherMinibatch() { ++minibatch_size; } - bool needs_wait = false; - int ppt_start = new_start; - if (params_.GetTaskWorkersPerSearchWorker() > 0 && - non_collisions >= params_.GetMinimumWorkSizeForProcessing()) { - const int num_tasks = std::clamp( - non_collisions / params_.GetMinimumWorkPerTaskForProcessing(), 2, - params_.GetTaskWorkersPerSearchWorker() + 1); - // Round down, left overs can go to main thread so it waits less. - int per_worker = non_collisions / num_tasks; - needs_wait = true; - ResetTasks(); - int found = 0; - for (int i = new_start; i < static_cast(minibatch_.size()); i++) { - auto& picked_node = minibatch_[i]; - if (picked_node.IsCollision()) { - continue; - } - ++found; - if (found == per_worker) { - picking_tasks_.emplace_back(ppt_start, i + 1); - task_count_.fetch_add(1, std::memory_order_acq_rel); - ppt_start = i + 1; - found = 0; - if (picking_tasks_.size() == static_cast(num_tasks - 1)) { - break; + { + // This lock must be held until after the task_completed_ wait succeeds + // below. Since the tasks perform work which assumes they have the lock, + // even though actually this thread does. + SharedMutex::Lock lock(search_->nodes_mutex_); + + bool needs_wait = false; + int ppt_start = new_start; + if (params_.GetTaskWorkersPerSearchWorker() > 0 && + non_collisions >= params_.GetMinimumWorkSizeForProcessing()) { + const int num_tasks = std::clamp( + non_collisions / params_.GetMinimumWorkPerTaskForProcessing(), 2, + params_.GetTaskWorkersPerSearchWorker() + 1); + // Round down, left overs can go to main thread so it waits less. + int per_worker = non_collisions / num_tasks; + needs_wait = true; + ResetTasks(); + int found = 0; + for (int i = new_start; i < static_cast(minibatch_.size()); i++) { + auto& picked_node = minibatch_[i]; + if (picked_node.IsCollision()) { + continue; + } + ++found; + if (found == per_worker) { + picking_tasks_.emplace_back(ppt_start, i + 1); + task_count_.fetch_add(1, std::memory_order_acq_rel); + ppt_start = i + 1; + found = 0; + if (picking_tasks_.size() == static_cast(num_tasks - 1)) { + break; + } } } } - } - ProcessPickedTask(ppt_start, static_cast(minibatch_.size()), - &main_workspace_); - if (needs_wait) { - WaitForTasks(); + ProcessPickedTask(ppt_start, static_cast(minibatch_.size())); + if (needs_wait) { + WaitForTasks(); + } } bool some_ooo = false; for (int i = static_cast(minibatch_.size()) - 1; i >= new_start; i--) { @@ -1273,14 +1350,13 @@ void SearchWorker::GatherMinibatch() { // This may remove too many items, but hopefully most of the time they // will just be added back in the same in the next gather. if (minibatch_[i].IsCollision()) { - Node* node = minibatch_[i].node; - for (node = node->GetParent(); - node != search_->root_node_->GetParent(); - node = node->GetParent()) { - node->CancelScoreUpdate(minibatch_[i].multivisit); + for (auto it = ++(minibatch_[i].path.crbegin()); + it != minibatch_[i].path.crend(); ++it) { + std::get<0>(*it)->CancelScoreUpdate(minibatch_[i].multivisit); } minibatch_.erase(minibatch_.begin() + i); } else if (minibatch_[i].ooo_completed) { + FetchSingleNodeResult(&minibatch_[i], minibatch_[i], 0); DoBackupUpdateSingleNode(minibatch_[i]); minibatch_.erase(minibatch_.begin() + i); --minibatch_size; @@ -1292,15 +1368,13 @@ void SearchWorker::GatherMinibatch() { // If there was no OOO, there can stil be collisions. // There are no OOO though. // Also terminals when OOO is disabled. - if (!minibatch_[i].nn_queried) continue; + if (!minibatch_[i].ShouldAddToInput()) continue; if (minibatch_[i].is_cache_hit) { // Since minibatch_[i] holds cache lock, this is guaranteed to succeed. computation_->AddInputByHash(minibatch_[i].hash, std::move(minibatch_[i].lock)); } else { - computation_->AddInput(minibatch_[i].hash, - std::move(minibatch_[i].input_planes), - std::move(minibatch_[i].probabilities_to_cache)); + computation_->AddInput(minibatch_[i].hash, minibatch_[i].history); } } @@ -1316,11 +1390,9 @@ void SearchWorker::GatherMinibatch() { int extra = std::min(picked_node.maxvisit, collisions_left) - picked_node.multivisit; picked_node.multivisit += extra; - Node* node = picked_node.node; - for (node = node->GetParent(); - node != search_->root_node_->GetParent(); - node = node->GetParent()) { - node->IncrementNInFlight(extra); + for (auto it = ++(picked_node.path.crbegin()); + it != picked_node.path.crend(); ++it) { + std::get<0>(*it)->IncrementNInFlight(extra); } } if ((collisions_left -= picked_node.multivisit) <= 0) return; @@ -1330,51 +1402,20 @@ void SearchWorker::GatherMinibatch() { } } -void SearchWorker::ProcessPickedTask(int start_idx, int end_idx, - TaskWorkspace* workspace) { - auto& history = workspace->history; - history = search_->played_history_; - +void SearchWorker::ProcessPickedTask(int start_idx, int end_idx) { for (int i = start_idx; i < end_idx; i++) { auto& picked_node = minibatch_[i]; if (picked_node.IsCollision()) continue; - auto* node = picked_node.node; - - // If node is already known as terminal (win/loss/draw according to rules - // of the game), it means that we already visited this node before. + // If node is a collision, known as a terminal (win/loss/draw according to + // the rules of the game) or has a low node, it means that we have already + // visited this node before and can't extend it. if (picked_node.IsExtendable()) { // Node was never visited, extend it. - ExtendNode(node, picked_node.depth, picked_node.moves_to_visit, &history); - if (!node->IsTerminal()) { - picked_node.nn_queried = true; - const auto hash = history.HashLast(params_.GetCacheHistoryLength() + 1); - picked_node.hash = hash; - picked_node.lock = NNCacheLock(search_->cache_, hash); - picked_node.is_cache_hit = picked_node.lock; - if (!picked_node.is_cache_hit) { - int transform; - picked_node.input_planes = EncodePositionForNN( - search_->network_->GetCapabilities().input_format, history, 8, - params_.GetHistoryFill(), &transform); - picked_node.probability_transform = transform; - - std::vector& moves = picked_node.probabilities_to_cache; - // Legal moves are known, use them. - moves.reserve(node->GetNumEdges()); - for (const auto& edge : node->Edges()) { - moves.emplace_back(edge.GetMove().as_nn_index(transform)); - } - } else { - picked_node.probability_transform = TransformForPosition( - search_->network_->GetCapabilities().input_format, history); - } - } - } - if (params_.GetOutOfOrderEval() && picked_node.CanEvalOutOfOrder()) { - // Perform out of order eval for the last entry in minibatch_. - FetchSingleNodeResult(&picked_node, picked_node, 0); - picked_node.ooo_completed = true; + ExtendNode(picked_node); } + + picked_node.ooo_completed = + params_.GetOutOfOrderEval() && picked_node.CanEvalOutOfOrder(); } } @@ -1412,8 +1453,10 @@ void SearchWorker::PickNodesToExtend(int collision_limit) { // Since the tasks perform work which assumes they have the lock, even though // actually this thread does. SharedMutex::Lock lock(search_->nodes_mutex_); - PickNodesToExtendTask(search_->root_node_, 0, collision_limit, empty_movelist, - &minibatch_, &main_workspace_); + history_.Trim(search_->played_history_.GetLength()); + PickNodesToExtendTask({std::make_tuple(search_->root_node_, 0, 0)}, + collision_limit, history_, &minibatch_, + &main_workspace_); WaitForTasks(); for (int i = 0; i < static_cast(picking_tasks_.size()); i++) { @@ -1424,52 +1467,78 @@ void SearchWorker::PickNodesToExtend(int collision_limit) { } } -void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node, - int depth) { - // Check whether first repetition was before root. If yes, remove - // terminal status of node and revert all visits in the tree. - // Length of repetition was stored in m_. This code will only do - // something when tree is reused and twofold visits need to be - // reverted. - if (child_node->IsTwoFoldTerminal() && depth < child_node->GetM()) { - // Take a mutex - any SearchWorker specific mutex... since this is - // not safe to do concurrently between multiple tasks. - Mutex::Lock lock(picking_tasks_mutex_); - int depth_counter = 0; - // Cache node's values as we reset them in the process. We could - // manually set wl and d, but if we want to reuse this for reverting - // other terminal nodes this is the way to go. - const auto wl = child_node->GetWL(); - const auto d = child_node->GetD(); - const auto m = child_node->GetM(); - const auto terminal_visits = child_node->GetN(); - for (Node* node_to_revert = child_node; node_to_revert != nullptr; - node_to_revert = node_to_revert->GetParent()) { - // Revert all visits on twofold draw when making it non terminal. - node_to_revert->RevertTerminalVisits(wl, d, m + (float)depth_counter, - terminal_visits); - depth_counter++; - // Even if original tree still exists, we don't want to revert - // more than until new root. - if (depth_counter > depth) break; - // If wl != 0, we would have to switch signs at each depth. - } - // Mark the prior twofold draw as non terminal to extend it again. - child_node->MakeNotTerminal(); - // When reverting the visits, we also need to revert the initial - // visits, as we reused fewer nodes than anticipated. - search_->initial_visits_ -= terminal_visits; - // Max depth doesn't change when reverting the visits, and - // cum_depth_ only counts the average depth of new nodes, not reused - // ones. +// Depth starts with 0 at root, so number of plies in PV equals depth. +std::pair SearchWorker::GetRepetitions(int depth, + const Position& position) { + const auto repetitions = position.GetRepetitions(); + + if (repetitions == 0) return {0, 0}; + + if (repetitions >= 2) return {repetitions, 0}; + + const auto plies = position.GetPliesSincePrevRepetition(); + if (params_.GetTwoFoldDraws() && /*repetitions == 1 &&*/ depth >= 4 && + depth >= plies) { + return {1, plies}; } + + return {0, 0}; +} + +// Check if PickNodesToExtendTask should stop picking at this @node. +bool SearchWorker::ShouldStopPickingHere(Node* node, bool is_root_node, + int repetitions) { + constexpr double wl_diff_limit = 0.01f; + constexpr float d_diff_limit = 0.01f; + constexpr float m_diff_limit = 2.0f; + + if (node->GetN() == 0 || node->IsTerminal()) return true; + + // Only stop at root when there is no other option. + assert(!is_root_node || node == search_->root_node_); + if (is_root_node) return false; + + // Stop at draws by repetition. + if (repetitions >= 2) return true; + + // Check if Node and LowNode differ significantly. + auto low_node = node->GetLowNode(); + assert(low_node); + + // Only known transpositions can differ. + if (!low_node->IsTransposition()) return false; + + // LowNode is terminal when Node is not. + if (low_node->IsTerminal()) return true; + + // Bounds differ (swap). + auto [low_node_lower, low_node_upper] = low_node->GetBounds(); + auto [node_lower, node_upper] = node->GetBounds(); + if (low_node_lower != -node_upper || low_node_upper != -node_lower) + return true; + + // WL differs significantly (flip). + auto wl_diff = std::abs(low_node->GetWL() + node->GetWL()); + if (wl_diff >= wl_diff_limit) return true; + + // D differs significantly. + auto d_diff = std::abs(low_node->GetD() - node->GetD()); + if (d_diff >= d_diff_limit) return true; + + // M differs significantly (increment). + auto m_diff = std::abs(low_node->GetM() + 1 - node->GetM()); + if (m_diff >= m_diff_limit) return true; + + return false; } void SearchWorker::PickNodesToExtendTask( - Node* node, int base_depth, int collision_limit, - const std::vector& moves_to_base, + const BackupPath& path, int collision_limit, PositionHistory& history, std::vector* receiver, TaskWorkspace* workspace) NO_THREAD_SAFETY_ANALYSIS { + assert(path.size() == (size_t)history.GetLength() - + search_->played_history_.GetLength() + 1); + // TODO: Bring back pre-cached nodes created outside locks in a way that works // with tasks. // TODO: pre-reserve visits_to_perform for expected depth and likely maximum @@ -1481,15 +1550,16 @@ void SearchWorker::PickNodesToExtendTask( vtp_last_filled.clear(); auto& current_path = workspace->current_path; current_path.clear(); - auto& moves_to_path = workspace->moves_to_path; - moves_to_path = moves_to_base; + auto& full_path = workspace->full_path; + full_path = path; + assert(full_path.size() > 0); + auto [node, repetitions, moves_left] = full_path.back(); // Sometimes receiver is reused, othertimes not, so only jump start if small. if (receiver->capacity() < 30) { receiver->reserve(receiver->size() + 30); } - // These 2 are 'filled pre-emptively'. - std::array current_pol; + // This 1 is 'filled pre-emptively'. std::array current_util; // These 3 are 'filled on demand'. @@ -1515,6 +1585,7 @@ void SearchWorker::PickNodesToExtendTask( current_path.push_back(-1); while (current_path.size() > 0) { + assert(full_path.size() >= path.size()); // First prepare visits_to_perform. if (current_path.back() == -1) { // Need to do n visits, where n is either collision_limit, or comes from @@ -1526,31 +1597,37 @@ void SearchWorker::PickNodesToExtendTask( } // First check if node is terminal or not-expanded. If either than create // a collision of appropriate size and pop current_path. - if (node->GetN() == 0 || node->IsTerminal()) { + if (ShouldStopPickingHere(node, is_root_node, repetitions)) { if (is_root_node) { // Root node is special - since its not reached from anywhere else, so // it needs its own logic. Still need to create the collision to // ensure the outer gather loop gives up. if (node->TryStartScoreUpdate()) { cur_limit -= 1; - minibatch_.push_back(NodeToProcess::Visit( - node, static_cast(current_path.size() + base_depth))); + minibatch_.push_back( + NodeToProcess::Visit(full_path, search_->played_history_)); completed_visits++; } } // Visits are created elsewhere, just need the collisions here. if (cur_limit > 0) { int max_count = 0; - if (cur_limit == collision_limit && base_depth == 0 && + if (cur_limit == collision_limit && path.size() == 1 && max_limit > cur_limit) { max_count = max_limit; } - receiver->push_back(NodeToProcess::Collision( - node, static_cast(current_path.size() + base_depth), - cur_limit, max_count)); + receiver->push_back( + NodeToProcess::Collision(full_path, cur_limit, max_count)); completed_visits += cur_limit; } - node = node->GetParent(); + history.Pop(); + full_path.pop_back(); + if (full_path.size() > 0) { + std::tie(node, repetitions, moves_left) = full_path.back(); + } else { + node = nullptr; + repetitions = 0; + } current_path.pop_back(); continue; } @@ -1570,34 +1647,20 @@ void SearchWorker::PickNodesToExtendTask( vtp_last_filled.push_back(-1); // Cache all constant UCT parameters. - // When we're near the leaves we can copy less of the policy, since there - // is no way iteration will ever reach it. - // TODO: This is a very conservative formula. It assumes every visit we're - // aiming to add is going to trigger a new child, and that any visits - // we've already had have also done so and then a couple extra since we go - // to 2 unvisited to get second best in worst case. - // Unclear we can do better without having already walked the children. - // Which we are putting off until after policy is copied so we can create - // visited policy without having to cache it in the node (allowing the - // node to stay at 64 bytes). + int max_needed = node->GetNumEdges(); - if (!is_root_node || root_move_filter.empty()) { - max_needed = std::min(max_needed, node->GetNStarted() + cur_limit + 2); - } - node->CopyPolicy(max_needed, current_pol.data()); for (int i = 0; i < max_needed; i++) { current_util[i] = std::numeric_limits::lowest(); } // Root depth is 1 here, while for GetDrawScore() it's 0-based, that's why // the weirdness. - const float draw_score = ((current_path.size() + base_depth) % 2 == 0) - ? odd_draw_score - : even_draw_score; + const float draw_score = + (full_path.size() % 2 == 0) ? odd_draw_score : even_draw_score; m_evaluator.SetParent(node); float visited_pol = 0.0f; for (Node* child : node->VisitedNodes()) { int index = child->Index(); - visited_pol += current_pol[index]; + visited_pol += child->GetP(); float q = child->GetQ(draw_score); current_util[index] = q + m_evaluator.GetM(child, q); } @@ -1609,7 +1672,8 @@ void SearchWorker::PickNodesToExtendTask( } } - const float cpuct = ComputeCpuct(params_, node->GetN(), is_root_node); + const float cpuct = + ComputeCpuct(params_, node->GetTotalVisits(), is_root_node); const float puct_mult = cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); int cache_filled_idx = -1; @@ -1635,7 +1699,7 @@ void SearchWorker::PickNodesToExtendTask( const float util = current_util[idx]; if (idx > cache_filled_idx) { current_score[idx] = - current_pol[idx] * puct_mult / (1 + nstarted) + util; + cur_iters[idx].GetP() * puct_mult / (1 + nstarted) + util; cache_filled_idx++; } if (is_root_node) { @@ -1683,7 +1747,7 @@ void SearchWorker::PickNodesToExtendTask( if (best_without_u < second_best) { const auto n1 = current_nstarted[best_idx] + 1; estimated_visits_to_change_best = static_cast( - std::max(1.0f, std::min(current_pol[best_idx] * puct_mult / + std::max(1.0f, std::min(cur_iters[best_idx].GetP() * puct_mult / (second_best - best_without_u) - n1 + 1, 1e9f))); @@ -1702,43 +1766,35 @@ void SearchWorker::PickNodesToExtendTask( } (*visits_to_perform.back())[best_idx] += new_visits; cur_limit -= new_visits; - Node* child_node = best_edge.GetOrSpawnNode(/* parent */ node); - - // Probably best place to check for two-fold draws consistently. - // Depth starts with 1 at root, so real depth is depth - 1. - EnsureNodeTwoFoldCorrectForDepth( - child_node, current_path.size() + base_depth + 1 - 1); - bool decremented = false; + Node* child_node = best_edge.GetOrSpawnNode(/* parent */ node); + history.Append(best_edge.GetMove()); + auto [child_repetitions, child_moves_left] = + GetRepetitions(full_path.size(), history.Last()); + full_path.push_back({child_node, child_repetitions, child_moves_left}); if (child_node->TryStartScoreUpdate()) { current_nstarted[best_idx]++; new_visits -= 1; - decremented = true; - if (child_node->GetN() > 0 && !child_node->IsTerminal()) { + if (ShouldStopPickingHere(child_node, false, child_repetitions)) { + // Reduce 1 for the visits_to_perform to ensure the collision + // created doesn't include this visit. + (*visits_to_perform.back())[best_idx] -= 1; + receiver->push_back(NodeToProcess::Visit(full_path, history)); + completed_visits++; + } else { child_node->IncrementNInFlight(new_visits); current_nstarted[best_idx] += new_visits; } - current_score[best_idx] = current_pol[best_idx] * puct_mult / + current_score[best_idx] = cur_iters[best_idx].GetP() * puct_mult / (1 + current_nstarted[best_idx]) + current_util[best_idx]; } - if ((decremented && - (child_node->GetN() == 0 || child_node->IsTerminal()))) { - // Reduce 1 for the visits_to_perform to ensure the collision created - // doesn't include this visit. - (*visits_to_perform.back())[best_idx] -= 1; - receiver->push_back(NodeToProcess::Visit( - child_node, - static_cast(current_path.size() + 1 + base_depth))); - completed_visits++; - receiver->back().moves_to_visit.reserve(moves_to_path.size() + 1); - receiver->back().moves_to_visit = moves_to_path; - receiver->back().moves_to_visit.push_back(best_edge.GetMove()); - } if (best_idx > vtp_last_filled.back() && (*visits_to_perform.back())[best_idx] > 0) { vtp_last_filled.back() = best_idx; } + history.Pop(); + full_path.pop_back(); } is_root_node = false; // Actively do any splits now rather than waiting for potentially long @@ -1753,29 +1809,32 @@ void SearchWorker::PickNodesToExtendTask( collision_limit - params_.GetMinimumRemainingWorkSizeForPicking()) { Node* child_node = cur_iters[i].GetOrSpawnNode(/* parent */ node); + history.Append(cur_iters[i].GetMove()); + auto [child_repetitions, child_moves_left] = + GetRepetitions(full_path.size(), history.Last()); + full_path.push_back( + {child_node, child_repetitions, child_moves_left}); // Don't split if not expanded or terminal. - if (child_node->GetN() == 0 || child_node->IsTerminal()) continue; - - bool passed = false; - { - // Multiple writers, so need mutex here. - Mutex::Lock lock(picking_tasks_mutex_); - // Ensure not to exceed size of reservation. - if (picking_tasks_.size() < MAX_TASKS) { - moves_to_path.push_back(cur_iters[i].GetMove()); - picking_tasks_.emplace_back( - child_node, current_path.size() - 1 + base_depth + 1, - moves_to_path, child_limit); - moves_to_path.pop_back(); - task_count_.fetch_add(1, std::memory_order_acq_rel); - task_added_.notify_all(); - passed = true; - passed_off += child_limit; + if (!ShouldStopPickingHere(child_node, false, child_repetitions)) { + bool passed = false; + { + // Multiple writers, so need mutex here. + Mutex::Lock lock(picking_tasks_mutex_); + // Ensure not to exceed size of reservation. + if (picking_tasks_.size() < MAX_TASKS) { + picking_tasks_.emplace_back(full_path, history, child_limit); + task_count_.fetch_add(1, std::memory_order_acq_rel); + task_added_.notify_all(); + passed = true; + passed_off += child_limit; + } + } + if (passed) { + (*visits_to_perform.back())[i] = 0; } } - if (passed) { - (*visits_to_perform.back())[i] = 0; - } + history.Pop(); + full_path.pop_back(); } } // Fall through to select the first child. @@ -1787,14 +1846,13 @@ void SearchWorker::PickNodesToExtendTask( for (auto& child : node->Edges()) { idx++; if (idx > min_idx && (*visits_to_perform.back())[idx] > 0) { - if (moves_to_path.size() != current_path.size() + base_depth) { - moves_to_path.push_back(child.GetMove()); - } else { - moves_to_path.back() = child.GetMove(); - } current_path.back() = idx; current_path.push_back(-1); node = child.GetOrSpawnNode(/* parent */ node); + history.Append(child.GetMove()); + std::tie(repetitions, moves_left) = + GetRepetitions(full_path.size(), history.Last()); + full_path.push_back({node, repetitions, moves_left}); found_child = true; break; } @@ -1802,8 +1860,14 @@ void SearchWorker::PickNodesToExtendTask( } } if (!found_child) { - node = node->GetParent(); - if (!moves_to_path.empty()) moves_to_path.pop_back(); + history.Pop(); + full_path.pop_back(); + if (full_path.size() > 0) { + std::tie(node, repetitions, moves_left) = full_path.back(); + } else { + node = nullptr; + repetitions = 0; + } current_path.pop_back(); vtp_buffer.push_back(std::move(visits_to_perform.back())); visits_to_perform.pop_back(); @@ -1812,22 +1876,20 @@ void SearchWorker::PickNodesToExtendTask( } } -void SearchWorker::ExtendNode(Node* node, int depth, - const std::vector& moves_to_node, - PositionHistory* history) { - // Initialize position sequence with pre-move position. - history->Trim(search_->played_history_.GetLength()); - for (size_t i = 0; i < moves_to_node.size(); i++) { - history->Append(moves_to_node[i]); - } +void SearchWorker::ExtendNode(NodeToProcess& picked_node) { + const auto path = picked_node.path; + assert(!std::get<0>(path.back())->GetLowNode()); + + const PositionHistory& history = picked_node.history; // We don't need the mutex because other threads will see that N=0 and // N-in-flight=1 and will not touch this node. - const auto& board = history->Last().GetBoard(); - auto legal_moves = board.GenerateLegalMoves(); + const auto& board = history.Last().GetBoard(); + std::vector legal_moves = board.GenerateLegalMoves(); // Check whether it's a draw/lose by position. Importantly, we must check // these before doing the by-rule checks below. + auto node = picked_node.node; if (legal_moves.empty()) { // Could be a checkmate or a stalemate if (board.IsUnderCheck()) { @@ -1846,58 +1908,41 @@ void SearchWorker::ExtendNode(Node* node, int depth, return; } - if (history->Last().GetRule50Ply() >= 100) { + if (history.Last().GetRule50Ply() >= 100) { node->MakeTerminal(GameResult::DRAW); return; } - const auto repetitions = history->Last().GetRepetitions(); - // Mark two-fold repetitions as draws according to settings. - // Depth starts with 1 at root, so number of plies in PV is depth - 1. - if (repetitions >= 2) { - node->MakeTerminal(GameResult::DRAW); - return; - } else if (repetitions == 1 && depth - 1 >= 4 && - params_.GetTwoFoldDraws() && - depth - 1 >= history->Last().GetPliesSincePrevRepetition()) { - const auto cycle_length = history->Last().GetPliesSincePrevRepetition(); - // use plies since first repetition as moves left; exact if forced draw. - node->MakeTerminal(GameResult::DRAW, (float)cycle_length, - Node::Terminal::TwoFold); - return; + // Handle repetition draws as pseudo-terminals. + if (picked_node.repetitions >= 2) { + // Not a real terminal, set low node. } - - // Neither by-position or by-rule termination, but maybe it's a TB position. - if (search_->syzygy_tb_ && !search_->root_is_in_dtz_ && - board.castlings().no_legal_castle() && - history->Last().GetRule50Ply() == 0 && - (board.ours() | board.theirs()).count() <= - search_->syzygy_tb_->max_cardinality()) { + // Neither by-position or by-rule termination, but maybe it's a TB + // position. + else if (search_->syzygy_tb_ && !search_->root_is_in_dtz_ && + board.castlings().no_legal_castle() && + history.Last().GetRule50Ply() == 0 && + (board.ours() | board.theirs()).count() <= + search_->syzygy_tb_->max_cardinality()) { ProbeState state; const WDLScore wdl = - search_->syzygy_tb_->probe_wdl(history->Last(), &state); + search_->syzygy_tb_->probe_wdl(history.Last(), &state); // Only fail state means the WDL is wrong, probe_wdl may produce correct // result with a stat other than OK. if (state != FAIL) { // TB nodes don't have NN evaluation, assign M from parent node. float m = 0.0f; - // Need a lock to access parent, in case MakeSolid is in progress. - { - SharedMutex::SharedLock lock(search_->nodes_mutex_); - auto parent = node->GetParent(); - if (parent) { - m = std::max(0.0f, parent->GetM() - 1.0f); - } + if (path.size() > 1) { + auto parent = std::get<0>(path[path.size() - 2]); + m = std::max(0.0f, parent->GetM() - 1.0f); } // If the colors seem backwards, check the checkmate check above. if (wdl == WDL_WIN) { - node->MakeTerminal(GameResult::BLACK_WON, m, - Node::Terminal::Tablebase); + node->MakeTerminal(GameResult::BLACK_WON, m, Terminal::Tablebase); } else if (wdl == WDL_LOSS) { - node->MakeTerminal(GameResult::WHITE_WON, m, - Node::Terminal::Tablebase); + node->MakeTerminal(GameResult::WHITE_WON, m, Terminal::Tablebase); } else { // Cursed wins and blessed losses count as draws. - node->MakeTerminal(GameResult::DRAW, m, Node::Terminal::Tablebase); + node->MakeTerminal(GameResult::DRAW, m, Terminal::Tablebase); } search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel); return; @@ -1905,42 +1950,19 @@ void SearchWorker::ExtendNode(Node* node, int depth, } } - // Add legal moves as edges of this node. - node->CreateEdges(legal_moves); -} + picked_node.nn_queried = true; // Node::SetLowNode() required. -// Returns whether node was already in cache. -bool SearchWorker::AddNodeToComputation(Node* node) { - const auto hash = history_.HashLast(params_.GetCacheHistoryLength() + 1); - if (search_->cache_->ContainsKey(hash)) { - return true; - } - int transform; - auto planes = - EncodePositionForNN(search_->network_->GetCapabilities().input_format, - history_, 8, params_.GetHistoryFill(), &transform); - - std::vector moves; - - if (node && node->HasChildren()) { - // Legal moves are known, use them. - moves.reserve(node->GetNumEdges()); - for (const auto& edge : node->Edges()) { - moves.emplace_back(edge.GetMove().as_nn_index(transform)); - } + // Check the transposition table first and NN cache second before asking for + // NN evaluation. + picked_node.hash = search_->dag_->GetHistoryHash(history); + auto tt_low_node = search_->dag_->TTFind(picked_node.hash); + if (tt_low_node != nullptr) { + picked_node.tt_low_node = tt_low_node; + picked_node.is_tt_hit = true; } else { - // Cache pseudolegal moves. A bit of a waste, but faster. - const auto& pseudolegal_moves = - history_.Last().GetBoard().GeneratePseudolegalMoves(); - moves.reserve(pseudolegal_moves.size()); - for (auto iter = pseudolegal_moves.begin(), end = pseudolegal_moves.end(); - iter != end; ++iter) { - moves.emplace_back(iter->as_nn_index(transform)); - } + picked_node.lock = NNCacheLock(search_->cache_, picked_node.hash); + picked_node.is_cache_hit = picked_node.lock; } - - computation_->AddInput(hash, std::move(planes), std::move(moves)); - return false; } // 2b. Copy collisions into shared collisions. @@ -1949,187 +1971,81 @@ void SearchWorker::CollectCollisions() { for (const NodeToProcess& node_to_process : minibatch_) { if (node_to_process.IsCollision()) { - search_->shared_collisions_.emplace_back(node_to_process.node, + search_->shared_collisions_.emplace_back(node_to_process.path, node_to_process.multivisit); } } } -// 3. Prefetch into cache. -// ~~~~~~~~~~~~~~~~~~~~~~~ -void SearchWorker::MaybePrefetchIntoCache() { - // TODO(mooskagh) Remove prefetch into cache if node collisions work well. - // If there are requests to NN, but the batch is not full, try to prefetch - // nodes which are likely useful in future. - if (search_->stop_.load(std::memory_order_acquire)) return; - if (computation_->GetCacheMisses() > 0 && - computation_->GetCacheMisses() < params_.GetMaxPrefetchBatch()) { - history_.Trim(search_->played_history_.GetLength()); - SharedMutex::SharedLock lock(search_->nodes_mutex_); - PrefetchIntoCache( - search_->root_node_, - params_.GetMaxPrefetchBatch() - computation_->GetCacheMisses(), false); - } -} - -// Prefetches up to @budget nodes into cache. Returns number of nodes -// prefetched. -int SearchWorker::PrefetchIntoCache(Node* node, int budget, bool is_odd_depth) { - const float draw_score = search_->GetDrawScore(is_odd_depth); - if (budget <= 0) return 0; - - // We are in a leaf, which is not yet being processed. - if (!node || node->GetNStarted() == 0) { - if (AddNodeToComputation(node)) { - // Make it return 0 to make it not use the slot, so that the function - // tries hard to find something to cache even among unpopular moves. - // In practice that slows things down a lot though, as it's not always - // easy to find what to cache. - return 1; - } - return 1; - } - - assert(node); - // n = 0 and n_in_flight_ > 0, that means the node is being extended. - if (node->GetN() == 0) return 0; - // The node is terminal; don't prefetch it. - if (node->IsTerminal()) return 0; - - // Populate all subnodes and their scores. - typedef std::pair ScoredEdge; - std::vector scores; - const float cpuct = - ComputeCpuct(params_, node->GetN(), node == search_->root_node_); - const float puct_mult = - cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); - const float fpu = - GetFpu(params_, node, node == search_->root_node_, draw_score); - for (auto& edge : node->Edges()) { - if (edge.GetP() == 0.0f) continue; - // Flip the sign of a score to be able to easily sort. - // TODO: should this use logit_q if set?? - scores.emplace_back(-edge.GetU(puct_mult) - edge.GetQ(fpu, draw_score), - edge); - } - - size_t first_unsorted_index = 0; - int total_budget_spent = 0; - int budget_to_spend = budget; // Initialize for the case where there's only - // one child. - for (size_t i = 0; i < scores.size(); ++i) { - if (search_->stop_.load(std::memory_order_acquire)) break; - if (budget <= 0) break; - - // Sort next chunk of a vector. 3 at a time. Most of the time it's fine. - if (first_unsorted_index != scores.size() && - i + 2 >= first_unsorted_index) { - const int new_unsorted_index = - std::min(scores.size(), budget < 2 ? first_unsorted_index + 2 - : first_unsorted_index + 3); - std::partial_sort(scores.begin() + first_unsorted_index, - scores.begin() + new_unsorted_index, scores.end(), - [](const ScoredEdge& a, const ScoredEdge& b) { - return a.first < b.first; - }); - first_unsorted_index = new_unsorted_index; - } - - auto edge = scores[i].second; - // Last node gets the same budget as prev-to-last node. - if (i != scores.size() - 1) { - // Sign of the score was flipped for sorting, so flip it back. - const float next_score = -scores[i + 1].first; - // TODO: As above - should this use logit_q if set? - const float q = edge.GetQ(-fpu, draw_score); - if (next_score > q) { - budget_to_spend = - std::min(budget, int(edge.GetP() * puct_mult / (next_score - q) - - edge.GetNStarted()) + - 1); - } else { - budget_to_spend = budget; - } - } - history_.Append(edge.GetMove()); - const int budget_spent = - PrefetchIntoCache(edge.node(), budget_to_spend, !is_odd_depth); - history_.Pop(); - budget -= budget_spent; - total_budget_spent += budget_spent; - } - return total_budget_spent; -} - // 4. Run NN computation. // ~~~~~~~~~~~~~~~~~~~~~~ -void SearchWorker::RunNNComputation() { computation_->ComputeBlocking(); } +void SearchWorker::RunNNComputation() { + computation_->ComputeBlocking(params_.GetPolicySoftmaxTemp()); +} // 5. Retrieve NN computations (and terminal values) into nodes. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ void SearchWorker::FetchMinibatchResults() { + SharedMutex::Lock nodes_lock(search_->nodes_mutex_); // Populate NN/cached results, or terminal results, into nodes. int idx_in_computation = 0; for (auto& node_to_process : minibatch_) { FetchSingleNodeResult(&node_to_process, *computation_, idx_in_computation); - if (node_to_process.nn_queried) ++idx_in_computation; + if (node_to_process.ShouldAddToInput()) ++idx_in_computation; } } template void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process, const Computation& computation, - int idx_in_computation) { - if (node_to_process->IsCollision()) return; - Node* node = node_to_process->node; - if (!node_to_process->nn_queried) { - // Terminal nodes don't involve the neural NetworkComputation, nor do - // they require any further processing after value retrieval. - node_to_process->v = node->GetWL(); - node_to_process->d = node->GetD(); - node_to_process->m = node->GetM(); - return; - } - // For NN results, we need to populate policy as well as value. - // First the value... - node_to_process->v = -computation.GetQVal(idx_in_computation); - node_to_process->d = computation.GetDVal(idx_in_computation); - node_to_process->m = computation.GetMVal(idx_in_computation); - // ...and secondly, the policy data. - // Calculate maximum first. - float max_p = -std::numeric_limits::infinity(); - // Intermediate array to store values when processing policy. - // There are never more than 256 valid legal moves in any legal position. - std::array intermediate; - int counter = 0; - for (auto& edge : node->Edges()) { - float p = computation.GetPVal( - idx_in_computation, - edge.GetMove().as_nn_index(node_to_process->probability_transform)); - intermediate[counter++] = p; - max_p = std::max(max_p, p); - } - float total = 0.0; - for (int i = 0; i < counter; i++) { - // Perform softmax and take into account policy softmax temperature T. - // Note that we want to calculate (exp(p-max_p))^(1/T) = exp((p-max_p)/T). - float p = - FastExp((intermediate[i] - max_p) / params_.GetPolicySoftmaxTemp()); - intermediate[i] = p; - total += p; - } - counter = 0; - // Normalize P values to add up to 1.0. - const float scale = total > 0.0f ? 1.0f / total : 1.0f; - for (auto& edge : node->Edges()) { - edge.edge()->SetP(intermediate[counter++] * scale); + int idx_in_computation) + REQUIRES(search_->nodes_mutex_) { + if (!node_to_process->nn_queried) return; + + if (!node_to_process->is_tt_hit) { + auto [tt_low_node, is_tt_miss] = + search_->dag_->TTGetOrCreate(node_to_process->hash); + + assert(tt_low_node != nullptr); + node_to_process->tt_low_node = tt_low_node; + if (is_tt_miss) { + auto nn_eval = computation.GetNNEval(idx_in_computation).get(); + if (params_.GetContemptPerspective() != ContemptPerspective::NONE) { + bool root_stm = + (params_.GetContemptPerspective() == ContemptPerspective::STM + ? !(search_->played_history_.Last().IsBlackToMove()) + : (params_.GetContemptPerspective() == + ContemptPerspective::WHITE)); + auto sign = (root_stm ^ node_to_process->history.IsBlackToMove()) + ? 1.0f + : -1.0f; + if (params_.GetWDLRescaleRatio() != 1.0f || + params_.GetWDLRescaleDiff() != 0.0f) { + float v = nn_eval->q; + float d = nn_eval->d; + WDLRescale(v, d, nullptr, params_.GetWDLRescaleRatio(), + params_.GetWDLRescaleDiff(), sign, false); + nn_eval->q = v; + nn_eval->d = d; + } + } + node_to_process->tt_low_node->SetNNEval(nn_eval); + } } + + // Add NN results to node. + Node* node = node_to_process->node; // Add Dirichlet noise if enabled and at root. if (params_.GetNoiseEpsilon() && node == search_->root_node_) { + auto low_node = search_->dag_->NonTTAddClone(*node_to_process->tt_low_node); + assert(low_node != nullptr); + node->SetLowNode(low_node); ApplyDirichletNoise(node, params_.GetNoiseEpsilon(), params_.GetNoiseAlpha()); + node->SortEdges(); + } else { + node->SetLowNode(node_to_process->tt_low_node); } - node->SortEdges(); } // 6. Propagate the new nodes' information to all their parents in the tree. @@ -2150,60 +2066,152 @@ void SearchWorker::DoBackupUpdate() { search_->total_batches_ += 1; } +bool SearchWorker::MaybeAdjustForTerminalOrTransposition( + Node* n, const LowNode* nl, float& v, float& d, float& m, + uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta, + bool& update_parent_bounds) const { + if (n->IsTerminal()) { + v = n->GetWL(); + d = n->GetD(); + m = n->GetM(); + + return true; + } + + // Use information from transposition or a new terminal. + if (nl->IsTransposition() || nl->IsTerminal()) { + // Adapt information from low node to node by flipping Q sign, bounds, + // result and incrementing m. + v = -nl->GetWL(); + d = nl->GetD(); + m = nl->GetM() + 1; + // When starting at or going through a transposition/terminal, make sure to + // use the information it has already acquired. + n_to_fix = n->GetN(); + v_delta = v - n->GetWL(); + d_delta = d - n->GetD(); + m_delta = m - n->GetM(); + // Update bounds. + if (params_.GetStickyEndgames()) { + auto tt = nl->GetTerminalType(); + if (tt != Terminal::NonTerminal) { + GameResult r; + if (v == 1.0f) { + r = GameResult::WHITE_WON; + } else if (v == -1.0f) { + r = GameResult::BLACK_WON; + } else { + r = GameResult::DRAW; + } + + n->MakeTerminal(r, m, tt); + update_parent_bounds = true; + } else { + auto [lower, upper] = nl->GetBounds(); + n->SetBounds(-upper, -lower); + } + } + + return true; + } + + return false; +} + +// Use information from terminal status or low node to update node and node's +// parent low node and so on until the root is reached. Low node may become a +// transposition and/or get more information even during this batch. Both low +// node and node may adjust bounds and become a terminal during this batch. void SearchWorker::DoBackupUpdateSingleNode( const NodeToProcess& node_to_process) REQUIRES(search_->nodes_mutex_) { - Node* node = node_to_process.node; if (node_to_process.IsCollision()) { // Collisions are handled via shared_collisions instead. return; } + auto path = node_to_process.path; + auto [n, nr, nm] = path.back(); // For the first visit to a terminal, maybe update parent bounds too. auto update_parent_bounds = - params_.GetStickyEndgames() && node->IsTerminal() && !node->GetN(); - - // Backup V value up to a root. After 1 visit, V = Q. - float v = node_to_process.v; - float d = node_to_process.d; - float m = node_to_process.m; - int n_to_fix = 0; + params_.GetStickyEndgames() && n->IsTerminal() && !n->GetN(); + auto nl = n->GetLowNode(); + float v = 0.0f; + float d = 0.0f; + float m = 0.0f; + uint32_t n_to_fix = 0; float v_delta = 0.0f; float d_delta = 0.0f; float m_delta = 0.0f; - uint32_t solid_threshold = - static_cast(params_.GetSolidTreeThreshold()); - for (Node *n = node, *p; n != search_->root_node_->GetParent(); n = p) { - p = n->GetParent(); - - // Current node might have become terminal from some other descendant, so - // backup the rest of the way with more accurate values. - if (n->IsTerminal()) { - v = n->GetWL(); - d = n->GetD(); - m = n->GetM(); - } + + // Update the low node at the start of the backup path first, but only visit + // it the first time that backup sees it. + if (nl && nl->GetN() == 0) { + nl->FinalizeScoreUpdate(nl->GetWL(), nl->GetD(), nl->GetM(), + node_to_process.multivisit); + } + + if (nr >= 2) { + // Three-fold itself has to be handled as a terminal to produce relevant + // results. Unlike two-folds that can keep updating their "real" values. + n->SetRepetition(); + v = 0.0f; + d = 1.0f; + m = 1; + } else if (!MaybeAdjustForTerminalOrTransposition(n, nl, v, d, m, n_to_fix, + v_delta, d_delta, m_delta, + update_parent_bounds)) { + // If there is nothing better, use original NN values adjusted for node. + v = -nl->GetWL(); + d = nl->GetD(); + m = nl->GetM() + 1; + } + + // Backup V value up to a root. After 1 visit, V = Q. + for (auto it = path.crbegin(); it != path.crend(); + /* ++it in the body */) { n->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit); if (n_to_fix > 0 && !n->IsTerminal()) { + // First part of the path might be never as it was removed and recreated. + n_to_fix = std::min(n_to_fix, n->GetN()); n->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix); } - if (n->GetN() >= solid_threshold) { - if (n->MakeSolid() && n == search_->root_node_) { - // If we make the root solid, the current_best_edge_ becomes invalid and - // we should repopulate it. - search_->current_best_edge_ = - search_->GetBestChildNoTemperature(search_->root_node_, 0); - } + + // Stop delta update on repetition "terminal" and propagate a draw above + // repetitions valid on the current path. + // Only do this after edge update to have good values if play goes here. + if (nr == 1 && !n->IsTerminal()) { + n->SetRepetition(); + v = 0.0f; + d = 1.0f; + m = nm + 1; } + if (n->IsRepetition()) n_to_fix = 0; // Nothing left to do without ancestors to update. - if (!p) break; + if (++it == path.crend()) break; + auto [p, pr, pm] = *it; + auto pl = p->GetLowNode(); + + assert(!p->IsTerminal() || + (p->IsTerminal() && pl->IsTerminal() && p->GetWL() == -pl->GetWL() && + p->GetD() == pl->GetD())); + // If parent low node is already a (new) terminal, then change propagated + // values and stop terminal adjustment. + if (pl->IsTerminal()) { + v = pl->GetWL(); + d = pl->GetD(); + m = pl->GetM(); + n_to_fix = 0; + } + pl->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit); + if (n_to_fix > 0) { + pl->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix); + } bool old_update_parent_bounds = update_parent_bounds; - // If parent already is terminal further adjustment is not required. - if (p->IsTerminal()) n_to_fix = 0; // Try setting parent bounds except the root or those already terminal. update_parent_bounds = - update_parent_bounds && p != search_->root_node_ && !p->IsTerminal() && + update_parent_bounds && p != search_->root_node_ && !pl->IsTerminal() && MaybeSetBounds(p, m, &n_to_fix, &v_delta, &d_delta, &m_delta); // Q will be flipped for opponent. @@ -2211,6 +2219,10 @@ void SearchWorker::DoBackupUpdateSingleNode( v_delta = -v_delta; m++; + MaybeAdjustForTerminalOrTransposition(p, pl, v, d, m, n_to_fix, v_delta, + d_delta, m_delta, + update_parent_bounds); + // Update the stats. // Best move. // If update_parent_bounds was set, we just adjusted bounds on the @@ -2226,19 +2238,25 @@ void SearchWorker::DoBackupUpdateSingleNode( search_->current_best_edge_ = search_->GetBestChildNoTemperature(search_->root_node_, 0); } + + n = p; + nr = pr; + nm = pm; } search_->total_playouts_ += node_to_process.multivisit; - search_->cum_depth_ += node_to_process.depth * node_to_process.multivisit; - search_->max_depth_ = std::max(search_->max_depth_, node_to_process.depth); + search_->cum_depth_ += + node_to_process.path.size() * node_to_process.multivisit; + search_->max_depth_ = + std::max(search_->max_depth_, (uint16_t)node_to_process.path.size()); } -bool SearchWorker::MaybeSetBounds(Node* p, float m, int* n_to_fix, +bool SearchWorker::MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, float* v_delta, float* d_delta, float* m_delta) const { auto losing_m = 0.0f; auto prefer_tb = false; - // Determine the maximum (lower, upper) bounds across all children. + // Determine the maximum (lower, upper) bounds across all edges. // (-1,-1) Loss (initial and lowest bounds) // (-1, 0) Can't Win // (-1, 1) Regular node @@ -2274,6 +2292,7 @@ bool SearchWorker::MaybeSetBounds(Node* p, float m, int* n_to_fix, // Win ( 1, 1) -> (-1,-1) Loss // Nothing left to do for ancestors if the parent would be a regular node. + auto pl = p->GetLowNode(); if (lower == GameResult::BLACK_WON && upper == GameResult::WHITE_WON) { return false; } else if (lower == upper) { @@ -2281,19 +2300,19 @@ bool SearchWorker::MaybeSetBounds(Node* p, float m, int* n_to_fix, // it terminal preferring shorter wins and longer losses. *n_to_fix = p->GetN(); assert(*n_to_fix > 0); - float cur_v = p->GetWL(); - float cur_d = p->GetD(); - float cur_m = p->GetM(); + pl->MakeTerminal( + upper, (upper == GameResult::BLACK_WON ? std::max(losing_m, m) : m), + prefer_tb ? Terminal::Tablebase : Terminal::EndOfGame); + // v, d and m will be set in MaybeAdjustForTerminalOrTransposition. + *v_delta = pl->GetWL() + p->GetWL(); + *d_delta = pl->GetD() - p->GetD(); + *m_delta = pl->GetM() + 1 - p->GetM(); p->MakeTerminal( -upper, (upper == GameResult::BLACK_WON ? std::max(losing_m, m) : m) + 1.0f, - prefer_tb ? Node::Terminal::Tablebase : Node::Terminal::EndOfGame); - // Negate v_delta because we're calculating for the parent, but immediately - // afterwards we'll negate v_delta in case it has come from the child. - *v_delta = -(p->GetWL() - cur_v); - *d_delta = p->GetD() - cur_d; - *m_delta = p->GetM() - cur_m; + prefer_tb ? Terminal::Tablebase : Terminal::EndOfGame); } else { + pl->SetBounds(lower, upper); p->SetBounds(-upper, -lower); } diff --git a/src/mcts/search.h b/src/mcts/search.h index 1323320bca..b2547b9b03 100644 --- a/src/mcts/search.h +++ b/src/mcts/search.h @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include "chess/callbacks.h" #include "chess/uciloop.h" @@ -40,16 +42,17 @@ #include "mcts/params.h" #include "mcts/stoppers/timemgr.h" #include "neural/cache.h" -#include "neural/network.h" #include "syzygy/syzygy.h" #include "utils/logging.h" #include "utils/mutex.h" namespace lczero { +typedef std::vector> BackupPath; + class Search { public: - Search(const NodeTree& tree, Network* network, + Search(NodeTree* dag, Network* network, std::unique_ptr uci_responder, const MoveList& searchmoves, std::chrono::steady_clock::time_point start_time, @@ -96,7 +99,7 @@ class Search { void ResetBestMove(); // Returns NN eval for a given node from cache, if that node is cached. - NNCacheLock GetCachedNNEval(const Node* node) const; + NNCacheLock GetCachedNNEval(const PositionHistory& history) const; private: // Computes the best move, maybe with temperature (according to the settings). @@ -163,6 +166,7 @@ class Search { Node* root_node_; NNCache* cache_; + NodeTree* dag_; SyzygyTablebase* syzygy_tb_; // Fixed positions which happened before the search. const PositionHistory& played_history_; @@ -196,7 +200,7 @@ class Search { std::atomic backend_waiting_counter_{0}; std::atomic thread_count_{0}; - std::vector> shared_collisions_ + std::vector> shared_collisions_ GUARDED_BY(nodes_mutex_); std::unique_ptr uci_responder_; @@ -255,7 +259,7 @@ class SearchWorker { // Does one full iteration of MCTS search: // 1. Initialize internal structures. // 2. Gather minibatch. - // 3. Prefetch into cache. + // 3. // 4. Run NN computation. // 5. Retrieve NN computations (and terminal values) into nodes. // 6. Propagate the new nodes' information to all their parents in the tree. @@ -273,9 +277,6 @@ class SearchWorker { // 2b. Copy collisions into shared_collisions_. void CollectCollisions(); - // 3. Prefetch into cache. - void MaybePrefetchIntoCache(); - // 4. Run NN computation. void RunNNComputation(); @@ -290,88 +291,94 @@ class SearchWorker { private: struct NodeToProcess { - bool IsExtendable() const { return !is_collision && !node->IsTerminal(); } + bool IsExtendable() const { + return !is_collision && !node->IsTerminal() && !node->GetLowNode(); + } bool IsCollision() const { return is_collision; } bool CanEvalOutOfOrder() const { - return is_cache_hit || node->IsTerminal(); + return is_tt_hit || is_cache_hit || node->IsTerminal() || + node->GetLowNode(); } + bool ShouldAddToInput() const { return nn_queried && !is_tt_hit; } + // The path to the node to extend. + BackupPath path; // The node to extend. Node* node; - // Value from NN's value head, or -1/0/1 for terminal nodes. - float v; - // Draw probability for NN's with WDL value head. - float d; - // Estimated remaining plies left. - float m; - int multivisit = 0; + uint32_t multivisit = 0; // If greater than multivisit, and other parameters don't imply a lower // limit, multivist could be increased to this value without additional // change in outcome of next selection. - int maxvisit = 0; - uint16_t depth; + uint32_t maxvisit = 0; bool nn_queried = false; + bool is_tt_hit = false; bool is_cache_hit = false; bool is_collision = false; - int probability_transform = 0; - - // Details only populated in the multigather path. - - // Only populated for visits, - std::vector moves_to_visit; // Details that are filled in as we go. uint64_t hash; + LowNode* tt_low_node; NNCacheLock lock; - std::vector probabilities_to_cache; - InputPlanes input_planes; - mutable int last_idx = 0; + PositionHistory history; bool ooo_completed = false; - static NodeToProcess Collision(Node* node, uint16_t depth, - int collision_count) { - return NodeToProcess(node, depth, true, collision_count, 0); - } - static NodeToProcess Collision(Node* node, uint16_t depth, - int collision_count, int max_count) { - return NodeToProcess(node, depth, true, collision_count, max_count); + // Repetition draws. + int repetitions = 0; + + static NodeToProcess Collision(const BackupPath& path, int collision_count, + int max_count) { + return NodeToProcess(path, collision_count, max_count); } - static NodeToProcess Visit(Node* node, uint16_t depth) { - return NodeToProcess(node, depth, false, 1, 0); + static NodeToProcess Visit(const BackupPath& path, + const PositionHistory& history) { + return NodeToProcess(path, history); } - // Methods to allow NodeToProcess to conform as a 'Computation'. Only safe + // Method to allow NodeToProcess to conform as a 'Computation'. Only safe // to call if is_cache_hit is true in the multigather path. - - float GetQVal(int) const { return lock->q; } - - float GetDVal(int) const { return lock->d; } - - float GetMVal(int) const { return lock->m; } - - float GetPVal(int, int move_id) const { - const auto& moves = lock->p; - - int total_count = 0; - while (total_count < moves.size()) { - // Optimization: usually moves are stored in the same order as queried. - const auto& move = moves[last_idx++]; - if (last_idx == moves.size()) last_idx = 0; - if (move.first == move_id) return move.second; - ++total_count; + std::shared_ptr GetNNEval(int) const { return lock->eval; } + + std::string DebugString() const { + std::ostringstream oss; + oss << " This:" << this << " Depth:" << path.size() + << " Node:" << node << " Multivisit:" << multivisit + << " Maxvisit:" << maxvisit << " NNQueried:" << nn_queried + << " TTHit:" << is_tt_hit << " CacheHit:" << is_cache_hit + << " Collision:" << is_collision << " OOO:" << ooo_completed + << " Repetitions:" << repetitions << " Path:"; + for (auto it = path.cbegin(); it != path.cend(); ++it) { + if (it != path.cbegin()) oss << "->"; + auto n = std::get<0>(*it); + auto nl = n->GetLowNode(); + oss << n << ":" << n->GetNInFlight(); + if (nl) { + oss << "(" << nl << ")"; + } } - assert(false); // Move not found. - return 0; + oss << " --- " << std::get<0>(path.back())->DebugString(); + if (node->GetLowNode()) + oss << " --- " << node->GetLowNode()->DebugString(); + + return oss.str(); } private: - NodeToProcess(Node* node, uint16_t depth, bool is_collision, int multivisit, - int max_count) - : node(node), + NodeToProcess(const BackupPath& path, uint32_t multivisit, + uint32_t max_count) + : path(path), + node(std::get<0>(path.back())), multivisit(multivisit), maxvisit(max_count), - depth(depth), - is_collision(is_collision) {} + is_collision(true), + repetitions(0) {} + NodeToProcess(const BackupPath& path, const PositionHistory& in_history) + : path(path), + node(std::get<0>(path.back())), + multivisit(1), + maxvisit(0), + is_collision(false), + history(in_history), + repetitions(std::get<1>(path.back())) {} }; // Holds per task worker scratch data @@ -381,15 +388,13 @@ class SearchWorker { std::vector>> visits_to_perform; std::vector vtp_last_filled; std::vector current_path; - std::vector moves_to_path; - PositionHistory history; + BackupPath full_path; TaskWorkspace() { vtp_buffer.reserve(30); visits_to_perform.reserve(30); vtp_last_filled.reserve(30); current_path.reserve(30); - moves_to_path.reserve(30); - history.Reserve(30); + full_path.reserve(30); } }; @@ -398,10 +403,10 @@ class SearchWorker { PickTaskType task_type; // For task type gathering. + BackupPath start_path; Node* start; - int base_depth; int collision_limit; - std::vector moves_to_base; + PositionHistory history; std::vector results; // Task type post gather processing. @@ -410,35 +415,45 @@ class SearchWorker { bool complete = false; - PickTask(Node* node, uint16_t depth, const std::vector& base_moves, + PickTask(const BackupPath& start_path, const PositionHistory& in_history, int collision_limit) : task_type(kGathering), - start(node), - base_depth(depth), + start_path(start_path), + start(std::get<0>(start_path.back())), collision_limit(collision_limit), - moves_to_base(base_moves) {} + history(in_history) {} PickTask(int start_idx, int end_idx) : task_type(kProcessing), start_idx(start_idx), end_idx(end_idx) {} }; NodeToProcess PickNodeToExtend(int collision_limit); - bool AddNodeToComputation(Node* node); - int PrefetchIntoCache(Node* node, int budget, bool is_odd_depth); + // Adjust parameters for updating node @n and its parent low node if node is + // terminal or its child low node is a transposition. Also update bounds and + // terminal status of node @n using information from its child low node. + // Return true if adjustment happened. + bool MaybeAdjustForTerminalOrTransposition(Node* n, const LowNode* nl, + float& v, float& d, float& m, + uint32_t& n_to_fix, float& v_delta, + float& d_delta, float& m_delta, + bool& update_parent_bounds) const; void DoBackupUpdateSingleNode(const NodeToProcess& node_to_process); // Returns whether a node's bounds were set based on its children. - bool MaybeSetBounds(Node* p, float m, int* n_to_fix, float* v_delta, + bool MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, float* v_delta, float* d_delta, float* m_delta) const; void PickNodesToExtend(int collision_limit); - void PickNodesToExtendTask(Node* starting_point, int collision_limit, - int base_depth, - const std::vector& moves_to_base, + void PickNodesToExtendTask(const BackupPath& path, int collision_limit, + PositionHistory& history, std::vector* receiver, TaskWorkspace* workspace); - void EnsureNodeTwoFoldCorrectForDepth(Node* node, int depth); - void ProcessPickedTask(int batch_start, int batch_end, - TaskWorkspace* workspace); - void ExtendNode(Node* node, int depth, const std::vector& moves_to_add, - PositionHistory* history); + + // Check if the situation described by @depth under root and @position is a + // safe two-fold or a draw by repetition and return the number of safe + // repetitions and moves_left. + std::pair GetRepetitions(int depth, const Position& position); + // Check if there is a reason to stop picking and pick @node. + bool ShouldStopPickingHere(Node* node, bool is_root_node, int repetitions); + void ProcessPickedTask(int batch_start, int batch_end); + void ExtendNode(NodeToProcess& picked_node); template void FetchSingleNodeResult(NodeToProcess* node_to_process, const Computation& computation, @@ -454,7 +469,7 @@ class SearchWorker { std::unique_ptr computation_; // History is reset and extended by PickNodeToExtend(). PositionHistory history_; - int number_out_of_order_ = 0; + uint32_t number_out_of_order_ = 0; const SearchParams& params_; std::unique_ptr precached_node_; const bool moves_left_support_; diff --git a/src/mcts/stoppers/smooth.cc b/src/mcts/stoppers/smooth.cc index 2a1a196247..a60f649bb3 100644 --- a/src/mcts/stoppers/smooth.cc +++ b/src/mcts/stoppers/smooth.cc @@ -626,7 +626,7 @@ bool SmoothStopper::ShouldStop(const IterationStats& stats, void SmoothStopper::OnSearchDone(const IterationStats& stats) { manager_->UpdateEndOfMoveStats(stats.time_since_movestart, used_piggybank_.test_and_set(), deadline_ms_, - stats.total_nodes); + stats.total_visits); } } // namespace diff --git a/src/mcts/stoppers/stoppers.cc b/src/mcts/stoppers/stoppers.cc index 9f99a4269a..e64c5d9157 100644 --- a/src/mcts/stoppers/stoppers.cc +++ b/src/mcts/stoppers/stoppers.cc @@ -59,10 +59,10 @@ void ChainedSearchStopper::OnSearchDone(const IterationStats& stats) { bool VisitsStopper::ShouldStop(const IterationStats& stats, StoppersHints* hints) { if (populate_remaining_playouts_) { - hints->UpdateEstimatedRemainingPlayouts(nodes_limit_ - stats.total_nodes); + hints->UpdateEstimatedRemainingPlayouts(nodes_limit_ - stats.total_visits); } - if (stats.total_nodes >= nodes_limit_) { - LOGFILE << "Stopped search: Reached visits limit: " << stats.total_nodes + if (stats.total_visits >= nodes_limit_) { + LOGFILE << "Stopped search: Reached visits limit: " << stats.total_visits << ">=" << nodes_limit_; return true; } @@ -92,12 +92,14 @@ bool PlayoutsStopper::ShouldStop(const IterationStats& stats, /////////////////////////// namespace { +// FIXME: This is too conservative. const size_t kAvgNodeSize = - sizeof(Node) + MemoryWatchingStopper::kAvgMovesPerPosition * sizeof(Edge); + sizeof(Node) + sizeof(LowNode) + + sizeof(NodeTree::TranspositionTable::slot_type) + + MemoryWatchingStopper::kAvgMovesPerPosition * sizeof(Edge); const size_t kAvgCacheItemSize = - NNCache::GetItemStructSize() + sizeof(CachedNNRequest) + - sizeof(CachedNNRequest::IdxAndProb) * - MemoryWatchingStopper::kAvgMovesPerPosition; + NNCache::GetItemStructSize() + sizeof(CachedNNRequest) + sizeof(NNEval) + + sizeof(Edge) * MemoryWatchingStopper::kAvgMovesPerPosition; } // namespace MemoryWatchingStopper::MemoryWatchingStopper(int cache_size, int ram_limit_mb, @@ -109,7 +111,21 @@ MemoryWatchingStopper::MemoryWatchingStopper(int cache_size, int ram_limit_mb, LOGFILE << "RAM limit " << ram_limit_mb << "MB. Cache takes " << cache_size * kAvgCacheItemSize / 1000000 << "MB. Remaining memory is enough for " << GetVisitsLimit() - << " nodes."; + << " allocated nodes."; +} + +bool MemoryWatchingStopper::ShouldStop(const IterationStats& stats, + StoppersHints* hints) { + if (populate_remaining_playouts_) { + hints->UpdateEstimatedRemainingPlayouts(nodes_limit_ - + stats.total_allocated_nodes); + } + if (stats.total_allocated_nodes >= nodes_limit_) { + LOGFILE << "Stopped search: Reached allocated node limit: " + << stats.total_allocated_nodes << ">=" << nodes_limit_; + return true; + } + return false; } /////////////////////////// @@ -152,7 +168,7 @@ KldGainStopper::KldGainStopper(float min_gain, int average_interval) bool KldGainStopper::ShouldStop(const IterationStats& stats, StoppersHints*) { Mutex::Lock lock(mutex_); - const auto new_child_nodes = stats.total_nodes - 1.0; + const auto new_child_nodes = stats.total_visits - 1.0; if (new_child_nodes < prev_child_nodes_ + average_interval_) return false; const auto new_visits = stats.edge_n; diff --git a/src/mcts/stoppers/stoppers.h b/src/mcts/stoppers/stoppers.h index 22ce26c088..ed5120deab 100644 --- a/src/mcts/stoppers/stoppers.h +++ b/src/mcts/stoppers/stoppers.h @@ -59,7 +59,7 @@ class VisitsStopper : public SearchStopper { int64_t GetVisitsLimit() const { return nodes_limit_; } bool ShouldStop(const IterationStats&, StoppersHints*) override; - private: + protected: const int64_t nodes_limit_; const bool populate_remaining_playouts_; }; @@ -86,6 +86,7 @@ class MemoryWatchingStopper : public VisitsStopper { static constexpr size_t kAvgMovesPerPosition = 30; MemoryWatchingStopper(int cache_size, int ram_limit_mb, bool populate_remaining_playouts); + bool ShouldStop(const IterationStats&, StoppersHints*) override; }; // Stops after time budget is gone. diff --git a/src/mcts/stoppers/timemgr.h b/src/mcts/stoppers/timemgr.h index 64d0c2c05d..5149bba3da 100644 --- a/src/mcts/stoppers/timemgr.h +++ b/src/mcts/stoppers/timemgr.h @@ -43,7 +43,8 @@ namespace lczero { struct IterationStats { int64_t time_since_movestart = 0; int64_t time_since_first_batch = 0; - int64_t total_nodes = 0; + int64_t total_visits = 0; + int64_t total_allocated_nodes = 0; int64_t nodes_since_movestart = 0; int64_t batches_since_movestart = 0; int average_depth = 0; diff --git a/src/neural/blas/encoder.h b/src/neural/blas/encoder.h index 99d2752752..9a1e4e9f2d 100644 --- a/src/neural/blas/encoder.h +++ b/src/neural/blas/encoder.h @@ -1,6 +1,6 @@ /* This file is part of Leela Chess Zero. - Copyright (C) 2018-2019 The LCZero Authors + Copyright (C) 2022-2023 The LCZero Authors Leela Chess is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -18,37 +18,34 @@ #pragma once -#include #include -#include #include "neural/shared/activation.h" -#include "utils/exception.h" -namespace lczero { - -namespace { - -template -using EigenMatrixMap = - Eigen::Map>; - -template -using ConstEigenMatrixMap = - Eigen::Map>; +#ifdef USE_ISPC +#include "layer_norm_ispc.h" +#endif -} // namespace +namespace lczero { void LayerNorm2DWithSkipConnection(const size_t batch_size, const size_t channels, float* data, - const float* skip, const float* gammas, - const float* betas, float epsilon) { + const float alpha, const float* skip, + const float* gammas, const float* betas, + float epsilon) { for (size_t i = 0; i < batch_size; i++) { +#ifndef USE_ISPC // Mean taken in dimension C. float mean = 0; - for (size_t c = 0; c < channels; ++c) { - data[i * channels + c] += skip[i * channels + c]; - mean += data[i * channels + c]; + if (skip != nullptr) { + for (size_t c = 0; c < channels; ++c) { + data[i * channels + c] += alpha * skip[i * channels + c]; + mean += data[i * channels + c]; + } + } else { + for (size_t c = 0; c < channels; ++c) { + mean += data[i * channels + c]; + } } mean /= channels; @@ -61,51 +58,22 @@ void LayerNorm2DWithSkipConnection(const size_t batch_size, var /= channels; // Norm. + float den = 1.0f / std::sqrt(var + epsilon); for (size_t c = 0; c < channels; ++c) { - data[i * channels + c] = betas[c] + gammas[c] * - (data[i * channels + c] - mean) / - std::sqrt(var + epsilon); + data[i * channels + c] = + betas[c] + gammas[c] * (data[i * channels + c] - mean) * den; } - } -} - -template -void AttentionMatmul2D(const bool transpose_a, const bool transpose_b, - const size_t batch_size, const size_t M, const size_t N, - const size_t K, const float scaling, const float* input1, - const float* input2, float* output) { - for (auto batch = size_t{0}; batch < batch_size; batch++) { - const float* A = &input1[batch * M * K]; - const float* B = &input2[batch * N * K]; - float* C = &output[batch * M * N]; - if (use_eigen) { - auto C_mat = EigenMatrixMap(C, N, M); - - if (transpose_a && transpose_b) { - C_mat.noalias() = scaling * - ConstEigenMatrixMap(B, K, N).transpose() * - ConstEigenMatrixMap(A, M, K).transpose(); - } else if (transpose_a) { - C_mat.noalias() = scaling * ConstEigenMatrixMap(B, N, K) * - ConstEigenMatrixMap(A, M, K).transpose(); - } else if (transpose_b) { - C_mat.noalias() = scaling * - ConstEigenMatrixMap(B, K, N).transpose() * - ConstEigenMatrixMap(A, K, M); - } else { - C_mat.noalias() = scaling * ConstEigenMatrixMap(B, N, K) * - ConstEigenMatrixMap(A, K, M); - } - } else { -#ifdef USE_BLAS - cblas_sgemm(CblasRowMajor, transpose_a ? CblasTrans : CblasNoTrans, - transpose_b ? CblasTrans : CblasNoTrans, M, N, K, scaling, A, - transpose_a ? M : K, B, transpose_b ? K : N, 0.0f, C, N); #else - // Should never get here. - throw Exception("Blas backend internal error"); -#endif + if (skip != nullptr) { + ispc::LayerNorm2DWithSkipConnection(channels, data + i * channels, alpha, + skip + i * channels, gammas, betas, + epsilon); + } else { + ispc::LayerNorm2DWithSkipConnection(channels, data + i * channels, 0.0f, + nullptr, gammas, betas, epsilon); } + +#endif } } diff --git a/src/neural/blas/fully_connected_layer.cc b/src/neural/blas/fully_connected_layer.cc index 2465779ed1..d228d2523f 100644 --- a/src/neural/blas/fully_connected_layer.cc +++ b/src/neural/blas/fully_connected_layer.cc @@ -103,7 +103,9 @@ void FullyConnectedLayer::Forward1D( outputs, // C (int)output_size); // ldc, leading rank of C } - ApplyBias(batch_size, output_size, biases, activation, outputs); + if (biases != nullptr) { + ApplyBias(batch_size, output_size, biases, activation, outputs); + } } template <> @@ -134,7 +136,9 @@ void FullyConnectedLayer::Forward1D( .transpose() * ConstEigenMatrixMap(inputs, input_size, batch_size); } - ApplyBias(batch_size, output_size, biases, activation, outputs); + if (biases != nullptr) { + ApplyBias(batch_size, output_size, biases, activation, outputs); + } } template <> diff --git a/src/neural/blas/layer_norm.ispc b/src/neural/blas/layer_norm.ispc new file mode 100644 index 0000000000..063bb74476 --- /dev/null +++ b/src/neural/blas/layer_norm.ispc @@ -0,0 +1,80 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2023 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + */ + +export void LayerNorm2DWithSkipConnection(uniform const size_t channels, + uniform float data[], + const uniform float alpha, + const uniform float skip[], + const uniform float gammas[], + const uniform float betas[], + uniform const float epsilon) { +#if 0 + // Faster but potentially less stable version for future testing. + // One pass mean and variance taken in dimension C. Uses shifted variance calculation. + float imean = 0; + float ivar = 0; + float k = data[0]; + if (skip != NULL) { + k += alpha * skip[0]; + foreach (c = 0 ... channels) { + float t = data[c] + alpha * skip[c]; + data[c] = t; + t -= k; + imean += t; + ivar += t * t; + } + } else { + foreach (c = 0 ... channels) { + float t = data[c]; + t -= k; + imean += t; + ivar += t * t; + } + } + float mean = reduce_add(imean) / channels; + float var = (reduce_add(ivar) - channels * mean * mean) / channels; + mean += k; +#else + // Mean taken in dimension C. + float imean = 0; + if (skip != NULL) { + foreach (c = 0 ... channels) { + data[c] += alpha * skip[c]; + imean += data[c]; + } + } else { + foreach (c = 0 ... channels) { + imean += data[c]; + } + } + float mean = reduce_add(imean) / channels; + + // Variance. + float ivar = 0; + foreach (c = 0 ... channels) { + float diff = data[c] - mean; + ivar += diff * diff; + } + float var = reduce_add(ivar) / channels; +#endif + + float den = rsqrt(var + epsilon); + foreach (c = 0 ... channels) { + data[c] = betas[c] + gammas[c] * (data[c] - mean) * den; + } +} diff --git a/src/neural/blas/network_blas.cc b/src/neural/blas/network_blas.cc index 3edc1f15b0..d609ee097c 100644 --- a/src/neural/blas/network_blas.cc +++ b/src/neural/blas/network_blas.cc @@ -41,6 +41,10 @@ #include #endif +#ifdef USE_ISPC +#include "activation_ispc.h" +#endif + namespace lczero { namespace { @@ -50,7 +54,9 @@ class BlasComputation : public NetworkComputation { BlasComputation(const LegacyWeights& weights, const size_t max_batch_size, const bool wdl, const bool moves_left, const bool conv_policy, const ActivationFunction default_activation, - const bool attn_policy); + const ActivationFunction smolgen_activation, + const ActivationFunction ffn_activation, + const bool attn_policy, const bool attn_body); virtual ~BlasComputation() {} @@ -98,6 +104,13 @@ class BlasComputation : public NetworkComputation { private: void EncodePlanes(const InputPlanes& sample, float* buffer); + void MakeEncoderLayer(std::vector& head_buffer, + std::vector& head_buffer2, + std::vector& head_buffer3, size_t batch_size, + const LegacyWeights::EncoderLayer& layer, + int embedding_size, int heads, + ActivationFunction smolgen_activation, + ActivationFunction ffn_activation, float alpha); static constexpr auto kWidth = 8; static constexpr auto kHeight = 8; @@ -117,7 +130,10 @@ class BlasComputation : public NetworkComputation { bool moves_left_; bool conv_policy_; ActivationFunction default_activation_; + ActivationFunction smolgen_activation_; + ActivationFunction ffn_activation_; bool attn_policy_; + bool attn_body_; }; template @@ -129,7 +145,8 @@ class BlasNetwork : public Network { std::unique_ptr NewComputation() override { return std::make_unique>( weights_, max_batch_size_, wdl_, moves_left_, conv_policy_, - default_activation_, attn_policy_); + default_activation_, smolgen_activation_, ffn_activation_, attn_policy_, + attn_body_); } const NetworkCapabilities& GetCapabilities() const override { @@ -149,14 +166,20 @@ class BlasNetwork : public Network { bool moves_left_; bool conv_policy_; ActivationFunction default_activation_; + ActivationFunction smolgen_activation_; + ActivationFunction ffn_activation_; bool attn_policy_; + bool attn_body_; }; template BlasComputation::BlasComputation( const LegacyWeights& weights, const size_t max_batch_size, const bool wdl, const bool moves_left, const bool conv_policy, - const ActivationFunction default_activation, const bool attn_policy) + const ActivationFunction default_activation, + const ActivationFunction smolgen_activation, + const ActivationFunction ffn_activation, const bool attn_policy, + const bool attn_body) : weights_(weights), max_batch_size_(max_batch_size), policies_(0), @@ -165,7 +188,10 @@ BlasComputation::BlasComputation( moves_left_(moves_left), conv_policy_(conv_policy), default_activation_(default_activation), - attn_policy_(attn_policy) { + smolgen_activation_(smolgen_activation), + ffn_activation_(ffn_activation), + attn_policy_(attn_policy), + attn_body_(attn_body) { #ifdef USE_DNNL omp_set_num_threads(1); #endif @@ -177,23 +203,225 @@ using EigenMatrixMap = template using ConstEigenMatrixMap = Eigen::Map>; +template +using EigenStridedMatrixMap = + Eigen::Map, 0, + Eigen::OuterStride<>>; +template +using ConstEigenStridedMatrixMap = + Eigen::Map, 0, + Eigen::OuterStride<>>; + +template +void BlasComputation::MakeEncoderLayer( + std::vector& head_buffer, std::vector& head_buffer2, + std::vector& head_buffer3, size_t batch_size, + const LegacyWeights::EncoderLayer& layer, int embedding_size, int heads, + ActivationFunction smolgen_activation, ActivationFunction ffn_activation, + float alpha) { + const int d_model = layer.mha.q_b.size(); + const int dff_size = layer.ffn.dense1_b.size(); + std::vector head_buffer4(batch_size * kSquares * + std::max(kSquares * heads, dff_size)); + + // Smolgen. + if (layer.mha.has_smolgen) { + const float* input = &head_buffer[0]; + float* QK = &head_buffer4[0]; + + // Compress. + const auto hidden_channels = + layer.mha.smolgen.compress.size() / embedding_size; + std::vector temp1(batch_size * kSquares * hidden_channels); + FullyConnectedLayer::Forward1D( + batch_size * kSquares, embedding_size, hidden_channels, input, + layer.mha.smolgen.compress.data(), (const float*)nullptr, + ACTIVATION_NONE, temp1.data()); + + // Dense 1. + const auto hidden_sz = layer.mha.smolgen.dense1_b.size(); + std::vector temp2(batch_size * hidden_sz); + FullyConnectedLayer::Forward1D( + batch_size, kSquares * hidden_channels, hidden_sz, temp1.data(), + layer.mha.smolgen.dense1_w.data(), layer.mha.smolgen.dense1_b.data(), + smolgen_activation, temp2.data()); + // Layer Norm + skip connection. + LayerNorm2DWithSkipConnection(batch_size, hidden_sz, temp2.data(), 0.0f, + (const float*)nullptr, + layer.mha.smolgen.ln1_gammas.data(), + layer.mha.smolgen.ln1_betas.data(), 1e-3); + + // Dense 2. + const auto gen_sz_outputs = layer.mha.smolgen.dense2_b.size(); + std::vector temp3(batch_size * gen_sz_outputs); + FullyConnectedLayer::Forward1D( + batch_size, hidden_sz, gen_sz_outputs, temp2.data(), + layer.mha.smolgen.dense2_w.data(), layer.mha.smolgen.dense2_b.data(), + smolgen_activation, temp3.data()); + // Layer Norm + skip connection. + LayerNorm2DWithSkipConnection(batch_size, gen_sz_outputs, temp3.data(), + 0.0f, (const float*)nullptr, + layer.mha.smolgen.ln2_gammas.data(), + layer.mha.smolgen.ln2_betas.data(), 1e-3); + + // Global smolgen weights. + FullyConnectedLayer::Forward1D( + batch_size * heads, gen_sz_outputs / heads, kSquares * kSquares, + temp3.data(), weights_.smolgen_w.data(), (const float*)nullptr, + ACTIVATION_NONE, QK); + } + + // Q + FullyConnectedLayer::Forward1D( + batch_size * kSquares, embedding_size, d_model, head_buffer.data(), + layer.mha.q_w.data(), layer.mha.q_b.data(), ACTIVATION_NONE, + head_buffer2.data()); + // K + FullyConnectedLayer::Forward1D( + batch_size * kSquares, embedding_size, d_model, head_buffer.data(), + layer.mha.k_w.data(), layer.mha.k_b.data(), ACTIVATION_NONE, + head_buffer3.data()); + + // MHA (Q, K, V) + const int depth = d_model / heads; + const float scaling = 1.0f / sqrtf(depth); + + // MHA is done per batch since there's a fourth dimension introduced. + for (auto batch = size_t{0}; batch < batch_size; batch++) { + auto batchStart = batch * kSquares * d_model; + + float* QK = &head_buffer4[batch * kSquares * kSquares * heads]; + + const float* Q = &head_buffer2[batchStart]; + const float* K = &head_buffer3[batchStart]; + + // matmul(Q, K) for all heads per batch. + + for (auto h = 0; h < heads; h++) { + const float* A = &Q[h * depth]; + const float* B = &K[h * depth]; + float* C = &QK[h * kSquares * kSquares]; + const float beta = layer.mha.has_smolgen ? 1.0f : 0.0f; + if (use_eigen) { + auto C_mat = EigenMatrixMap(C, kSquares, kSquares); + C_mat.noalias() = + beta * C_mat + + scaling * + ConstEigenStridedMatrixMap( + B, depth, kSquares, Eigen::OuterStride<>(heads * depth)) + .transpose() * + ConstEigenStridedMatrixMap( + A, depth, kSquares, Eigen::OuterStride<>(heads * depth)); + } else { +#ifdef USE_BLAS + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, kSquares, kSquares, + depth, scaling, A, heads * depth, B, heads * depth, beta, C, + kSquares); +#else + // Should never get here. + throw Exception("Blas backend internal error"); +#endif + } + } + } + + // Apply Softmax. + float* QK = &head_buffer4[0]; + for (size_t h = 0; h < batch_size * heads * kSquares * kSquares; + h += kSquares) { +#if defined(USE_ISPC) + if (!use_eigen) { + ispc::SoftmaxActivation(kSquares, QK + h, QK + h); + continue; + } +#endif + SoftmaxActivation(kSquares, QK + h, QK + h); + } + + // V + FullyConnectedLayer::Forward1D( + batch_size * kSquares, embedding_size, d_model, head_buffer.data(), + layer.mha.v_w.data(), layer.mha.v_b.data(), ACTIVATION_NONE, + head_buffer3.data()); + + for (auto batch = size_t{0}; batch < batch_size; batch++) { + auto batchStart = batch * kSquares * d_model; + // matmul(softmax(QK), V) for all heads per batch. + float* attn = &head_buffer2[batchStart]; + const float* V = &head_buffer3[batchStart]; + const float* QK = &head_buffer4[batch * kSquares * kSquares * heads]; + for (auto h = 0; h < heads; h++) { + const float* A = &QK[h * kSquares * kSquares]; + const float* B = &V[h * depth]; + float* C = &attn[h * depth]; + if (use_eigen) { + auto C_mat = EigenStridedMatrixMap( + C, depth, kSquares, Eigen::OuterStride<>(heads * depth)); + C_mat.noalias() = + ConstEigenStridedMatrixMap( + B, depth, kSquares, Eigen::OuterStride<>(heads * depth)) * + ConstEigenMatrixMap(A, kSquares, kSquares); + } else { +#ifdef USE_BLAS + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, kSquares, depth, + kSquares, 1.0f, A, kSquares, B, heads * depth, 0.0f, C, + heads * depth); +#endif + } + } + } + + // Fully connected final MHA layer. + FullyConnectedLayer::Forward1D( + batch_size * kSquares, d_model, embedding_size, head_buffer2.data(), + layer.mha.dense_w.data(), layer.mha.dense_b.data(), ACTIVATION_NONE, + head_buffer3.data()); + + // Layer Norm + skip connection. + LayerNorm2DWithSkipConnection(batch_size * kSquares, embedding_size, + head_buffer.data(), 1.0f / alpha, + head_buffer3.data(), layer.ln1_gammas.data(), + layer.ln1_betas.data(), 1e-6); + + // FFN. + FullyConnectedLayer::Forward1D( + batch_size * kSquares, embedding_size, dff_size, head_buffer.data(), + layer.ffn.dense1_w.data(), layer.ffn.dense1_b.data(), ffn_activation, + head_buffer4.data()); + + FullyConnectedLayer::Forward1D( + batch_size * kSquares, dff_size, layer.ffn.dense2_b.size(), + head_buffer4.data(), layer.ffn.dense2_w.data(), layer.ffn.dense2_b.data(), + ACTIVATION_NONE, head_buffer3.data()); + + // Layer Norm + skip connection. + LayerNorm2DWithSkipConnection(batch_size * kSquares, embedding_size, + head_buffer.data(), 1.0f / alpha, + head_buffer3.data(), layer.ln2_gammas.data(), + layer.ln2_betas.data(), 1e-6); +} template void BlasComputation::ComputeBlocking() { // Retrieve network key dimensions from the weights structure. const auto num_value_channels = weights_.ip1_val_b.size(); const auto num_moves_channels = weights_.ip1_mov_b.size(); - const auto num_value_input_planes = weights_.value.biases.size(); + const auto num_value_input_planes = + attn_body_ ? weights_.ip_val_b.size() : weights_.value.biases.size(); const auto num_policy_input_planes = weights_.policy.biases.size(); - const auto num_moves_input_planes = weights_.moves_left.biases.size(); + const auto num_moves_input_planes = + attn_body_ ? weights_.ip_mov_b.size() : weights_.moves_left.biases.size(); const auto num_output_policy = static_cast(kPolicyOutputs); - const auto output_channels = weights_.input.biases.size(); + const auto output_channels = + attn_body_ ? weights_.ip_emb_b.size() : weights_.input.biases.size(); + const auto num_res_blocks = weights_.residual.size(); // max_channels is the maximum number of input channels of any // convolution. // Residual blocks are identical, but the first convolution might be bigger // when the network has very few filters - const auto input_channels = static_cast(kInputPlanes); + const auto input_channels = static_cast( + kInputPlanes + (attn_body_ ? kNumPosEncodingChannels : 0)); const auto max_channels = std::max(output_channels, input_channels); // The policy head may increase convolution max output size. @@ -203,8 +431,8 @@ void BlasComputation::ComputeBlocking() { : output_channels; // Determine the largest batch for allocations. - const auto plane_count = planes_.size(); - const auto largest_batch_size = std::min(max_batch_size_, plane_count); + const auto total_batches = planes_.size(); + const auto largest_batch_size = std::min(max_batch_size_, total_batches); /* Typically input_channels = 112 @@ -222,10 +450,8 @@ void BlasComputation::ComputeBlocking() { std::vector output_fc(largest_batch_size * max_fc_channels); std::vector res_buffer1(largest_batch_size * max_channels * kSquares); - std::vector res_buffer2(largest_batch_size * output_channels * - kSquares); - std::vector res_buffer3(largest_batch_size * output_channels * - kSquares); + std::vector res_buffer2(largest_batch_size * max_channels * kSquares); + std::vector res_buffer3(largest_batch_size * max_channels * kSquares); WinogradConvolution3 convolve3(largest_batch_size, max_channels, max_output_channels); @@ -245,73 +471,140 @@ void BlasComputation::ComputeBlocking() { float* conv_out = res_buffer2.data(); float* res = res_buffer3.data(); - for (size_t i = 0; i < plane_count; i += largest_batch_size) { - const auto batch_size = std::min(plane_count - i, largest_batch_size); + for (size_t i = 0; i < total_batches; i += largest_batch_size) { + const auto batch_size = std::min(total_batches - i, largest_batch_size); for (size_t j = 0; j < batch_size; j++) { EncodePlanes(planes_[i + j], &conv_in[j * kSquares * kInputPlanes]); } - // Input convolution + if (num_res_blocks > 0) { + // Input convolution - convolve3.Forward(batch_size, kInputPlanes, output_channels, conv_in, - weights_.input.weights.data(), conv_out); + convolve3.Forward(batch_size, kInputPlanes, output_channels, conv_in, + weights_.input.weights.data(), conv_out); - BiasActivate(batch_size, output_channels, conv_out, - weights_.input.biases.data(), default_activation_); + BiasActivate(batch_size, output_channels, conv_out, + weights_.input.biases.data(), default_activation_); - // Residual tower + // Residual tower - for (auto& residual : weights_.residual) { - const auto& conv1 = residual.conv1; - const auto& conv2 = residual.conv2; - const auto& se = residual.se; + for (auto& residual : weights_.residual) { + const auto& conv1 = residual.conv1; + const auto& conv2 = residual.conv2; + const auto& se = residual.se; - std::swap(conv_out, conv_in); + std::swap(conv_out, conv_in); - convolve3.Forward(batch_size, output_channels, output_channels, conv_in, - conv1.weights.data(), conv_out); + convolve3.Forward(batch_size, output_channels, output_channels, conv_in, + conv1.weights.data(), conv_out); - BiasActivate(batch_size, output_channels, &conv_out[0], - conv1.biases.data(), default_activation_); + BiasActivate(batch_size, output_channels, &conv_out[0], + conv1.biases.data(), default_activation_); - std::swap(conv_in, res); - std::swap(conv_out, conv_in); + std::swap(conv_in, res); + std::swap(conv_out, conv_in); - convolve3.Forward(batch_size, output_channels, output_channels, conv_in, - conv2.weights.data(), conv_out); + convolve3.Forward(batch_size, output_channels, output_channels, conv_in, + conv2.weights.data(), conv_out); - if (residual.has_se) { - // No relu if followed by SE-unit and residual/bias is added later - std::swap(conv_out, conv_in); + if (residual.has_se) { + // No relu if followed by SE-unit and residual/bias is added later + std::swap(conv_out, conv_in); - auto se_fc_outputs = se.b1.size(); - ApplySEUnit(batch_size, output_channels, se_fc_outputs, - conv_in, conv2.biases.data(), res, se.w1.data(), - se.b1.data(), se.w2.data(), se.b2.data(), - conv_out, default_activation_); - } else { - BiasResidual(batch_size, output_channels, &conv_out[0], - conv2.biases.data(), res, default_activation_); + auto se_fc_outputs = se.b1.size(); + ApplySEUnit(batch_size, output_channels, se_fc_outputs, + conv_in, conv2.biases.data(), res, + se.w1.data(), se.b1.data(), se.w2.data(), + se.b2.data(), conv_out, default_activation_); + } else { + BiasResidual(batch_size, output_channels, &conv_out[0], + conv2.biases.data(), res, default_activation_); + } } } + if (attn_body_) { + const auto embedding_size = weights_.ip_emb_b.size(); + assert(embedding_size > 0); + const auto input_size = + num_res_blocks == 0 ? input_channels : weights_.input.biases.size(); + + if (num_res_blocks == 0) { + // No residual means pure transformer, so process input position + // encoding. + // Preprocess for attention body. + for (auto batch = size_t{0}; batch < batch_size; batch++) { + for (auto i = 0; i < kSquares; i++) { + // NCHW to NHWC conversion. + for (size_t j = 0; j < kInputPlanes; j++) { + res[batch * kSquares * input_size + i * input_size + j] = + conv_in[batch * kSquares * kInputPlanes + j * kSquares + i]; + } + // Position encoding. + for (size_t j = kInputPlanes; j < input_size; j++) { + res[batch * kSquares * input_size + i * input_size + j] = + kPosEncoding[i][j - kInputPlanes]; + } + } + } + } + + // Input embedding. + FullyConnectedLayer::Forward1D( + batch_size * kSquares, input_size, embedding_size, res_buffer3.data(), + weights_.ip_emb_w.data(), weights_.ip_emb_b.data(), + default_activation_, res_buffer1.data()); + + // Input gating + if (weights_.ip_mult_gate.size() > 0 && weights_.ip_add_gate.size() > 0) { + int idx; + for (auto batch = size_t{0}; batch < batch_size; batch++) { + for (auto i = 0; i < kSquares; i++) { + for (size_t j = 0; j < embedding_size; j++) { + idx = batch * kSquares * embedding_size + i * embedding_size + j; + res_buffer1[idx] = + res_buffer1[idx] * weights_.ip_mult_gate[j * kSquares + i] + + weights_.ip_add_gate[j * kSquares + i]; + } + } + }; + } + + // Attention body encoders. + float alpha = (float)pow(2.0 * weights_.encoder.size(), 0.25); + for (auto& layer : weights_.encoder) { + MakeEncoderLayer(res_buffer1, res_buffer2, res_buffer3, batch_size, + layer, embedding_size, weights_.encoder_head_count, + smolgen_activation_, ffn_activation_, alpha); + } + + res = res_buffer1.data(); + conv_in = res_buffer2.data(); + conv_out = res_buffer3.data(); + } + // Need to preserve conv_out which is used for value and moves left heads. if (attn_policy_) { - // NCHW to NHWC conversion. - for (auto batch = size_t{0}; batch < batch_size; batch++) { - for (auto i = 0; i < kSquares; i++) { - for (size_t j = 0; j < output_channels; j++) { - res[batch * kSquares * output_channels + i * output_channels + j] = - conv_out[batch * kSquares * output_channels + j * kSquares + i]; + if (!attn_body_) { + // NCHW to NHWC conversion. + for (auto batch = size_t{0}; batch < batch_size; batch++) { + for (auto i = 0; i < kSquares; i++) { + for (size_t j = 0; j < output_channels; j++) { + res[batch * kSquares * output_channels + i * output_channels + + j] = conv_out[batch * kSquares * output_channels + + j * kSquares + i]; + } } } } - const size_t embedding_size = weights_.ip_pol_b.size(); - // Embedding. + const size_t policy_embedding_size = weights_.ip_pol_b.size(); + // Policy Embedding. FullyConnectedLayer::Forward1D( - batch_size * kSquares, output_channels, embedding_size, res, + batch_size * kSquares, output_channels, policy_embedding_size, res, weights_.ip_pol_w.data(), weights_.ip_pol_b.data(), - SELU, // SELU activation for attention head. + attn_body_ + ? default_activation_ + : ACTIVATION_SELU, // SELU activation hardcoded for apmish nets. head_buffer.data()); const size_t policy_d_model = weights_.ip2_pol_b.size(); @@ -324,131 +617,24 @@ void BlasComputation::ComputeBlocking() { std::vector head_buffer3(largest_batch_size * max_channel_size * kSquares); - if (weights_.pol_encoder.size() > 0) { - std::vector head_buffer4(largest_batch_size * max_channel_size * - kSquares); - std::vector temp_buffer1(policy_d_model * kSquares); - std::vector temp_buffer2(policy_d_model * kSquares); - std::vector temp_buffer3(policy_d_model * kSquares); - - for (auto layer : weights_.pol_encoder) { - // Q - FullyConnectedLayer::Forward1D( - batch_size * kSquares, embedding_size, layer.mha.q_b.size(), - head_buffer.data(), layer.mha.q_w.data(), layer.mha.q_b.data(), - NONE, head_buffer2.data()); - // K - FullyConnectedLayer::Forward1D( - batch_size * kSquares, embedding_size, layer.mha.k_b.size(), - head_buffer.data(), layer.mha.k_w.data(), layer.mha.k_b.data(), - NONE, head_buffer3.data()); - // V - FullyConnectedLayer::Forward1D( - batch_size * kSquares, embedding_size, layer.mha.v_b.size(), - head_buffer.data(), layer.mha.v_w.data(), layer.mha.v_b.data(), - NONE, head_buffer4.data()); - - // MHA (Q, K, V) - const int d_model = layer.mha.q_b.size(); - const int heads = weights_.pol_encoder_head_count; - const int depth = d_model / heads; - const float scaling = 1.0f / sqrtf(depth); - - // MHA is done per batch since there's a fourth dimension introduced. - for (auto batch = size_t{0}; batch < batch_size; batch++) { - auto batchStart = batch * kSquares * d_model; - - // Reshape and transpose for each head. - const float* Q = temp_buffer1.data(); - const float* K = temp_buffer2.data(); - const float* V = temp_buffer3.data(); - - for (int head = 0; head < heads; head++) { - for (int j = 0; j < kSquares; j++) { - auto channelStart = batchStart + j * d_model + head * depth; - auto transposeStart = head * kSquares * depth + j * depth; - std::copy(head_buffer2.begin() + channelStart, - head_buffer2.begin() + channelStart + depth, - temp_buffer1.begin() + transposeStart); - std::copy(head_buffer3.begin() + channelStart, - head_buffer3.begin() + channelStart + depth, - temp_buffer2.begin() + transposeStart); - std::copy(head_buffer4.begin() + channelStart, - head_buffer4.begin() + channelStart + depth, - temp_buffer3.begin() + transposeStart); - } - } - - // matmul(Q, K) for all heads per batch. - float* QK = &head_buffer2[batchStart]; - AttentionMatmul2D(false, true, heads, kSquares, kSquares, - depth, scaling, Q, K, QK); - - // Apply Softmax. - for (int h = 0; h < heads * kSquares * kSquares; h += kSquares) { - SoftmaxActivation(kSquares, QK + h, QK + h); - } - - // matmul(softmax(QK), V) for all heads per batch. - float* attn = &head_buffer3[batchStart]; - AttentionMatmul2D(false, false, heads, kSquares, depth, - kSquares, 1.0, QK, V, attn); - - // Transpose back into N x 64 x H x D. - for (int j = 0; j < kSquares; j++) { - for (int head = 0; head < heads; head++) { - auto transposeStart = - batchStart + head * kSquares * depth + j * depth; - std::copy(head_buffer3.begin() + transposeStart, - head_buffer3.begin() + transposeStart + depth, - head_buffer2.begin() + batchStart + j * d_model + - head * depth); - } - } - } - - // Fully connected final MHA layer. - FullyConnectedLayer::Forward1D( - batch_size * kSquares, d_model, embedding_size, - head_buffer2.data(), layer.mha.dense_w.data(), - layer.mha.dense_b.data(), NONE, head_buffer3.data()); - - // Layer Norm + skip connection. - LayerNorm2DWithSkipConnection(batch_size * kSquares, embedding_size, - head_buffer.data(), head_buffer3.data(), - layer.ln1_gammas.data(), - layer.ln1_betas.data(), 1e-6); - - // FFN. - const size_t dff_size = layer.ffn.dense1_b.size(); - FullyConnectedLayer::Forward1D( - batch_size * kSquares, embedding_size, dff_size, - head_buffer.data(), layer.ffn.dense1_w.data(), - layer.ffn.dense1_b.data(), SELU, head_buffer2.data()); - - FullyConnectedLayer::Forward1D( - batch_size * kSquares, dff_size, layer.ffn.dense2_b.size(), - head_buffer2.data(), layer.ffn.dense2_w.data(), - layer.ffn.dense2_b.data(), NONE, head_buffer3.data()); - - // Layer Norm + skip connection. - LayerNorm2DWithSkipConnection(batch_size * kSquares, embedding_size, - head_buffer.data(), head_buffer3.data(), - layer.ln2_gammas.data(), - layer.ln2_betas.data(), 1e-6); - } + for (auto& layer : weights_.pol_encoder) { + MakeEncoderLayer(head_buffer, head_buffer2, head_buffer3, batch_size, + layer, policy_embedding_size, + weights_.pol_encoder_head_count, + attn_body_ ? smolgen_activation_ : ACTIVATION_NONE, + attn_body_ ? ffn_activation_ : ACTIVATION_SELU, 1.0f); } // Q FullyConnectedLayer::Forward1D( - batch_size * kSquares, embedding_size, policy_d_model, + batch_size * kSquares, policy_embedding_size, policy_d_model, head_buffer.data(), weights_.ip2_pol_w.data(), - weights_.ip2_pol_b.data(), NONE, head_buffer2.data()); + weights_.ip2_pol_b.data(), ACTIVATION_NONE, head_buffer2.data()); // K FullyConnectedLayer::Forward1D( - batch_size * kSquares, embedding_size, policy_d_model, + batch_size * kSquares, policy_embedding_size, policy_d_model, head_buffer.data(), weights_.ip3_pol_w.data(), - weights_.ip3_pol_b.data(), NONE, head_buffer3.data()); + weights_.ip3_pol_b.data(), ACTIVATION_NONE, head_buffer3.data()); const float scaling = 1.0f / sqrtf(policy_d_model); for (auto batch = size_t{0}; batch < batch_size; batch++) { const float* A = &head_buffer2[batch * 64 * policy_d_model]; @@ -515,6 +701,7 @@ void BlasComputation::ComputeBlocking() { } } } else if (conv_policy_) { + assert(!attn_body_); // not supported with attention body convolve3.Forward(batch_size, output_channels, output_channels, conv_out, weights_.policy1.weights.data(), res); @@ -526,7 +713,7 @@ void BlasComputation::ComputeBlocking() { head_buffer.data()); BiasActivate(batch_size, num_policy_input_planes, &head_buffer.data()[0], - weights_.policy.biases.data(), NONE); + weights_.policy.biases.data(), ACTIVATION_NONE); // Mapping from convolutional policy to lc0 policy for (auto batch = size_t{0}; batch < batch_size; batch++) { @@ -540,6 +727,7 @@ void BlasComputation::ComputeBlocking() { } } else { + assert(!attn_body_); // not supported with attention body Convolution1::Forward( batch_size, output_channels, num_policy_input_planes, conv_out, weights_.policy.weights.data(), head_buffer.data()); @@ -551,7 +739,7 @@ void BlasComputation::ComputeBlocking() { batch_size, num_policy_input_planes * kSquares, num_output_policy, head_buffer.data(), weights_.ip_pol_w.data(), weights_.ip_pol_b.data(), - NONE, // Activation Off + ACTIVATION_NONE, // Activation Off output_fc.data()); } @@ -565,12 +753,19 @@ void BlasComputation::ComputeBlocking() { } // Value head - Convolution1::Forward( - batch_size, output_channels, num_value_input_planes, conv_out, - weights_.value.weights.data(), head_buffer.data()); + if (attn_body_) { + FullyConnectedLayer::Forward1D( + batch_size * kSquares, weights_.ip_emb_b.size(), + num_value_input_planes, res, weights_.ip_val_w.data(), + weights_.ip_val_b.data(), default_activation_, head_buffer.data()); + } else { + Convolution1::Forward( + batch_size, output_channels, num_value_input_planes, conv_out, + weights_.value.weights.data(), head_buffer.data()); - BiasActivate(batch_size, num_value_input_planes, &head_buffer[0], - weights_.value.biases.data(), default_activation_); + BiasActivate(batch_size, num_value_input_planes, &head_buffer[0], + weights_.value.biases.data(), default_activation_); + } FullyConnectedLayer::Forward1D( batch_size, num_value_input_planes * kSquares, num_value_channels, @@ -585,7 +780,7 @@ void BlasComputation::ComputeBlocking() { FullyConnectedLayer::Forward1D( batch_size, num_value_channels, 3, output_fc.data(), weights_.ip2_val_w.data(), weights_.ip2_val_b.data(), - NONE, // Activation Off + ACTIVATION_NONE, // Activation Off wdl.data()); for (size_t j = 0; j < batch_size; j++) { @@ -607,12 +802,19 @@ void BlasComputation::ComputeBlocking() { } } if (moves_left_) { - Convolution1::Forward( - batch_size, output_channels, num_moves_input_planes, conv_out, - weights_.moves_left.weights.data(), head_buffer.data()); + if (attn_body_) { + FullyConnectedLayer::Forward1D( + batch_size * kSquares, weights_.ip_emb_b.size(), + num_moves_input_planes, res, weights_.ip_mov_w.data(), + weights_.ip_mov_b.data(), default_activation_, head_buffer.data()); + } else { + Convolution1::Forward( + batch_size, output_channels, num_moves_input_planes, conv_out, + weights_.moves_left.weights.data(), head_buffer.data()); - BiasActivate(batch_size, num_moves_input_planes, &head_buffer[0], - weights_.moves_left.biases.data(), default_activation_); + BiasActivate(batch_size, num_moves_input_planes, &head_buffer[0], + weights_.moves_left.biases.data(), default_activation_); + } FullyConnectedLayer::Forward1D( batch_size, num_moves_input_planes * kSquares, num_moves_channels, @@ -625,7 +827,7 @@ void BlasComputation::ComputeBlocking() { FullyConnectedLayer::Forward1D( batch_size, num_moves_channels, 1, output_fc.data(), weights_.ip2_mov_w.data(), weights_.ip2_mov_b.data(), - RELU, // Specifically Relu + ACTIVATION_RELU, // Specifically Relu output_moves_left.data()); for (size_t j = 0; j < batch_size; j++) { @@ -651,7 +853,7 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, : capabilities_{file.format().network_format().input(), file.format().network_format().moves_left()}, weights_(file.weights()) { - Numa::Init(); + Numa::Init(); max_batch_size_ = static_cast(options.GetOrDefault("batch_size", 256)); @@ -669,10 +871,25 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, attn_policy_ = file.format().network_format().policy() == pblczero::NetworkFormat::POLICY_ATTENTION; + attn_body_ = file.format().network_format().network() == + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; + default_activation_ = file.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH - ? MISH - : RELU; + ? ACTIVATION_MISH + : ACTIVATION_RELU; + + if (attn_body_) { + const auto smol_act = file.format().network_format().smolgen_activation(); + smolgen_activation_ = + smol_act == pblczero::NetworkFormat::ACTIVATION_DEFAULT + ? default_activation_ + : static_cast(smol_act); + const auto ffn_act = file.format().network_format().ffn_activation(); + ffn_activation_ = ffn_act == pblczero::NetworkFormat::ACTIVATION_DEFAULT + ? default_activation_ + : static_cast(ffn_act); + } if (max_batch_size_ > kHardMaxBatchSize) { max_batch_size_ = kHardMaxBatchSize; @@ -755,7 +972,9 @@ std::unique_ptr MakeBlasNetwork(const std::optional& w, if (weights.format().network_format().network() != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT) { + pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + weights.format().network_format().network() != + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( weights.format().network_format().network()) + diff --git a/src/neural/blas/se_unit.cc b/src/neural/blas/se_unit.cc index c11a95cae8..a4f24f0e0b 100644 --- a/src/neural/blas/se_unit.cc +++ b/src/neural/blas/se_unit.cc @@ -82,7 +82,7 @@ void ApplySEUnit(const size_t batch_size, const size_t channels, FullyConnectedLayer::Forward1D(batch_size, se_fc_outputs, 2 * channels, fc_out1.data(), weights_w2, weights_b2, - NONE, // Activation Off + ACTIVATION_NONE, // Activation Off pool.data()); // Sigmoid, scale and add residual diff --git a/src/neural/cache.cc b/src/neural/cache.cc index d729a562f0..9c11fe5e1c 100644 --- a/src/neural/cache.cc +++ b/src/neural/cache.cc @@ -25,13 +25,22 @@ Program grant you additional permission to convey the resulting work. */ #include "neural/cache.h" + +#include #include #include +#include "utils/fastmath.h" + namespace lczero { CachingComputation::CachingComputation( - std::unique_ptr parent, NNCache* cache) - : parent_(std::move(parent)), cache_(cache) {} + std::unique_ptr parent, + pblczero::NetworkFormat::InputFormat input_format, + lczero::FillEmptyHistory history_fill, NNCache* cache) + : parent_(std::move(parent)), + input_format_(input_format), + history_fill_(history_fill), + cache_(cache) {} int CachingComputation::GetCacheMisses() const { return parent_->GetBatchSize(); @@ -60,15 +69,25 @@ void CachingComputation::PopCacheHit() { batch_.pop_back(); } -void CachingComputation::AddInput( - uint64_t hash, InputPlanes&& input, - std::vector&& probabilities_to_cache) { - if (AddInputByHash(hash)) return; +void CachingComputation::AddInput(uint64_t hash, + const PositionHistory& history) { + if (AddInputByHash(hash)) { + return; + } + int transform; + auto input = + EncodePositionForNN(input_format_, history, 8, history_fill_, &transform); batch_.emplace_back(); batch_.back().hash = hash; batch_.back().idx_in_parent = parent_->GetBatchSize(); - batch_.back().probabilities_to_cache = probabilities_to_cache; + // Cache legal moves. + std::vector moves = history.Last().GetBoard().GenerateLegalMoves(); + batch_.back().eval = std::make_shared(); + batch_.back().eval->edges = Edge::FromMovelist(moves); + batch_.back().eval->num_edges = moves.size(); + batch_.back().transform = transform; parent_->AddInput(std::move(input)); + return; } void CachingComputation::PopLastInputHit() { @@ -77,61 +96,58 @@ void CachingComputation::PopLastInputHit() { batch_.pop_back(); } -void CachingComputation::ComputeBlocking() { +void CachingComputation::ComputeBlocking(float softmax_temp) { if (parent_->GetBatchSize() == 0) return; parent_->ComputeBlocking(); // Fill cache with data from NN. - for (const auto& item : batch_) { + for (auto& item : batch_) { if (item.idx_in_parent == -1) continue; - auto req = - std::make_unique(item.probabilities_to_cache.size()); - req->q = parent_->GetQVal(item.idx_in_parent); - req->d = parent_->GetDVal(item.idx_in_parent); - req->m = parent_->GetMVal(item.idx_in_parent); - int idx = 0; - for (auto x : item.probabilities_to_cache) { - req->p[idx++] = - std::make_pair(x, parent_->GetPVal(item.idx_in_parent, x)); + item.eval->q = parent_->GetQVal(item.idx_in_parent); + item.eval->d = parent_->GetDVal(item.idx_in_parent); + item.eval->m = parent_->GetMVal(item.idx_in_parent); + + // Calculate maximum first. + float max_p = -std::numeric_limits::infinity(); + // Intermediate array to store values when processing policy. + // There are never more than 256 valid legal moves in any legal position. + std::array intermediate; + int transform = item.transform; + int num_edges = item.eval->num_edges; + auto edges = item.eval->edges.get(); + for (int ct = 0; ct < num_edges; ct++) { + auto move = edges[ct].GetMove(); + float p = + parent_->GetPVal(item.idx_in_parent, move.as_nn_index(transform)); + intermediate[ct] = p; + max_p = std::max(max_p, p); + } + float total = 0.0; + for (int ct = 0; ct < num_edges; ct++) { + // Perform softmax and take into account policy softmax temperature T. + // Note that we want to calculate (exp(p-max_p))^(1/T) = exp((p-max_p)/T). + float p = FastExp((intermediate[ct] - max_p) / softmax_temp); + intermediate[ct] = p; + total += p; + } + // Normalize P values to add up to 1.0. + const float scale = total > 0.0f ? 1.0f / total : 1.0f; + for (int ct = 0; ct < num_edges; ct++) { + edges[ct].SetP(intermediate[ct] * scale); } - cache_->Insert(item.hash, std::move(req)); - } -} -float CachingComputation::GetQVal(int sample) const { - const auto& item = batch_[sample]; - if (item.idx_in_parent >= 0) return parent_->GetQVal(item.idx_in_parent); - return item.lock->q; -} + Edge::SortEdges(item.eval->edges.get(), item.eval->num_edges); -float CachingComputation::GetDVal(int sample) const { - const auto& item = batch_[sample]; - if (item.idx_in_parent >= 0) return parent_->GetDVal(item.idx_in_parent); - return item.lock->d; + auto req = std::make_unique(); + req->eval = item.eval; + cache_->Insert(item.hash, std::move(req)); + } } -float CachingComputation::GetMVal(int sample) const { +std::shared_ptr CachingComputation::GetNNEval(int sample) const { const auto& item = batch_[sample]; - if (item.idx_in_parent >= 0) return parent_->GetMVal(item.idx_in_parent); - return item.lock->m; -} - -float CachingComputation::GetPVal(int sample, int move_id) const { - auto& item = batch_[sample]; - if (item.idx_in_parent >= 0) - return parent_->GetPVal(item.idx_in_parent, move_id); - const auto& moves = item.lock->p; - - int total_count = 0; - while (total_count < moves.size()) { - // Optimization: usually moves are stored in the same order as queried. - const auto& move = moves[item.last_idx++]; - if (item.last_idx == moves.size()) item.last_idx = 0; - if (move.first == move_id) return move.second; - ++total_count; - } - assert(false); // Move not found. - return 0; + if (item.idx_in_parent >= 0) return item.eval; + return item.lock->eval; } } // namespace lczero diff --git a/src/neural/cache.h b/src/neural/cache.h index 207e0fe6e4..1d0aaeef97 100644 --- a/src/neural/cache.h +++ b/src/neural/cache.h @@ -26,20 +26,15 @@ */ #pragma once +#include "mcts/node.h" +#include "neural/encoder.h" #include "neural/network.h" #include "utils/cache.h" -#include "utils/smallarray.h" namespace lczero { struct CachedNNRequest { - CachedNNRequest(size_t size) : p(size) {} - typedef std::pair IdxAndProb; - float q; - float d; - float m; - // TODO(mooskagh) Don't really need index if using perfect hash. - SmallArray p; + std::shared_ptr eval; }; typedef HashKeyedCache NNCache; @@ -51,7 +46,8 @@ typedef HashKeyedCacheLock NNCacheLock; class CachingComputation { public: CachingComputation(std::unique_ptr parent, - NNCache* cache); + pblczero::NetworkFormat::InputFormat input_format, + FillEmptyHistory history_fill, NNCache* cache); // How many inputs are not found in cache and will be forwarded to a wrapped // computation. @@ -64,24 +60,20 @@ class CachingComputation { // Adds input by hash with existing lock. Assumes the given lock holds a real // reference. void AddInputByHash(uint64_t hash, NNCacheLock&& lock); - // Adds a sample to the batch. + // Adds a sample to the batch. Also calls EncodePositionForNN() if needed. // @hash is a hash to store/lookup it in the cache. - // @probabilities_to_cache is which indices of policy head to store. - void AddInput(uint64_t hash, InputPlanes&& input, - std::vector&& probabilities_to_cache); - // Undos last AddInput. If it was a cache miss, the it's actually not removed + void AddInput(uint64_t hash, const PositionHistory& history); + // Undos last AddInput. If it was a cache miss, then it's actually not removed // from parent's batch. void PopLastInputHit(); // Do the computation. - void ComputeBlocking(); - // Returns Q value of @sample. - float GetQVal(int sample) const; - // Returns probability of draw if NN has WDL value head. - float GetDVal(int sample) const; - // Returns estimated remaining moves. - float GetMVal(int sample) const; - // Returns P value @move_id of @sample. - float GetPVal(int sample, int move_id) const; + void ComputeBlocking(float softmax_temp); + // Returns NN eval of @sample. + std::shared_ptr GetNNEval(int sample) const; + + // Returns compressed P value @move_ct of @sample. + uint16_t GetPVal(int sample, int move_ct) const; + Move GetMove(int sample, int move_ct) const; // Pops last input from the computation. Only allowed for inputs which were // cached. void PopCacheHit(); @@ -94,11 +86,13 @@ class CachingComputation { uint64_t hash; NNCacheLock lock; int idx_in_parent = -1; - std::vector probabilities_to_cache; - mutable int last_idx = 0; + std::shared_ptr eval; + int transform; }; std::unique_ptr parent_; + pblczero::NetworkFormat::InputFormat input_format_; + FillEmptyHistory history_fill_; NNCache* cache_; std::vector batch_; }; diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 684d8d1e8a..2959cfc9ec 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -25,9 +25,12 @@ Program grant you additional permission to convey the resulting work. */ -#include #include +#include + #include "cuda_common.h" +#include "neural/shared/activation.h" +#include "neural/shared/attention_policy_map.h" #include "winograd_helper.inc" namespace lczero { @@ -71,6 +74,34 @@ void addVectors(T* c, T* a, T* b, int size, int asize, int bsize, ReportCUDAErrors(cudaGetLastError()); } +template +__global__ void addVectorsHNC_NHC_kernel(T* a, T* b, int N, int H, int C) { + int i = threadIdx.x + blockDim.x * blockIdx.x; + if (i < N * H * C) { + int orig_i = i; + int c = i % C; + i /= C; + int n = i % N; + i /= N; + int h = i; + float aVal = (float)a[orig_i]; + float bVal = (float)b[n * H * C + h * C + c]; + + float cVal = aVal + bVal; + + a[orig_i] = (T)cVal; + } +} + +template +void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream) { + const int kBlockSize = 256; + int blocks = DivUp(N * H * C, kBlockSize); + addVectorsHNC_NHC_kernel<<>>(a, b, N, H, C); + + ReportCUDAErrors(cudaGetLastError()); +} + template __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, int N, int C) { @@ -100,7 +131,7 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, copyAs(&val[0], &input[tensorIndex]); copyAs(&b[0], &bias[biasIndex]); } - + // Perform bias add and activation #pragma unroll for (int i = 0; i < 4; i++) { @@ -131,20 +162,142 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, dim3 blockDim, gridDim; blockDim.x = C / 4; - blockDim.y = std::min(std::max(512 / blockDim.x, 1u), (unsigned int) N); + blockDim.y = std::min(std::max(512 / blockDim.x, 1u), (unsigned int)N); + blockDim.z = 1; + gridDim.x = DivUp(N, blockDim.y); + gridDim.y = Batch; + gridDim.z = 1; + + switch (activation) { + case ACTIVATION_NONE: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; + case ACTIVATION_SELU: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; + case ACTIVATION_MISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; + case ACTIVATION_RELU: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; + case ACTIVATION_SWISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; + case ACTIVATION_RELU_2: // square relu + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; + default: + throw Exception( + "unsupported activation in addBiasBatched. Add in switch-case here"); + } + + ReportCUDAErrors(cudaGetLastError()); +} + +template +__global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, + int N, int C, int Nstride) { + int batch = blockIdx.y; + int n = blockIdx.x * blockDim.y + threadIdx.y; + if (n >= N) return; + int c = threadIdx.x * 4; + + int biasIndex = batch * C + c; + int tensorIndex = batch * Nstride * C + n * C + c; + + float val[4]; + float b[4]; + + // Load from memory + const bool fp16 = std::is_same::value; + if (fp16) { + half inp[4]; + copyAs(&inp[0], &input[tensorIndex]); +#pragma unroll + for (int i = 0; i < 4; i++) val[i] = (float)inp[i]; + + copyAs(&inp[0], &bias[biasIndex]); +#pragma unroll + for (int i = 0; i < 4; i++) b[i] = (float)inp[i]; + } else { + copyAs(&val[0], &input[tensorIndex]); + copyAs(&b[0], &bias[biasIndex]); + } + + // Perform bias add and activation +#pragma unroll + for (int i = 0; i < 4; i++) { + float x = val[i] + b[i]; + x = activate(x, act); + val[i] = x; + } + + // write to memory + if (fp16) { + half op[4]; +#pragma unroll + for (int i = 0; i < 4; i++) op[i] = (half)val[i]; + copyAs(&output[tensorIndex], &op[0]); + } else { + copyAs(&output[tensorIndex], &val[0]); + } +} + +// Input/output tensors are Batch * N * C +// bias tensor is N * C (i.e, different bias for each Batch dimension) +template +void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, + int C, int Nstride, ActivationFunction activation, + cudaStream_t stream) { + // process 4 elements per thread to achieve close to peak memory bandwidth + if (C % 4 != 0) throw Exception("unsupported filter size"); + if (C > 4096) throw Exception("unsupported filter size"); + + dim3 blockDim, gridDim; + blockDim.x = C / 4; + blockDim.y = std::min(std::max(512 / blockDim.x, 1u), (unsigned int)N); blockDim.z = 1; gridDim.x = DivUp(N, blockDim.y); gridDim.y = Batch; gridDim.z = 1; switch (activation) { - case NONE: - addBiasBatched_kernel<<>>( - output, input, bias, N, C); + case ACTIVATION_NONE: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, + Nstride); + break; + case ACTIVATION_SELU: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, + Nstride); + break; + case ACTIVATION_MISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, + Nstride); break; - case SELU: - addBiasBatched_kernel<<>>( - output, input, bias, N, C); + case ACTIVATION_RELU: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, + Nstride); + break; + case ACTIVATION_SWISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, + Nstride); + break; + case ACTIVATION_RELU_2: // square relu + addBiasBatched_kernel + <<>>(output, input, bias, N, C, + Nstride); break; default: throw Exception( @@ -641,7 +794,8 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, // Each thread processes entire chess board if (use_se == false) { dim3 grid_dim(DivUp(C, kOpInpTransformBlockSize), N, 1); - OutputTransform_relu_InputTransform_kernel + OutputTransform_relu_InputTransform_kernel <<>>(N, C, output, input, (float*)skip, bias); } else if (C > kMaxResBlockFusingChannels) { @@ -649,8 +803,8 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, "res block fusing opt not supported for the given data type and no " "of filters\n"); } else { - OutputTransform_SE_relu_InputTransform_kernel + OutputTransform_SE_relu_InputTransform_kernel <<>>(N, C, se_K, output, input, (float*)skip, bias, w1, b1, w2, b2); } @@ -658,17 +812,16 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, ReportCUDAErrors(cudaGetLastError()); } - // softmax along C dimension which is assumed to be 64 // each thread processes two elements. Each warp computes a sum (over 64 // elements) template -__global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { - +__global__ void softmax_opt_64_kernel(T* output, const T* input, + const T* input2, int N) { int index = blockDim.x * blockIdx.x + threadIdx.x; if (index >= N) return; - float x[2]; + float x[4]; float ex[2]; // Load from memory @@ -678,10 +831,22 @@ __global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { copyAs(&inp[0], &input[index * 2]); x[0] = (float)inp[0]; x[1] = (float)inp[1]; + if (input2 != nullptr) { + copyAs(&inp[0], &input2[index * 2]); + x[2] = (float)inp[0]; + x[3] = (float)inp[1]; + } } else { copyAs(&x[0], &input[index * 2]); + if (input2 != nullptr) { + copyAs(&x[2], &input2[index * 2]); + } } + if (input2 != nullptr) { + x[0] += x[2]; + x[1] += x[3]; + } float threadMax = max(x[0], x[1]); float maxval = warpMax(threadMax); maxval = __shfl_sync(0xFFFFFFFF, maxval, 0); @@ -707,14 +872,13 @@ __global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { } } - // N * C Tensors // performs softmax along the C dimension // Each thread processes one element // Sums are computed in shared memory // C threads per block, N blocks template -__global__ void softmax_kernel(T* output, const T* input) { +__global__ void softmax_kernel(T* output, const T* input, const T* input2) { int n = blockIdx.x; int c = threadIdx.x; int C = blockDim.x; @@ -723,6 +887,7 @@ __global__ void softmax_kernel(T* output, const T* input) { // softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) float x = (float)input[index]; + if (input2 != nullptr) x += (float)input2[index]; __shared__ float sum, maxval; if (c == 0) { @@ -754,14 +919,16 @@ __global__ void softmax_kernel(T* output, const T* input) { } template -void Softmax(int N, int C, T* output, const T* input, cudaStream_t stream) { +void Softmax(int N, int C, T* output, const T* input, const T* input2, + cudaStream_t stream) { if (C == 64) { - int size = N * 32; // Total no of threads needed + int size = N * 32; // Total no of threads needed const int kBlockSize = 256; int blocks = DivUp(size, kBlockSize); - softmax_opt_64_kernel<<>>(output, input, size); + softmax_opt_64_kernel + <<>>(output, input, input2, size); } else { - softmax_kernel<<>>(output, input); + softmax_kernel<<>>(output, input, input2); } ReportCUDAErrors(cudaGetLastError()); @@ -797,62 +964,79 @@ __device__ __forceinline__ float shared_sum_for_layer_norm(float x) { // 1. Perform Bias add, and skip add // 2. Perform layer norm (normalize across C dimension) template -__global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const T* bias, - const T* skip, const T* gammas, - const T* betas, float ep) { +__global__ void layer_norm_kernel(int N, int C, T* output, const T* input, + const T* bias, const T* skip, const T* gammas, + const T* betas, float ep, float alpha, + ActivationFunction act) { int n = blockIdx.x * blockDim.z + threadIdx.z; if (n >= N) return; - int c = (threadIdx.y * 32 + threadIdx.x) * 4; + int c = (threadIdx.y * 32 + threadIdx.x) * 16; bool oobThread = c >= C; int biasIndex = c; int tensorIndex = n * C + c; - float val[4] = {0, 0, 0, 0}; - float b[4] = {0, 0, 0, 0}; - float sk[4] = {0, 0, 0, 0}; - float bts[4] = {0, 0, 0, 0}; - float gms[4] = {0, 0, 0, 0}; + float val[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float oth[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; const bool fp16 = std::is_same::value; if (!oobThread) { - // Load from memory (4 elements a time) + // Load from memory (16 elements a time) if (fp16) { - half inp[4]; - copyAs(&inp[0], &input[tensorIndex]); - for (int i = 0; i < 4; i++) val[i] = (float)inp[i]; - copyAs(&inp[0], &skip[tensorIndex]); - for (int i = 0; i < 4; i++) sk[i] = (float)inp[i]; - copyAs(&inp[0], &bias[biasIndex]); - for (int i = 0; i < 4; i++) b[i] = (float)inp[i]; - copyAs(&inp[0], &betas[biasIndex]); - for (int i = 0; i < 4; i++) bts[i] = (float)inp[i]; - copyAs(&inp[0], &gammas[biasIndex]); - for (int i = 0; i < 4; i++) gms[i] = (float)inp[i]; + half inp[8]; + copyAs(&inp[0], &input[tensorIndex]); + for (int i = 0; i < 8; i++) val[i] = (float)inp[i]; + copyAs(&inp[0], &input[tensorIndex + 8]); + for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i]; + copyAs(&inp[0], &bias[biasIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &bias[biasIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + for (int i = 0; i < 16; i++) val[i] += oth[i]; } else { copyAs(&val[0], &input[tensorIndex]); - copyAs(&sk[0], &skip[tensorIndex]); - copyAs(&b[0], &bias[biasIndex]); - copyAs(&bts[0], &betas[biasIndex]); - copyAs(&gms[0], &gammas[biasIndex]); + copyAs(&val[4], &input[tensorIndex + 4]); + copyAs(&val[8], &input[tensorIndex + 8]); + copyAs(&val[12], &input[tensorIndex + 12]); + copyAs(&oth[0], &bias[biasIndex]); + copyAs(&oth[4], &bias[biasIndex + 4]); + copyAs(&oth[8], &bias[biasIndex + 8]); + copyAs(&oth[12], &bias[biasIndex + 12]); + for (int i = 0; i < 16; i++) val[i] += oth[i]; + } + } + + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &skip[tensorIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &skip[tensorIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &skip[tensorIndex]); + copyAs(&oth[4], &skip[tensorIndex + 4]); + copyAs(&oth[8], &skip[tensorIndex + 8]); + copyAs(&oth[12], &skip[tensorIndex + 12]); } } // 1. Compute mean float s = 0; if (!oobThread) - for (int i = 0; i < 4; i++) { - val[i] += b[i] + sk[i]; + for (int i = 0; i < 16; i++) { + val[i] = activate(val[i], act) + oth[i] * alpha; s += val[i]; } - + s = shared_sum_for_layer_norm(s); float mean = s / C; // 2. Compute varience s = 0; if (!oobThread) - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 16; i++) { float d = val[i] - mean; float d_sq = d * d; s += d_sq; @@ -860,39 +1044,81 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const s = shared_sum_for_layer_norm(s); float var = s / C; + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &gammas[biasIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &gammas[biasIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &gammas[biasIndex]); + copyAs(&oth[4], &gammas[biasIndex + 4]); + copyAs(&oth[8], &gammas[biasIndex + 8]); + copyAs(&oth[12], &gammas[biasIndex + 12]); + } + } + // 3. Normalize - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 16; i++) { float d = val[i] - mean; float norm = d / sqrt(var + ep); - float op = norm * gms[i] + bts[i]; + float op = norm * oth[i]; val[i] = op; } + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &betas[biasIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &betas[biasIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &betas[biasIndex]); + copyAs(&oth[4], &betas[biasIndex + 4]); + copyAs(&oth[8], &betas[biasIndex + 8]); + copyAs(&oth[12], &betas[biasIndex + 12]); + } + } + + for (int i = 0; i < 16; i++) { + val[i] += oth[i]; + } + if (!oobThread) { // Write to memory if (fp16) { - half op[4]; - for (int i = 0; i < 4; i++) op[i] = (half)val[i]; - copyAs(&output[tensorIndex], &op[0]); + half op[8]; + for (int i = 0; i < 8; i++) op[i] = (half)val[i]; + copyAs(&output[tensorIndex], &op[0]); + for (int i = 0; i < 8; i++) op[i] = (half)val[i + 8]; + copyAs(&output[tensorIndex + 8], &op[0]); } else { copyAs(&output[tensorIndex], &val[0]); + copyAs(&output[tensorIndex + 4], &val[4]); + copyAs(&output[tensorIndex + 8], &val[8]); + copyAs(&output[tensorIndex + 12], &val[12]); } } } // add (optional) skip connection to input, and then perform Layer normalization -// normalization is done across C dimension (i.e, sums and std deviations taken over elements in C dim) +// normalization is done across C dimension (i.e, sums and std deviations taken +// over elements in C dim) template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, - cudaStream_t stream) { + float alpha, ActivationFunction act, cudaStream_t stream) { // process 4 elements per thread to achieve close to peak memory bandwidth - if (C % 4 != 0) throw Exception("unsupported filter size"); - if (C > 4096) throw Exception("unsupported filter size"); + if (C % 16 != 0) throw Exception("unsupported filter size"); + if (C > 16384) throw Exception("unsupported filter size"); dim3 blockDim, gridDim; blockDim.x = 32; - blockDim.y = DivUp(C / 4, 32); + blockDim.y = DivUp(C / 16, 32); blockDim.z = std::min(std::max(512 / (blockDim.x * blockDim.y), 1u), (unsigned int)N); gridDim.x = DivUp(N, blockDim.z); @@ -900,7 +1126,7 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias, gridDim.z = 1; layer_norm_kernel<<>>( - N, C, output, input, bias, skip, gammas, betas, ep); + N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); ReportCUDAErrors(cudaGetLastError()); } @@ -958,7 +1184,7 @@ __global__ void promotion_logits_kernel(int C, T* output, const T* keys, // phase 2: add the last "row" to the other 3 // #knight offset is added to the other three // promotion_offsets = promotion_offsets[:, :3, :] + promotion_offsets[:, 3:4, - // :] + // :] // Only 24 threads in the group are active in this phase if (threadInGroup < 32) { int x = threadInGroup % 4; @@ -999,6 +1225,77 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, <<>>(C, output, keys, ppo, policy_attn_logits); } +template +__global__ void preprocess_for_attention_body_kernel(T* output, const T* input, + const float* encoding) { + int n = blockIdx.x; + int hw = blockIdx.y; + int c = threadIdx.x; + + T op; + if (c >= kInputPlanes) { + // concatenate from fixed pos encoding array + op = (T)(encoding[64 * hw + (c - kInputPlanes)]); + } else { + op = input[n * kInputPlanes * 64 + c * 64 + hw]; // nchw + } + + constexpr int outputC = kInputPlanes + kNumPosEncodingChannels; + + // convert to nhwc + output[n * 64 * outputC + hw * outputC + c] = op; +} + +template +void inputPreprocessForAttentionBody(T* output, const T* input, + const float* encoding, int N, + cudaStream_t stream) { + // N * 64 blocks + // (kInputPlanes + kNumPosEncodingChannels) threads + // Each thread computes a single output element + dim3 gridSize = dim3(N, 64); + int blockSize = kInputPlanes + kNumPosEncodingChannels; + preprocess_for_attention_body_kernel + <<>>(output, input, encoding); +} + +template +__global__ void input_gating_kernel(T* output, const T* input, const T* mult, + const T* add, int HW, int C) { + int n_offset = blockIdx.z * HW * C; + int idx = threadIdx.y * C + blockIdx.x * blockDim.x + + threadIdx.x; // index in input + int idxT = (blockIdx.x * blockDim.x + threadIdx.x) * HW + + threadIdx.y; // index in transposed weights arrays mult and add. + + if (idx < HW * C) { + // Combine multiply gating, add gating and weights transpose. + float op = + (float)input[n_offset + idx] * (float)mult[idxT] + (float)add[idxT]; + output[n_offset + idx] = (T)op; + } +} + +template +void applyInputGating(T* output, const T* input, const T* mult, const T* add, + int N, int HW, int C, cudaStream_t stream) { + // Multiple blocks to fit into each input area / volume + // Block x position indicates horizontal section of area + // Block y position indicates batch + // Each thread computes a single output element + dim3 blockSize, gridSize; + blockSize.x = DivUp(1024, HW); + blockSize.y = HW; + blockSize.z = 1; + gridSize.x = DivUp(C, blockSize.x); + gridSize.y = 1; + gridSize.z = N; + input_gating_kernel + <<>>(output, input, mult, add, HW, C); + + ReportCUDAErrors(cudaGetLastError()); +} + // Template instantiation. template void copyTypeConverted(half* op, float* ip, int N, cudaStream_t stream); @@ -1025,6 +1322,11 @@ template void addVectors(half* c, half* a, half* b, int size, int asize, int bsize, ActivationFunction act, cudaStream_t stream); +template void addVectorsHNC_NHC(float* a, float* b, int N, int H, int C, + cudaStream_t stream); +template void addVectorsHNC_NHC(half* a, half* b, int N, int H, int C, + cudaStream_t stream); + template void addBiasBatched(float* output, const float* input, const float* bias, int Batch, int N, int C, ActivationFunction activation, @@ -1034,6 +1336,15 @@ template void addBiasBatched(half* output, const half* input, ActivationFunction activation, cudaStream_t stream); +template void addBiasBatched(float* output, const float* input, + const float* bias, int Batch, int N, int C, + int Nstride, ActivationFunction activation, + cudaStream_t stream); +template void addBiasBatched(half* output, const half* input, + const half* bias, int Batch, int N, int C, + int Nstride, ActivationFunction activation, + cudaStream_t stream); + template void addBias_NCHW(float* c, float* a, float* b, int N, int C, int H, int W, ActivationFunction activation, cudaStream_t stream); @@ -1079,125 +1390,156 @@ template void InputTransform(int N, int C, const float* input, cudaStream_t stream); -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); -template void OutputTransform( +template void +OutputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void OutputTransform( +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, float* output, + const float* input, const float* skip, + const float* bias, const float* w1, + const float* b1, const float* w2, + const float* b2, cudaStream_t stream); + +template void OutputInputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void OutputTransform( +template void OutputInputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void OutputTransform( +template void OutputInputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void OutputTransform( +template void OutputInputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void OutputTransform( +template void OutputInputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputInputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputInputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputInputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputInputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputInputTransform( - int N, int C, int se_K, float* output, const float* input, - const float* skip, const float* bias, const float* w1, const float* b1, - const float* w2, const float* b2, cudaStream_t stream); - -template void OutputInputTransform( +template void OutputInputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); template void Softmax(int N, int C, half* output, const half* input, - cudaStream_t stream); + const half* input2, cudaStream_t stream); template void Softmax(int N, int C, float* output, const float* input, - cudaStream_t stream); + const float* input2, cudaStream_t stream); template void LayerNorm(int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, + float alpha, ActivationFunction act, cudaStream_t stream); template void LayerNorm(int N, int C, float* output, const float* input, const float* bias, const float* skip, const float* gammas, const float* betas, - float ep, cudaStream_t stream); + float ep, float alpha, ActivationFunction act, + cudaStream_t stream); template void ComputePromotionLogits(int N, int C, half* output, const half* keys, const half* ppo, @@ -1220,5 +1562,26 @@ template void convertNCHWtoNHWC(half* output_tensor, const half* input_tensor, int Nin, int Cin, int Nout, int Cout, int H, int W); + +template void inputPreprocessForAttentionBody(half* output, + const half* input, + const float* encoding, + int N, cudaStream_t stream); + +template void inputPreprocessForAttentionBody(float* output, + const float* input, + const float* encoding, + int N, + cudaStream_t stream); + +template void applyInputGating(half* output, const half* input, + const half* mult, const half* add, int N, + int C, int output_size, + cudaStream_t stream); + +template void applyInputGating(float* output, const float* input, + const float* mult, const float* add, + int N, int C, int output_size, + cudaStream_t stream); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/cuda_common.h b/src/neural/cuda/cuda_common.h index 759238cd4e..ca91f0e91b 100644 --- a/src/neural/cuda/cuda_common.h +++ b/src/neural/cuda/cuda_common.h @@ -74,7 +74,5 @@ void CudaError(cudaError_t status, const char* file, const int& line); inline int DivUp(int a, int b) { return (a + b - 1) / b; } -enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH }; - } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/fp16_kernels.cu b/src/neural/cuda/fp16_kernels.cu index a2645c7846..cb433ded9c 100644 --- a/src/neural/cuda/fp16_kernels.cu +++ b/src/neural/cuda/fp16_kernels.cu @@ -26,6 +26,7 @@ */ #include "cuda_common.h" +#include "neural/shared/activation.h" #include "winograd_helper.inc" namespace lczero { @@ -207,13 +208,18 @@ bool Se_Fp16_NHWC(int N, int C, int numFc1Out, half* output, const half* skip, // 'C' threads per block // 'N' blocks // Every thread generates an entire board/plane (8x8 elements). -template -__global__ __launch_bounds__(kMaxResBlockFusingSeKFp16Ampere,1) -void OutputInputTransformKernel_fp16_shmem_board( - int N, int C, int se_K, half* output, const half* input, half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2) { +template +__global__ __launch_bounds__( + kMaxResBlockFusingSeKFp16Ampere, + 1) void OutputInputTransformKernel_fp16_shmem_board(int N, int C, int se_K, + half* output, + const half* input, + half* skip, + const half* bias, + const half* w1, + const half* b1, + const half* w2, + const half* b2) { int k = threadIdx.x; int n = blockIdx.x; @@ -324,7 +330,7 @@ void OutputInputTransformKernel_fp16_shmem_board( for (int w = 0; w < 8; w++) boardRow[w] += skipInp[w]; } - if (activation != NONE) { + if (activation != ACTIVATION_NONE) { #pragma unroll for (int w = 0; w < 8; w++) boardRow[w] = (half)activate((float)boardRow[w], activation); @@ -338,7 +344,6 @@ void OutputInputTransformKernel_fp16_shmem_board( copyAs(&BOARD(h, 0), &boardRow); } - // Perform input transform. int c = k; @@ -437,10 +442,9 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, // and only for fp16. if (C <= kMaxResBlockFusingSeKFp16Ampere) { cudaFuncSetAttribute( - OutputInputTransformKernel_fp16_shmem_board, - cudaFuncAttributeMaxDynamicSharedMemorySize, - 72 * C * sizeof(half)); + OutputInputTransformKernel_fp16_shmem_board, + cudaFuncAttributeMaxDynamicSharedMemorySize, 72 * C * sizeof(half)); OutputInputTransformKernel_fp16_shmem_board <<>>( @@ -452,8 +456,8 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, "of filters\n"); } } else { - OutputTransform_SE_relu_InputTransform_kernel + OutputTransform_SE_relu_InputTransform_kernel <<>>(N, C, se_K, output, input, (half*)skip, bias, w1, b1, w2, b2); } @@ -470,107 +474,137 @@ template void InputTransform(int N, int C, half* transformed_input, const half* input, cudaStream_t stream); -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputTransform( - int N, int C, int se_K, half* output, const half* input, const half* skip, - const half* bias, const half* w1, const half* b1, const half* w2, - const half* b2, cudaStream_t stream); - -template void OutputInputTransform( +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputTransform(int N, int C, int se_K, half* output, + const half* input, const half* skip, + const half* bias, const half* w1, + const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + +template void OutputInputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); -template void OutputInputTransform( +template void OutputInputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); -template void OutputInputTransform( +template void OutputInputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); -template void OutputInputTransform( +template void OutputInputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); -template void OutputInputTransform( +template void OutputInputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); -template void OutputInputTransform( +template void OutputInputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); diff --git a/src/neural/cuda/inputs_outputs.h b/src/neural/cuda/inputs_outputs.h index d8677b4d91..da3b5a6b00 100644 --- a/src/neural/cuda/inputs_outputs.h +++ b/src/neural/cuda/inputs_outputs.h @@ -98,11 +98,13 @@ struct InputsOutputs { if (mem) ReportCUDAErrors(cudaFree(mem)); } if (scratch_mem_) ReportCUDAErrors(cudaFree(scratch_mem_)); - + if (offset_pointers_) ReportCUDAErrors(cudaFree(offset_pointers_)); + if (head_offset_pointers_) { + ReportCUDAErrors(cudaFree(head_offset_pointers_)); + } cudaStreamDestroy(stream_); cublasDestroy(cublas_); } - } uint64_t* input_masks_mem_; float* input_val_mem_; @@ -124,13 +126,14 @@ struct InputsOutputs { bool multi_stream_; void* tensor_mem_[3]; void* scratch_mem_; + void** offset_pointers_ = nullptr; + void** head_offset_pointers_ = nullptr; // cuda stream used to run the network cudaStream_t stream_; - cublasHandle_t cublas_; // cublas handle used to run the network - + cublasHandle_t cublas_; }; } // namespace cudnn_backend diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 2763d09c1e..fa405c1946 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -26,6 +26,7 @@ */ #include "cuda_common.h" +#include "neural/shared/activation.h" namespace lczero { namespace cudnn_backend { @@ -36,12 +37,25 @@ template void addVectors(T* c, T* a, T* b, int size, int asize, int bsize, ActivationFunction activation, cudaStream_t stream); +// Adds two vectors of equal size overwriting the first with the sum. +// This specialisation performs a transposition of the first 2 indexes +// of the second while performing the addition. +template +void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream); + // Optimized kernel to add bias to innermost dimension // and perform optional activation (to be used with GEMMs/fully connected) template void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, int C, ActivationFunction activation, cudaStream_t stream); +// Optimized kernel to add bias to innermost dimension +// and perform optional activation (to be used with GEMMs/fully connected) +template +void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, + int C, int Nstride, ActivationFunction activation, + cudaStream_t stream); + // Add bias to convolution's output. template void addBias_NCHW(T* c, T* a, T* b, int N, int C, int H, int W, @@ -118,17 +132,26 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, cudaStream_t stream); template -void Softmax(int N, int C, T* output, const T* input, cudaStream_t stream); +void Softmax(int N, int C, T* output, const T* input, const T* input2, + cudaStream_t stream); template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, - cudaStream_t stream); + float alpha, ActivationFunction act, cudaStream_t stream); template void ComputePromotionLogits(int N, int C, T* output, const T* keys, const T* ppo, const T* policy_attn_logits, cudaStream_t stream); +template +void inputPreprocessForAttentionBody(T* output, const T* input, + const float* encoding, int N, + cudaStream_t stream); + +template +void applyInputGating(T* output, const T* input, const T* mult, const T* add, + int N, int HW, int C, cudaStream_t stream); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 0415b24a92..d4e02b3a46 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -32,11 +32,68 @@ #include "cuda_common.h" #include "kernels.h" +#include "neural/network.h" +#include "neural/shared/activation.h" +#include "neural/shared/attention_policy_map.h" #include "utils/fp16_utils.h" namespace lczero { -// void dumpTensor(void* memory, int elements, const char* message, bool fp16 = -// false); + +#if 0 +// debug code to dump allocation in GPU memory +template +void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { + const bool fp16 = std::is_same::value; + printf("\n%s\n", message); + int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); + int bytes = elements * elementSize; + void *temp = malloc(bytes); + cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); + float maxval = -std::numeric_limits::max(); + float minval = std::numeric_limits::max(); + int nans = 0; + int nanss[10] {}; + + for (int i = 0; i < elements; i++) + { + float val; + if (fp16) + { + half *arr = (half*)temp; + val = (float)arr[i]; + } + else + { + float *arr = (float *)temp; + val = arr[i]; + } + maxval = std::max(maxval, val); + minval = std::min(minval, val); + + if (std::isnan(val)) { + if (nans < 10) nanss[nans] = i; + nans++; + } + + if (!only_summary || i < 2 || i == elements - 1) { + // printf("%8.4f ", val); + // if ((i % 8) == 7) printf("\n"); + printf("%i;%.6f\n", i, val); + } + } + free(temp); + if (maxval == -std::numeric_limits::max()) + maxval = std::numeric_limits::quiet_NaN(); + if (minval == std::numeric_limits::max()) + minval = std::numeric_limits::quiet_NaN(); + + printf("Max: %.6f, Min: %.6f, NaNs: %i of %i", maxval, minval, nans, elements); + printf("\nNaN indices: "); + for (int i=0; i 10) printf("......"); + printf("\n"); +} +#endif namespace cudnn_backend { @@ -56,7 +113,12 @@ BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip, bool nhwc, template BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip) - : input_(ip), C(c), H(h), W(w), nhwc_(ip->nhwc_), use_gemm_ex_(false) {} + : input_(ip), + C(c), + H(h), + W(w), + nhwc_(ip ? ip->nhwc_ : false), + use_gemm_ex_(false) {} #ifdef USE_CUDNN template @@ -108,7 +170,7 @@ void ConvLayer::init() { conv_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } - if (act_ == RELU) { + if (act_ == ACTIVATION_RELU) { cudnnSetActivationDescriptor(activation_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0); } @@ -193,7 +255,8 @@ template void ConvLayer::Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, cudnnHandle_t cudnn, - cublasHandle_t /*cublas*/, cudaStream_t stream) { + cublasHandle_t /*cublas*/, cudaStream_t stream, + DataType***) { const cudnnDataType_t dataType = std::is_same::value ? CUDNN_DATA_HALF : CUDNN_DATA_FLOAT; @@ -208,14 +271,15 @@ void ConvLayer::Eval(int N, DataType* output, const DataType* input, float alpha = 1.0f, beta = 0.0f; - if (!(act_ != NONE || use_bias_ || input2)) { + if (!(act_ != ACTIVATION_NONE || use_bias_ || input2)) { ReportCUDNNErrors(cudnnConvolutionForward( cudnn, &alpha, in_tensor_desc_, input, filter_desc_, weights, conv_desc_, conv_algo_, scratch, scratch_size, &beta, out_tensor_desc_, output)); } #if CUDNN_MAJOR != 7 || CUDNN_MINOR != 0 - else if (input2 && (act_ == RELU || act_ == NONE) && use_bias_) { + else if (input2 && (act_ == ACTIVATION_RELU || act_ == ACTIVATION_NONE) && + use_bias_) { // fused bias + sum + relu! ReportCUDNNErrors(cudnnConvolutionBiasActivationForward( cudnn, &alpha, in_tensor_desc_, input, filter_desc_, weights, @@ -225,7 +289,8 @@ void ConvLayer::Eval(int N, DataType* output, const DataType* input, // For some reason cudnn doesn't support just Convolution + Bias with nchw // (winograd algorithm) it works fine when RELU is also needed which is // somewhat strange. - if ((act_ == RELU || (act_ == NONE && nhwc_)) && !input2 && use_bias_) { + if ((act_ == ACTIVATION_RELU || (act_ == ACTIVATION_NONE && nhwc_)) && + !input2 && use_bias_) { ReportCUDNNErrors(cudnnConvolutionBiasActivationForward( cudnn, &alpha, in_tensor_desc_, input, filter_desc_, weights, conv_desc_, conv_algo_, scratch, scratch_size, &beta, @@ -241,8 +306,8 @@ void ConvLayer::Eval(int N, DataType* output, const DataType* input, if (input2 && input2 != output) { // Merge act with residual add unless there is bias. addVectors(output, output, (DataType*)input2, N * C * H * W, - N * C * H * W, N * C * H * W, use_bias_ ? NONE : act_, - stream); + N * C * H * W, N * C * H * W, + use_bias_ ? ACTIVATION_NONE : act_, stream); act_done = !use_bias_; } // Merge act with bias. @@ -254,7 +319,7 @@ void ConvLayer::Eval(int N, DataType* output, const DataType* input, addVectors(output, output, biases, N * C * H * W, N * C * H * W, C, act_, stream); } - } else if (!act_done && act_ != NONE) { + } else if (!act_done && act_ != ACTIVATION_NONE) { addVectors(output, output, (DataType*)nullptr, N * C * H * W, N * C * H * W, 0, act_, stream); } @@ -274,12 +339,12 @@ void ConvLayer::Eval(int N, DataType* output, const DataType* input, ReportCUDNNErrors(cudnnAddTensor(cudnn, &alpha, bias_desc_, biases, &alpha, out_tensor_desc_, output)); } - if (act_ == RELU) { + if (act_ == ACTIVATION_RELU) { ReportCUDNNErrors(cudnnActivationForward(cudnn, activation_, &alpha, out_tensor_desc_, output, &beta, out_tensor_desc_, output)); } - if (act_ != RELU && act_ != NONE) { + if (act_ != ACTIVATION_RELU && act_ != ACTIVATION_NONE) { addVectors(output, output, nullptr, N * C * H * W, N * C * H * W, 0, act_, stream); // TODO: check this actually compiles? @@ -423,7 +488,8 @@ template <> void SELayer::Eval(int N, float* output, const float* input, const float* /*input2*/, void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, - cublasHandle_t cublas, cudaStream_t stream) { + cublasHandle_t cublas, cudaStream_t stream, + float***) { // Ping-pong between 'op1' and 'op2' (parts of scratch memory). float* op1 = (float*)scratch; float* op2 = (float*)scratch + scratch_size / sizeof(float) / 2; @@ -444,7 +510,8 @@ void SELayer::Eval(int N, float* output, const float* input, ReportCUBLASErrors(cublasSgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, 2 * C, N, numFc1Out_, &alpha, w2_, numFc1Out_, op1, numFc1Out_, &beta, op2, 2 * C)); - addVectors(op2, b2_, op2, 2 * C * N, 2 * C, 2 * C * N, NONE, stream); + addVectors(op2, b2_, op2, 2 * C * N, 2 * C, 2 * C * N, ACTIVATION_NONE, + stream); // 4. (Optional prev layer bias add), Global scale, residual add, relu and // bias. @@ -455,7 +522,7 @@ template <> void SELayer::Eval(int N, half* output, const half* input, const half* input2, void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, - cudaStream_t stream) { + cudaStream_t stream, half***) { bool se_done = false; if (kUseFusedSELayer && nhwc_) { se_done = Se_Fp16_NHWC(N, C, numFc1Out_, output, input2, input, w1_t_, b1_, @@ -486,7 +553,8 @@ void SELayer::Eval(int N, half* output, const half* input, ReportCUBLASErrors(cublasHgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, 2 * C, N, numFc1Out_, &alpha, w2_, numFc1Out_, op1, numFc1Out_, &beta, op2, 2 * C)); - addVectors(op2, b2_, op2, 2 * C * N, 2 * C, 2 * C * N, NONE, stream); + addVectors(op2, b2_, op2, 2 * C * N, 2 * C, 2 * C * N, ACTIVATION_NONE, + stream); // 4. (Optional prev layer bias add), Global scale, residual add, relu and // bias. @@ -559,7 +627,7 @@ template <> void FCLayer::Eval(int N, half* output_tensor, const half* input_tensor, const half* /*input2*/, void* /*scratch*/, size_t /*scratch_size*/, cudnnHandle_t /*cudnn*/, - cublasHandle_t cublas, cudaStream_t stream) { + cublasHandle_t cublas, cudaStream_t stream, half***) { const int num_outputs = C * H * W; const int num_inputs = input_->GetC() * input_->GetH() * input_->GetW(); @@ -573,7 +641,7 @@ void FCLayer::Eval(int N, half* output_tensor, const half* input_tensor, input_tensor, num_inputs, &beta, output_tensor, num_outputs)); - if (use_bias_ || (act_ != NONE)) { + if (use_bias_ || (act_ != ACTIVATION_NONE)) { addVectors(output_tensor, biases_, output_tensor, num_outputs * N, num_outputs, num_outputs * N, act_, stream); } @@ -584,7 +652,7 @@ void FCLayer::Eval(int N, float* output_tensor, const float* input_tensor, const float* /*input2*/, void* /*scratch*/, size_t /*scratch_size*/, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, - cudaStream_t stream) { + cudaStream_t stream, float***) { const int num_outputs = C * H * W; const int num_inputs = input_->GetC() * input_->GetH() * input_->GetW(); @@ -594,7 +662,7 @@ void FCLayer::Eval(int N, float* output_tensor, input_tensor, num_inputs, &beta, output_tensor, num_outputs)); - if (use_bias_ || (act_ != NONE)) { + if (use_bias_ || (act_ != ACTIVATION_NONE)) { addVectors(output_tensor, biases_, output_tensor, num_outputs * N, num_outputs, num_outputs * N, act_, stream); } @@ -693,10 +761,13 @@ void PolicyMapLayer::LoadWeights(const short* cpuWeight, } template -void PolicyMapLayer::Eval( - int N, DataType* output_tensor, const DataType* input_tensor, - const DataType* /*input2*/, void* /*scratch*/, size_t /*scratch_size*/, - cudnnHandle_t /*cudnn*/, cublasHandle_t /*cublas*/, cudaStream_t stream) { +void PolicyMapLayer::Eval(int N, DataType* output_tensor, + const DataType* input_tensor, + const DataType* /*input2*/, + void* /*scratch*/, size_t /*scratch_size*/, + cudnnHandle_t /*cudnn*/, + cublasHandle_t /*cublas*/, + cudaStream_t stream, DataType***) { int inputSize = this->input_->GetC() * this->input_->GetH() * this->input_->GetW(); if (attention_map_) inputSize = used_size_; @@ -723,7 +794,8 @@ FusedWinogradConvSELayer::FusedWinogradConvSELayer( has_se_(se), se_k_(se_k), op_nhcw_(op_nhcw) { - if (act_ != RELU && act_ != MISH && act_ != NONE) { + if (act_ != ACTIVATION_RELU && act_ != ACTIVATION_MISH && + act_ != ACTIVATION_NONE) { throw Exception("Unsupported activation for fused winograd conv SE layer."); } // Allocate memory for weights (filter tensor) and biases. @@ -870,7 +942,7 @@ template void FusedWinogradConvSELayer::Eval( int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, - cublasHandle_t cublas, cudaStream_t stream) { + cublasHandle_t cublas, cudaStream_t stream, DataType***) { // Split the scratch space into two parts - use first part for holding // transformed input and second part for transformed output. DataType* transformed_input = (DataType*)scratch; @@ -883,51 +955,58 @@ void FusedWinogradConvSELayer::Eval( transformed_input, transformed_weights_, transformed_output, N * 4, C, c_input_, 36, cublas); - if (act_ == NONE) { + if (act_ == ACTIVATION_NONE) { if (!has_se_ && use_bias_ && !skip_add_) - OutputTransform( - N, C, 0, output, transformed_output, nullptr, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, nullptr, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); else throw Exception("unsupported network type!"); - } else if (act_ == RELU) { + } else if (act_ == ACTIVATION_RELU) { if (has_se_ && use_bias_ && skip_add_) - OutputTransform( - N, C, se_k_, output, transformed_output, input2, biases_, w1_, b1_, - w2_, b2_, stream); + OutputTransform(N, C, se_k_, output, transformed_output, input2, + biases_, w1_, b1_, w2_, b2_, stream); else if (!has_se_ && use_bias_ && !skip_add_) { if (op_nhcw_) - OutputTransform( - N, C, 0, output, transformed_output, nullptr, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, nullptr, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); else - OutputTransform( - N, C, 0, output, transformed_output, nullptr, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, nullptr, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); } else if (!has_se_ && use_bias_ && skip_add_) - OutputTransform( - N, C, 0, output, transformed_output, input2, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, input2, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); else throw Exception("unsupported network type!"); - } else if (act_ == MISH) { + } else if (act_ == ACTIVATION_MISH) { if (has_se_ && use_bias_ && skip_add_) - OutputTransform( - N, C, se_k_, output, transformed_output, input2, biases_, w1_, b1_, - w2_, b2_, stream); + OutputTransform(N, C, se_k_, output, transformed_output, input2, + biases_, w1_, b1_, w2_, b2_, stream); else if (!has_se_ && use_bias_ && !skip_add_) { if (op_nhcw_) - OutputTransform( - N, C, 0, output, transformed_output, nullptr, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, nullptr, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); else - OutputTransform( - N, C, 0, output, transformed_output, nullptr, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, nullptr, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); } else if (!has_se_ && use_bias_ && skip_add_) - OutputTransform( - N, C, 0, output, transformed_output, input2, biases_, nullptr, - nullptr, nullptr, nullptr, stream); + OutputTransform(N, C, 0, output, transformed_output, input2, + biases_, nullptr, nullptr, nullptr, nullptr, + stream); else throw Exception("unsupported network type!"); } else @@ -1035,13 +1114,13 @@ void Conv1Layer::Eval(int N, DataType* output, const DataType* input, const DataType* /*input2*/, void* /*scratch*/, size_t /*scratch_size*/, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, - cudaStream_t stream) { + cudaStream_t stream, DataType***) { cublasSpecialMatrixMul(weights_, input, output, C, H * W, c_input_, N, cublas); if (use_bias_) addBias_NCHW(output, output, biases_, N, C, H, W, act_, stream); - else if (act_ != NONE) + else if (act_ != ACTIVATION_NONE) addVectors(output, output, (DataType*)nullptr, N * C * H * W, N * C * H * W, 0, act_, stream); } @@ -1066,7 +1145,7 @@ ResidualBlock::ResidualBlock(BaseLayer* ip, int C, bool se, last_block_(last), shared_mem_size_(shared_mem_size), act_(activation) { - if (act_ != RELU && act_ != MISH) { + if (act_ != ACTIVATION_RELU && act_ != ACTIVATION_MISH) { throw Exception("Unsupported activation for residual block."); } // Allocate memory for weights (filter tensor) and biases. @@ -1188,7 +1267,8 @@ void ResidualBlock::Eval(int N, DataType* output, const DataType* input, const DataType* /*input2*/, void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, - cublasHandle_t cublas, cudaStream_t stream) { + cublasHandle_t cublas, cudaStream_t stream, + DataType***) { // normally: // - "output" initially contains the transformed input, // and after this layer, it contains the transformed input for next layer @@ -1224,12 +1304,12 @@ void ResidualBlock::Eval(int N, DataType* output, c_input_, 36, cublas); } - if (act_ == RELU) { - OutputInputTransform( + if (act_ == ACTIVATION_RELU) { + OutputInputTransform( N, C, 0, transformed_input, transformed_output, nullptr, biases0_, nullptr, nullptr, nullptr, nullptr, stream); - } else if (act_ == MISH) { - OutputInputTransform( + } else if (act_ == ACTIVATION_MISH) { + OutputInputTransform( N, C, 0, transformed_input, transformed_output, nullptr, biases0_, nullptr, nullptr, nullptr, nullptr, stream); } @@ -1246,59 +1326,61 @@ void ResidualBlock::Eval(int N, DataType* output, (fp16 && (shared_mem_size_ >= kMaxResBlockFusingSeFp16AmpereSmem) && (C <= kMaxResBlockFusingSeKFp16Ampere)); - if (act_ == RELU) { + if (act_ == ACTIVATION_RELU) { if (last_block_) { if (has_se_) - OutputTransform( - N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, - w2_, b2_, stream); + OutputTransform(N, C, se_k_, output, transformed_output, input, + biases1_, w1_, b1_, w2_, b2_, stream); else - OutputTransform( - N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, - w2_, b2_, stream); + OutputTransform(N, C, se_k_, output, transformed_output, input, + biases1_, w1_, b1_, w2_, b2_, stream); } else { if (has_se_) { if (allowFusing) { - OutputInputTransform( + OutputInputTransform( N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); } else { - OutputTransform( - N, C, se_k_, (DataType*)input, transformed_output, input, - biases1_, w1_, b1_, w2_, b2_, stream); + OutputTransform(N, C, se_k_, (DataType*)input, + transformed_output, input, biases1_, w1_, b1_, + w2_, b2_, stream); InputTransform(N, C, output, (DataType*)input, stream); } } else - OutputInputTransform( + OutputInputTransform( N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); } - } else if (act_ == MISH) { + } else if (act_ == ACTIVATION_MISH) { if (last_block_) { if (has_se_) - OutputTransform( - N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, - w2_, b2_, stream); + OutputTransform(N, C, se_k_, output, transformed_output, input, + biases1_, w1_, b1_, w2_, b2_, stream); else - OutputTransform( - N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, - w2_, b2_, stream); + OutputTransform(N, C, se_k_, output, transformed_output, input, + biases1_, w1_, b1_, w2_, b2_, stream); } else { if (has_se_) { if (allowFusing) { - OutputInputTransform( + OutputInputTransform( N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); } else { - OutputTransform( - N, C, se_k_, (DataType*)input, transformed_output, input, - biases1_, w1_, b1_, w2_, b2_, stream); + OutputTransform(N, C, se_k_, (DataType*)input, + transformed_output, input, biases1_, w1_, b1_, + w2_, b2_, stream); InputTransform(N, C, output, (DataType*)input, stream); } } else - OutputInputTransform( + OutputInputTransform( N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); } @@ -1338,10 +1420,14 @@ void allocAndUpload(DataType** gpu_dest, std::vector cpu_src, } template -AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, - const LegacyWeights& weights, - void* scratch) - : BaseLayer(64 * 64 + 24 * 8, 1, 1, ip) { +AttentionPolicyHead::AttentionPolicyHead( + BaseLayer* ip, const LegacyWeights& weights, void* scratch, + bool attention_body, ActivationFunction act, int max_batch_size) + : BaseLayer(64 * 64 + 24 * 8, 1, 1, ip), + attention_body_(attention_body), + // Old networks without attention body (e.g. T79) use hardcoded SELU + // activations. + act_(attention_body ? act : ACTIVATION_SELU) { embedding_op_size_ = weights.ip_pol_b.size(); wq_op_size_ = weights.ip2_pol_b.size(); wk_op_size_ = weights.ip3_pol_b.size(); @@ -1382,14 +1468,28 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, allocAndUpload(&ip4_pol_w_, weights.ip4_pol_w, scratch); for (const auto& enc : weights.pol_encoder) { - EncoderWeights* pW = new EncoderWeights(enc, scratch); + EncoderBlock* pW = new EncoderBlock( + enc, scratch, encoder_heads_, embedding_op_size_, + 1.0f, // using alpha = 1 for now (TODO: may change?) + nullptr, 0, max_batch_size, ACTIVATION_SWISH, + act_); // smolgen weights not implemented in policy encoder heads yet. encoder_weights_.emplace_back(pW); } } template -AttentionPolicyHead::EncoderWeights::EncoderWeights( - const LegacyWeights::EncoderLayer& cpu_weights, void* scratch) { +EncoderBlock::EncoderBlock( + const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, + int size, float alpha, DataType* smolgen_global_scratch, + int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, + ActivationFunction ffn_act) + : embedding_op_size_(size), + encoder_heads_(heads), + alpha_(alpha), + has_smolgen_(cpu_weights.mha.has_smolgen), + smolgen_activation_(smolgen_act), + ffn_activation_(ffn_act), + max_batch_size_(max_batch_size) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); mha_v_size_ = cpu_weights.mha.v_b.size(); @@ -1443,6 +1543,37 @@ AttentionPolicyHead::EncoderWeights::EncoderWeights( allocAndUpload(&ln2_gammas, cpu_weights.ln2_gammas, scratch); allocAndUpload(&ln2_betas, cpu_weights.ln2_betas, scratch); + + // Smolgen weights. + if (has_smolgen_) { + smol_compress_size_ = cpu_weights.mha.smolgen.compress.size() / mha_q_size_; + smol_dense_1_size_ = cpu_weights.mha.smolgen.dense1_b.size(); + smol_dense_2_size_ = cpu_weights.mha.smolgen.dense2_b.size(); + smol_global_size_ = smolgen_global_size; + + allocAndUpload(&smol_compress, cpu_weights.mha.smolgen.compress, + scratch); + allocAndUpload(&smol_dense1_w, cpu_weights.mha.smolgen.dense1_w, + scratch); + allocAndUpload(&smol_dense1_b, cpu_weights.mha.smolgen.dense1_b, + scratch); + allocAndUpload(&smol_dense2_w, cpu_weights.mha.smolgen.dense2_w, + scratch); + allocAndUpload(&smol_dense2_b, cpu_weights.mha.smolgen.dense2_b, + scratch); + + allocAndUpload(&smol_ln1_gammas, + cpu_weights.mha.smolgen.ln1_gammas, scratch); + allocAndUpload(&smol_ln1_betas, cpu_weights.mha.smolgen.ln1_betas, + scratch); + allocAndUpload(&smol_ln2_gammas, + cpu_weights.mha.smolgen.ln2_gammas, scratch); + allocAndUpload(&smol_ln2_betas, cpu_weights.mha.smolgen.ln2_betas, + scratch); + + // GPU memory already allocated in AttentionBody. + smol_global = smolgen_global_scratch; + } } template @@ -1488,169 +1619,300 @@ static void cublasXGemmStridedBatched( } template -void AttentionPolicyHead::Eval( - int N, DataType* output, const DataType* input, const DataType* input2, - void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, - cublasHandle_t cublas, cudaStream_t stream) { - DataType* scratch0 = (DataType*)scratch; - DataType* scratch1 = (DataType*)input2; - DataType* scratch2 = output + scratch_size / (2 * sizeof(DataType)); - DataType* scratch3 = scratch1 + scratch_size / (2 * sizeof(DataType)); - - int inputC = this->input_->GetC(); - convertNCHWtoNHWC(scratch0, input, N, inputC, N, inputC, 8, 8); - - // 1. Policy embedding (fully connected layer) - // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ - DataType* pol_embedding = scratch1; - { - const int num_outputs = embedding_op_size_; - const int num_inputs = inputC; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_pol_w_, - num_inputs, scratch0, num_inputs, 0.0f, pol_embedding, - num_outputs); - addBiasBatched(pol_embedding, pol_embedding, ip_pol_b_, 1, batch, - num_outputs, SELU, stream); +static void cublasXGemmBatched(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + float alpha, DataType** A, int lda, DataType** B, + int ldb, float beta, DataType** C, int ldc, + int batchCount) { + const bool fp16 = std::is_same::value; + if (fp16) { + unsigned short alpha_h = FP32toFP16(alpha); + unsigned short beta_h = FP32toFP16(beta); + ReportCUBLASErrors(cublasHgemmBatched( + handle, transa, transb, m, n, k, (const half*)&alpha_h, (half**)A, lda, + (half**)B, ldb, (const half*)&beta_h, (half**)C, ldc, batchCount)); + } else { + ReportCUBLASErrors(cublasSgemmBatched( + handle, transa, transb, m, n, k, &alpha, (float**)A, lda, (float**)B, + ldb, &beta, (float**)C, ldc, batchCount)); } +} - // 2. Encoder layers - for (const auto pEnc : encoder_weights_) { - const auto& enc = *pEnc; - const int d_model = enc.mha_q_size_; - const int depth = d_model / encoder_heads_; - - DataType* mha_q; - DataType* mha_k; - DataType* mha_v; - +// input/output tensor is in_out_tensor, others are used as scratch. +template +void EncoderBlock::Eval(int N, DataType* in_out_tensor, + DataType* scratch, DataType* buffer1, + DataType* buffer2, cublasHandle_t cublas, + cudaStream_t stream, + DataType*** offset_pointers) const { + const int d_model = mha_q_size_; + const int depth = d_model / encoder_heads_; + + // Calculate smolgen weights. Do this first so we can make use of + // scratch, buffer1 and buffer2. + if (has_smolgen_) { { - const int num_inputs = embedding_op_size_; - const int num_outputs = d_model; + // Compress. + // input shape: N, 64, d_model + // output shape: N, 64, hidden_channels + const int num_inputs = d_model; + const int num_outputs = smol_compress_size_; const int batch = N * 64; - - mha_q = scratch0; - mha_k = mha_q + num_outputs * batch; - mha_v = mha_k + num_outputs * batch; - - cublasXGemmStridedBatched( + cublasXgemm( cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, - 1.0f, enc.mha_qkv_w, num_inputs, num_inputs * num_outputs, - pol_embedding, num_inputs, 0, 0.0f, mha_q, num_outputs, - num_outputs * batch, 3); - addBiasBatched(mha_q, mha_q, enc.mha_qkv_b, 3, batch, - num_outputs, NONE, stream); + 1.0f, (const DataType*)smol_compress, num_inputs, in_out_tensor, + num_inputs, 0.0f, scratch, num_outputs); } - // Apply split_heads() to q, k and v - // which basically transposes (batch_size, 64, num_heads, depth) - // to (batch_size, num_heads, 64, depth) - // Do we really need to transpose here? - // (Maybe not, we can play with strides of the gemm and do independent gemms - // for each encoder head) - - // Apply scaled dot product attention: - /* - matmul_qk = tf.matmul(q, k, transpose_b=True) - dk = tf.cast(tf.shape(k)[-1], self.model_dtype) - scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) - attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) - output = tf.matmul(attention_weights, v) - */ - - // shape(k)[-1] = depth - float factor = 1.0f / sqrt((float)depth); - - // matmul_qk = tf.matmul(q, k, transpose_b=True) - for (int i = 0; i < encoder_heads_; i++) { - int offset = i * depth; - // layout of the output: encoder_heads_ * Batch * 64 * 64 - int outOffset = i * N * 64 * 64; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/, - depth /*K*/, // A/B, and M/N are swapped for row-major to col-major - // transform - factor, // to handle "/ tf.math.sqrt(dk)" - mha_k + offset /*A*/, - d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over - // other "depth" slices / heads - 64 * d_model, /*strideA*/ - mha_q + offset /*B*/, - d_model /*LDB*/, // to skip over other other "depth" slices / heads - 64 * d_model, /*strideB*/ - 0.0f, - scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 - 64 /*LDC*/, 64 * 64 /*strideC*/, N); + { + // Hidden 1 dense. + // input shape: N, 64 * hidden_channels + // output shape: N, hidden_sz + const int num_inputs = 64 * smol_compress_size_; + const int num_outputs = smol_dense_1_size_; + const int batch = N; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, + batch, num_inputs, 1.0f, + (const DataType*)smol_dense1_w, num_inputs, scratch, + num_inputs, 0.0f, buffer1, num_outputs); + + LayerNorm(batch, num_outputs, scratch, buffer1, smol_dense1_b, + buffer1, smol_ln1_gammas, smol_ln1_betas, 1e-6, + 0.0, /* alpha = 0 since we don't need skip */ + smolgen_activation_, stream); } - // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) - // attention_weights -> scratch2 - Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); - - // output = tf.matmul(attention_weights, v) - for (int i = 0; i < encoder_heads_; i++) { - int offset = i * depth; // for output and "v" matrix - // layout: encoder_heads_ * Batch*64*64 - int weightsOffset = i * N * 64 * 64; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, - 1.0f, mha_v + offset /*A*/, // "v" matrix - d_model /*LDA*/, // to skip over other "depth" slices / heads - 64 * d_model, /*strideA*/ - scratch2 + weightsOffset /*B*/, 64 /*LDB*/, 64 * 64, /*strideB*/ - 0.0f, scratch3 + offset /*C*/, // output goes to scratch3 - d_model /*LDC*/, 64 * d_model /*strideC*/, N); + { + // Hidden 2 dense (gen_from) + // input shape: N, hidden_sz + // output shape: N, heads * gen_sz + const int num_inputs = smol_dense_1_size_; + const int num_outputs = smol_dense_2_size_; + const int batch = N; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, + batch, num_inputs, 1.0f, + (const DataType*)smol_dense2_w, num_inputs, scratch, + num_inputs, 0.0f, buffer1, num_outputs); + + LayerNorm(batch, num_outputs, scratch, buffer1, smol_dense2_b, + buffer1, smol_ln2_gammas, smol_ln2_betas, 1e-6, + 0.0, /* alpha = 0 since we don't need skip */ + smolgen_activation_, stream); } - // #final dense layer (mha_dense), scratch3 -> scratch2 { - const int num_inputs = d_model; - const int num_outputs = embedding_op_size_; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)enc.mha_dense_w, - num_inputs, scratch3, num_inputs, 0.0f, scratch2, - num_outputs); + // Final smolgen weights generation. + /* + gen_from = tf.reshape(gen_from, [-1, heads, gen_sz]) + out = self.smol_weight_gen_dense(gen_from) + */ + const int num_inputs = + smol_dense_2_size_ / encoder_heads_; /* num_inputs == gen_sz == 256 */ + const int num_outputs = smol_global_size_; /* hwhw: 64 * 64 */ + const int batch = N * encoder_heads_; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, + batch, num_inputs, 1.0f, + (const DataType*)smol_global, num_inputs, scratch, + num_inputs, 0.0f, buffer2, num_outputs); } + } - // LN1: skip connection and layer normalization (also bias add of prev gemm) - // scratch2/scratch1 -> scratch0 - LayerNorm(N * 64, embedding_op_size_, scratch0, scratch2, - enc.mha_dense_b, scratch1, enc.ln1_gammas, - enc.ln1_betas, 1e-6, stream); + DataType* mha_q; + DataType* mha_k; + DataType* mha_v; - // #FFN dense 1, scratch0 -> scratch1 - const int encoder_dff = enc.ffn_dense1_size_; - { - const int num_inputs = embedding_op_size_; - const int num_outputs = encoder_dff; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)enc.ffn_dense1_w, - num_inputs, scratch0, num_inputs, 0.0f, scratch1, - num_outputs); - addBiasBatched(scratch1, scratch1, enc.ffn_dense1_b, 1, batch, - num_outputs, SELU, stream); - } + { + const int num_inputs = embedding_op_size_; + const int num_outputs = d_model; + const int batch = N * 64; + const int max_batch = max_batch_size_ * 64; - // #FFN dense 2, scratch1 -> scratch2 - { - const int num_inputs = encoder_dff; - const int num_outputs = embedding_op_size_; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)enc.ffn_dense2_w, - num_inputs, scratch1, num_inputs, 0.0f, scratch2, - num_outputs); + mha_q = scratch; + mha_k = mha_q + num_outputs * max_batch; + mha_v = mha_k + num_outputs * max_batch; + + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, + mha_qkv_w, num_inputs, num_inputs * num_outputs, in_out_tensor, + num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * max_batch, 3); + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, + max_batch, ACTIVATION_NONE, stream); + } + + // Apply split_heads() to q, k and v + // which basically transposes (batch_size, 64, num_heads, depth) + // to (batch_size, num_heads, 64, depth) + // Do we really need to transpose here? + // (Maybe not, we can play with strides of the gemm and do independent gemms + // for each encoder head) + + // Apply scaled dot product attention: + /* + matmul_qk = tf.matmul(q, k, transpose_b=True) + dk = tf.cast(tf.shape(k)[-1], self.model_dtype) + scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) + attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) + output = tf.matmul(attention_weights, v) + */ + + // shape(k)[-1] = depth + float factor = 1.0f / sqrt((float)depth); + + // matmul_qk = tf.matmul(q, k, transpose_b=True) + { + if (*offset_pointers == nullptr) { + std::vector offsets(encoder_heads_ * max_batch_size_ * 5); + for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { + int h = i % encoder_heads_; + int n = i / encoder_heads_; + offsets[i] = mha_k + h * depth + 64 * d_model * n; + offsets[i + encoder_heads_ * max_batch_size_] = + mha_q + h * depth + 64 * d_model * n; + offsets[i + 2 * encoder_heads_ * max_batch_size_] = + buffer1 + i * 64 * 64; + offsets[i + 3 * encoder_heads_ * max_batch_size_] = + mha_v + h * depth + 64 * d_model * n; + offsets[i + 4 * encoder_heads_ * max_batch_size_] = + buffer2 + h * depth + 64 * d_model * n; + } + ReportCUDAErrors( + cudaMalloc((void**)offset_pointers, + encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*))); + ReportCUDAErrors( + cudaMemcpy(*offset_pointers, offsets.data(), + encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*), + cudaMemcpyHostToDevice)); } + cublasXGemmBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/, + depth /*K*/, // A/B, and M/N are swapped for row-major to col-major + // transform + factor, // to handle "/ tf.math.sqrt(dk)" + *offset_pointers, // mha_k + offset /*A*/, + d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over + // other "depth" slices / heads + // 64 * d_model, /*strideA*/ + *offset_pointers + + encoder_heads_ * max_batch_size_, // mha_q + offset /*B*/, + d_model /*LDB*/, // to skip over other other "depth" slices / heads + // 64 * d_model, /*strideB*/ + 0.0f, + *offset_pointers + encoder_heads_ * max_batch_size_ * + 2, // buffer1 + outOffset /*C*/, // output + // (matmul_qk) goes to buffer1 + 64 /*LDC*/, + // 64 * 64 /*strideC*/, + N * encoder_heads_); + } + + // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) + // attention_weights -> buffer1 + if (has_smolgen_) { + // Add smolgen weights to the scaled matmul_qk attention logits before + // softmax. + Softmax(encoder_heads_ * N * 64, 64, buffer1, buffer1, buffer2, stream); + } else { + Softmax(encoder_heads_ * N * 64, 64, buffer1, buffer1, + (const DataType*)nullptr, stream); + } - // LN2: skip connection and layer normilization (also bias add of prev gemm) - // scratch2/scratch0 -> scratch1 - LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, - enc.ffn_dense2_b, scratch0, enc.ln2_gammas, - enc.ln2_betas, 1e-6, stream); + { + cublasXGemmBatched( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, 1.0f, + *offset_pointers + encoder_heads_ * max_batch_size_ * + 3, // mha_v + offset /*A*/, // "v" matrix + d_model /*LDA*/, // to skip over other "depth" slices / heads + // 64 * d_model, /*strideA*/ + *offset_pointers + encoder_heads_ * max_batch_size_ * + 2, // buffer1 + weightsOffset /*B*/, + 64 /*LDB*/, // 64 * 64, /*strideB*/ + 0.0f, + *offset_pointers + + encoder_heads_ * max_batch_size_ * + 4, // buffer2 + offset /*C*/, // output goes to buffer2 + d_model /*LDC*/, + // 64 * d_model /*strideC*/, + N * encoder_heads_); + } + + // #final dense layer (mha_dense), buffer2 -> buffer1 + { + const int num_inputs = d_model; + const int num_outputs = embedding_op_size_; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)mha_dense_w, num_inputs, + buffer2, num_inputs, 0.0f, buffer1, num_outputs); + } + // LN1: skip connection and layer normalization (also bias add of prev gemm) + // buffer1/in_out_tensor -> scratch + LayerNorm(N * 64, embedding_op_size_, scratch, buffer1, mha_dense_b, + in_out_tensor, ln1_gammas, ln1_betas, 1e-6, alpha_, + ACTIVATION_NONE, stream); + + // #FFN dense 1, scratch -> in_out_tensor + { + const int num_inputs = embedding_op_size_; + const int num_outputs = ffn_dense1_size_; // encoder_dff + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, + scratch, num_inputs, 0.0f, in_out_tensor, num_outputs); + addBiasBatched(in_out_tensor, in_out_tensor, ffn_dense1_b, 1, batch, + num_outputs, ffn_activation_, stream); + } + + // #FFN dense 2, in_out_tensor -> buffer1 + { + const int num_inputs = ffn_dense1_size_; // encoder_dff + const int num_outputs = embedding_op_size_; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, + in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); + } + + // LN2: skip connection and layer normilization (also bias add of prev gemm) + // buffer1/scratch -> in_out_tensor + LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, + ffn_dense2_b, scratch, ln2_gammas, ln2_betas, 1e-6, + alpha_, ACTIVATION_NONE, stream); +} + +template +void AttentionPolicyHead::Eval( + int N, DataType* output, const DataType* input, const DataType* input2, + void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, + cublasHandle_t cublas, cudaStream_t stream, DataType*** offset_pointers) { + DataType* input2_tensor = (DataType*)input2; + DataType* buffer1 = output + scratch_size / (2 * sizeof(DataType)); + DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(DataType)); + + int inputC = this->input_->GetC(); + if (!attention_body_) + convertNCHWtoNHWC((DataType*)scratch, input, N, inputC, N, inputC, 8, 8); + + // 1. Policy embedding (fully connected layer) + // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ + DataType* pol_embedding = input2_tensor; + { + const int num_outputs = embedding_op_size_; + const int num_inputs = inputC; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_pol_w_, + num_inputs, + attention_body_ ? input : (DataType*)scratch, + num_inputs, 0.0f, pol_embedding, num_outputs); + addBiasBatched(pol_embedding, pol_embedding, ip_pol_b_, 1, batch, + num_outputs, act_, stream); + } + + // 2. Encoder layers + for (const auto pEnc : encoder_weights_) { + pEnc->Eval(N, input2_tensor, (DataType*)scratch, buffer1, buffer2, cublas, + stream, offset_pointers); } // End of encoder blocks DataType* wq; @@ -1659,16 +1921,16 @@ void AttentionPolicyHead::Eval( const int num_inputs = embedding_op_size_; const int num_outputs = policy_d_model_; const int batch = N * 64; - wq = scratch0; + wq = (DataType*)scratch; wk = wq + num_outputs * batch; cublasXGemmStridedBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, - wqk_w_, num_inputs, num_inputs * num_outputs, scratch1, num_inputs, 0, - 0.0f, wq, num_outputs, num_outputs * batch, 2); + wqk_w_, num_inputs, num_inputs * num_outputs, input2_tensor, num_inputs, + 0, 0.0f, wq, num_outputs, num_outputs * batch, 2); - addBiasBatched(wq, wq, wqk_b_, 2, batch, num_outputs, NONE, - stream); + addBiasBatched(wq, wq, wqk_b_, 2, batch, num_outputs, + ACTIVATION_NONE, stream); } // dk = tf.math.sqrt(tf.cast(tf.shape(keys)[-1], self.model_dtype)) @@ -1714,7 +1976,7 @@ AttentionPolicyHead::~AttentionPolicyHead() { } template -AttentionPolicyHead::EncoderWeights::~EncoderWeights() { +EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(mha_q_w)); ReportCUDAErrors(cudaFree(mha_q_b)); ReportCUDAErrors(cudaFree(mha_k_w)); @@ -1733,6 +1995,170 @@ AttentionPolicyHead::EncoderWeights::~EncoderWeights() { ReportCUDAErrors(cudaFree(ffn_dense2_b)); ReportCUDAErrors(cudaFree(ln2_gammas)); ReportCUDAErrors(cudaFree(ln2_betas)); + if (has_smolgen_) { + ReportCUDAErrors(cudaFree(smol_compress)); + ReportCUDAErrors(cudaFree(smol_dense1_w)); + ReportCUDAErrors(cudaFree(smol_dense1_b)); + ReportCUDAErrors(cudaFree(smol_dense2_w)); + ReportCUDAErrors(cudaFree(smol_dense2_b)); + ReportCUDAErrors(cudaFree(smol_ln1_gammas)); + ReportCUDAErrors(cudaFree(smol_ln1_betas)); + ReportCUDAErrors(cudaFree(smol_ln2_gammas)); + ReportCUDAErrors(cudaFree(smol_ln2_betas)); + } +} + +template +EmbeddingLayer::EmbeddingLayer(BaseLayer* ip, + const std::vector& weights, + const std::vector& biases, + void* scratch, ActivationFunction act) + : BaseLayer(biases.size(), 8, 8, ip), act_(act) { + allocAndUpload(&weights_, weights, scratch); + allocAndUpload(&biases_, biases, scratch); +} + +template +EmbeddingLayer::~EmbeddingLayer() { + ReportCUDAErrors(cudaFree(weights_)); + ReportCUDAErrors(cudaFree(biases_)); +} + +template +void EmbeddingLayer::Eval( + int N, DataType* output, const DataType* input, const DataType* /*input2*/, + void* /*scratch*/, size_t /*scratch_size*/, cudnnHandle_t /*cudnn*/, + cublasHandle_t cublas, cudaStream_t stream, DataType***) { + const int num_outputs = this->GetC(); + const int num_inputs = this->input_->GetC(); + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, weights_, num_inputs, input, + num_inputs, 0.0f, output, num_outputs); + addBiasBatched(output, output, biases_, 1, batch, num_outputs, act_, stream); +} + +template +AttentionBody::AttentionBody(const LegacyWeights& weights, + void* scratch, Activations activations, + int num_res_blocks, int input_c, + int max_batch_size) + : BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr), + embedding_op_size_(weights.ip_emb_b.size()), + encoder_head_count_(weights.encoder_head_count), + activations_(activations), + num_resi_blocks_(num_res_blocks), + input_c_(input_c), + has_gating_(weights.ip_mult_gate.size() > 0 && + weights.ip_add_gate.size() > 0), + has_smolgen_(weights.has_smolgen) { + allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); + allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); + + { + size_t size = 64 * kNumPosEncodingChannels * sizeof(float); + ReportCUDAErrors(cudaMalloc(&pos_encoding_, size)); + ReportCUDAErrors( + cudaMemcpy(pos_encoding_, kPosEncoding, size, cudaMemcpyHostToDevice)); + } + + if (has_gating_) { + allocAndUpload(&ip_mult_gate_, weights.ip_mult_gate, scratch); + allocAndUpload(&ip_add_gate_, weights.ip_add_gate, scratch); + } + + if (has_smolgen_) { + allocAndUpload(&smolgen_global_, weights.smolgen_w, scratch); + smolgen_global_size_ = 64 * 64; + } + + int num_encoders = weights.encoder.size(); + float alpha = (float)pow(2.0 * num_encoders, 0.25); + for (const auto& enc : weights.encoder) { + EncoderBlock* pW = new EncoderBlock( + enc, scratch, encoder_head_count_, embedding_op_size_, alpha, + smolgen_global_, smolgen_global_size_, max_batch_size, + activations_.smolgen_activation, activations_.ffn_activation); + encoder_weights_.emplace_back(pW); + } +} + +template +AttentionBody::~AttentionBody() { + ReportCUDAErrors(cudaFree(ip_emb_w_)); + ReportCUDAErrors(cudaFree(ip_emb_b_)); + ReportCUDAErrors(cudaFree(pos_encoding_)); + if (has_gating_) { + ReportCUDAErrors(cudaFree(ip_mult_gate_)); + ReportCUDAErrors(cudaFree(ip_add_gate_)); + } + if (has_smolgen_) { + ReportCUDAErrors(cudaFree(smolgen_global_)); + } + for (const auto pEnc : encoder_weights_) delete pEnc; +} + +template +void AttentionBody::Eval(int N, DataType* output, + const DataType* input, + const DataType* input2, void* scratch, + size_t scratch_size, cudnnHandle_t /*cudnn*/, + cublasHandle_t cublas, cudaStream_t stream, + DataType*** offset_pointers) { + DataType* output_tensor = (DataType*)output; + DataType* buffer1 = (DataType*)input2; + DataType* buffer2 = buffer1 + scratch_size / (2 * sizeof(DataType)); + + int inputC = input_c_; + if (num_resi_blocks_ == 0) { + assert(inputC == kInputPlanes); + /* + # if there are no residual blocks (pure transformer), do some input + processing + flow = tf.transpose(inputs, perm=[0, 2, 3, 1]) + flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]]) + # add positional encoding for each square to the input + positional_encoding = tf.broadcast_to(tf.convert_to_tensor(self.POS_ENC, + dtype=self.model_dtype), [tf.shape(flow)[0], 64, + tf.shape(self.POS_ENC)[2]]) flow = tf.concat([flow, positional_encoding], + axis=2) + */ + inputPreprocessForAttentionBody((DataType*)scratch, input, pos_encoding_, N, + stream); + inputC += kNumPosEncodingChannels; + } else { + // #redirect flow through encoder blocks + // flow = tf.transpose(flow, perm = [ 0, 2, 3, 1 ]) + // flow = tf.reshape(flow, [ -1, 64, self.RESIDUAL_FILTERS ]) + convertNCHWtoNHWC((DataType*)scratch, input, N, inputC, N, inputC, 8, 8); + } + + // 1. square embedding (fully connected layer) + // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ + DataType* embedding = output_tensor; + { + const int num_outputs = embedding_op_size_; + const int num_inputs = inputC; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_w_, + num_inputs, (DataType*)scratch, num_inputs, 0.0f, + embedding, num_outputs); + addBiasBatched(embedding, embedding, ip_emb_b_, 1, batch, num_outputs, + activations_.default_activation, stream); + } + + // Input gating + if (has_gating_) { + applyInputGating(embedding, embedding, ip_mult_gate_, + ip_add_gate_, N, 64, embedding_op_size_, stream); + } + + // 2. Encoder blocks + for (const auto pEnc : encoder_weights_) { + pEnc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, + stream, offset_pointers); + } // End of encoder blocks } // Template instantiation. @@ -1762,6 +2188,15 @@ template class ResidualBlock; template class AttentionPolicyHead; template class AttentionPolicyHead; +template class EncoderBlock; +template class EncoderBlock; + +template class AttentionBody; +template class AttentionBody; + +template class EmbeddingLayer; +template class EmbeddingLayer; + // Misc error handling stuff. #ifdef USE_CUDNN void CudnnError(cudnnStatus_t status, const char* file, const int& line) { diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 2bb56ce15b..174b18ec15 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -32,6 +32,7 @@ #include "cuda_common.h" #include "neural/network_legacy.h" +#include "neural/shared/activation.h" #ifdef USE_CUDNN #include @@ -63,7 +64,7 @@ class BaseLayer { virtual void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) = 0; + cudaStream_t stream, DataType*** = nullptr) = 0; protected: BaseLayer* input_; @@ -93,17 +94,17 @@ class ConvLayer : public BaseLayer { public: ConvLayer(BaseLayer* ip, int C, int H, int W, int size, int Cin, - ActivationFunction activation = NONE, bool bias = false); + ActivationFunction activation = ACTIVATION_NONE, bool bias = false); ConvLayer(bool nhwc, int C, int H, int W, int size, int Cin, - ActivationFunction activation = NONE, bool bias = false); + ActivationFunction activation = ACTIVATION_NONE, bool bias = false); ~ConvLayer(); void LoadWeights(float* pfilter, float* pBias, void* scratch); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: const int c_input_; @@ -140,8 +141,8 @@ class FCLayer : public BaseLayer { void LoadWeights(float* cpuWeight, float* cpuBias, void* scratch); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: const bool use_bias_; @@ -162,8 +163,8 @@ class PolicyMapLayer : public BaseLayer { void LoadWeights(const short* cpuWeight, void* scratch); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: int used_size_; // Size of the input without padding (typically 73x64). @@ -191,8 +192,8 @@ class SELayer : public BaseLayer { void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: DataType* w1_ = nullptr; @@ -229,8 +230,8 @@ class FusedWinogradConvSELayer : public BaseLayer { void LoadSEWeights(float* w1, float* b1, float* w2, float* b2, void* scratch); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: const int c_input_; @@ -269,8 +270,8 @@ class Conv1Layer : public BaseLayer { void LoadWeights(float* pfilter, float* pBias, void* scratch); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: const int c_input_; @@ -308,8 +309,8 @@ class ResidualBlock : public BaseLayer { void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: const bool has_se_; @@ -332,6 +333,67 @@ class ResidualBlock : public BaseLayer { DataType* b2_; }; +template +class EncoderBlock { + public: + EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, + int heads, int size, float alpha, + DataType* smolgen_global_scratch, int smolgen_global_size, + int max_batch_size, ActivationFunction smolgen_act, + ActivationFunction ffn_act); + ~EncoderBlock(); + + void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, + DataType* scratch2, cublasHandle_t cublas, cudaStream_t stream, + DataType*** offset_pointers) const; + + // all GPU side pointers + DataType *mha_q_w, *mha_q_b; + DataType *mha_k_w, *mha_k_b; + DataType *mha_v_w, *mha_v_b; + DataType *mha_qkv_w, *mha_qkv_b; + DataType *mha_dense_w, *mha_dense_b; + + DataType *ln1_gammas, *ln1_betas; + + DataType *ffn_dense1_w, *ffn_dense1_b; + DataType *ffn_dense2_w, *ffn_dense2_b; + + DataType *ln2_gammas, *ln2_betas; + + DataType *smol_compress; + DataType *smol_dense1_w, *smol_dense1_b; + DataType *smol_dense2_w, *smol_dense2_b; + DataType *smol_ln1_gammas, *smol_ln1_betas; + DataType *smol_ln2_gammas, *smol_ln2_betas; + DataType *smol_global; + + int mha_q_size_; + int mha_k_size_; + int mha_v_size_; + int mha_dense_size_; + + int ffn_dense1_size_; + int ffn_dense2_size_; + + int embedding_op_size_; + int encoder_heads_; + + float alpha_; // scale to apply to skip connection add + + const bool has_smolgen_; + const ActivationFunction smolgen_activation_; + const ActivationFunction ffn_activation_; + + // Output sizes for smolgen layers. + int smol_compress_size_; + int smol_dense_1_size_; + int smol_dense_2_size_; + int smol_global_size_; + + const int max_batch_size_; +}; + // The Attention policy head implementation // Responsible for loading weights into GPU memory, and evaluating the entire // policy head @@ -346,48 +408,22 @@ class AttentionPolicyHead : public BaseLayer { public: AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, - void* scratch); + void* scratch, bool attention_body, + ActivationFunction act, int max_batch_size); ~AttentionPolicyHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, - cudnnHandle_t cudnn, cublasHandle_t cublas, - cudaStream_t stream) override; + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; private: - struct EncoderWeights { - EncoderWeights(const LegacyWeights::EncoderLayer& cpu_weights, - void* scratch); - ~EncoderWeights(); - // all GPU side pointers - DataType *mha_q_w, *mha_q_b; - DataType *mha_k_w, *mha_k_b; - DataType *mha_v_w, *mha_v_b; - DataType *mha_qkv_w, *mha_qkv_b; - DataType *mha_dense_w, *mha_dense_b; - - DataType *ln1_gammas, *ln1_betas; - - DataType *ffn_dense1_w, *ffn_dense1_b; - DataType *ffn_dense2_w, *ffn_dense2_b; - - DataType *ln2_gammas, *ln2_betas; - - int mha_q_size_; - int mha_k_size_; - int mha_v_size_; - int mha_dense_size_; - - int ffn_dense1_size_; - int ffn_dense2_size_; - }; - // GPU allocations to hold various weights used by the attention policy head DataType *ip_pol_w_, *ip_pol_b_; // "embedding" in policy attention DataType *ip2_pol_w_, *ip2_pol_b_; // "wq" in policy attention DataType *ip3_pol_w_, *ip3_pol_b_; // "wk" in policy attention - DataType* ip4_pol_w_; // "ppo" in policy attention + DataType *ip4_pol_w_; // "ppo" in policy attention - DataType *wqk_w_, *wqk_b_; // allocation containing both "wq" and "wq" + DataType *wqk_w_, *wqk_b_; // allocation containing both "wq" and "wq" int embedding_op_size_; int wq_op_size_; @@ -395,8 +431,71 @@ class AttentionPolicyHead : public BaseLayer { int encoder_heads_; int policy_d_model_; + bool attention_body_; + ActivationFunction act_; + + std::vector*> encoder_weights_; +}; + +template +class EmbeddingLayer : public BaseLayer { + using BaseLayer::C; + using BaseLayer::H; + using BaseLayer::W; - std::vector encoder_weights_; + public: + EmbeddingLayer(BaseLayer* ip, const std::vector& weights, + const std::vector& biases, void* scratch, + ActivationFunction activation); + ~EmbeddingLayer(); + + void Eval(int N, DataType* output, const DataType* input, + const DataType* input2, void* scratch, size_t scratch_size, + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; + + private: + DataType *weights_, *biases_; + ActivationFunction act_; +}; + +// The Attention body implementation +// Responsible for loading weights into GPU memory, and evaluating the entire +// attention network part of the body including the stack of encoder layers +template +class AttentionBody : public BaseLayer { + using BaseLayer::C; + using BaseLayer::H; + using BaseLayer::W; + using BaseLayer::GetC; + using BaseLayer::GetH; + using BaseLayer::GetW; + + public: + AttentionBody(const LegacyWeights& weights, void* scratch, + Activations activations, int num_res_blocks, int input_c, + int max_batch_size); + ~AttentionBody(); + void Eval(int N, DataType* output, const DataType* input, + const DataType* input2, void* scratch, size_t scratch_size, + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; + + private: + // GPU allocations to hold various weights used by the attention policy head + DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body + DataType *ip_mult_gate_, *ip_add_gate_; // input gating + DataType *smolgen_global_; // global smolgen weights for all encoder layers + float* pos_encoding_; + int embedding_op_size_; + int encoder_head_count_; + std::vector*> encoder_weights_; + Activations activations_; + int num_resi_blocks_; + int input_c_; + int smolgen_global_size_; + const bool has_gating_; + const bool has_smolgen_; }; } // namespace cudnn_backend diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index dce3b28265..275a332e6e 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -37,6 +37,7 @@ #include "layers.h" #include "neural/factory.h" #include "neural/network_legacy.h" +#include "neural/shared/activation.h" #include "neural/shared/attention_policy_map.h" #include "neural/shared/policy_map.h" #include "utils/bititer.h" @@ -67,8 +68,42 @@ static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { const size_t encoder_heads = weights.pol_encoder_head_count; - size_t size = N * 64 * std::max(std::max(embedding_op_size, encoder_dff), - policy_d_model); + size_t size = + N * 64 * + std::max(std::max(embedding_op_size, encoder_dff), policy_d_model); + + // size of matmul_qk matrix = encoder_heads_ * Batch * 64 * 64 + const size_t matmul_qk_size = encoder_heads * N * 64 * 64; + const size_t output_size = N * (64 * 64 + 8 * 24); + size = std::max(size, std::max(matmul_qk_size, output_size)); + + size_t qkv_size = N * 64 * encoder_d_model; + // We store qkv in single allocation, and other intermediate tensors are + // sometimes stored by splitting an allocation into two halves. + size = std::max(2 * size, 3 * qkv_size); + return size; +} + +static size_t getMaxAttentionBodySize(const LegacyWeights& weights, int N) { + const size_t embedding_op_size = weights.ip_emb_b.size(); + + size_t encoder_d_model = 0; + size_t encoder_dff = 0; + + if (weights.encoder.size() > 0) { + encoder_d_model = weights.encoder[0].mha.q_b.size(); + encoder_dff = weights.encoder[0].ffn.dense1_b.size(); + + assert(encoder_d_model == weights.encoder[0].mha.k_b.size()); + assert(encoder_d_model == weights.encoder[0].mha.v_b.size()); + assert(embedding_op_size == weights.encoder[0].ffn.dense2_b.size()); + } + + const size_t encoder_heads = weights.encoder_head_count; + + size_t size = + N * 64 * + std::max(std::max(embedding_op_size, encoder_dff), encoder_d_model); // size of matmul_qk matrix = encoder_heads_ * Batch * 64 * 64 const size_t matmul_qk_size = encoder_heads * N * 64 * 64; @@ -164,6 +199,9 @@ class CudaNetwork : public Network { attn_policy_ = file.format().network_format().policy() == pblczero::NetworkFormat::POLICY_ATTENTION; + attn_body_ = file.format().network_format().network() == + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; + max_batch_size_ = options.GetOrDefault("max_batch", 1024); showInfo(); @@ -231,6 +269,11 @@ class CudaNetwork : public Network { numBlocks_ = (int)weights.residual.size(); numFilters_ = kNumFilters; + num_encoder_blocks_ = (int)weights.encoder.size(); + if (attn_body_) { + assert(weights.ip_emb_b.size() > 0); + } + // Warn if the memory required for storing transformed weights is // going to exceed 40% of total video memory, force custom_winograd off // if it's going to exceed 50% of memory. @@ -266,7 +309,7 @@ class CudaNetwork : public Network { // 0. Check for SE. has_se_ = false; - if (weights.residual[0].has_se) { + if (numBlocks_ && weights.residual[0].has_se) { has_se_ = true; } @@ -284,84 +327,122 @@ class CudaNetwork : public Network { // Need additional space for transformed input/outputs which are 36/16 // times size (4x4 block transformed into 6x6). - const size_t transformed_tensor_size = (size_t)( - max_batch_size_ * kNumFilters * 64 * (36.0 / 16.0) * sizeof(DataType)); - scratch_size_ = std::max(scratch_size_, 2 * transformed_tensor_size); + if (numBlocks_ > 0) { + const size_t transformed_tensor_size = + (size_t)(max_batch_size_ * kNumFilters * 64 * (36.0 / 16.0) * + sizeof(DataType)); + scratch_size_ = std::max(scratch_size_, 2 * transformed_tensor_size); + } - // Attention policy head may need more memory - const size_t attentionSize = + // Attention policy head or body may need more memory + const size_t attentionPolicySize = getMaxAttentionHeadSize(weights, max_batch_size_) * sizeof(DataType); - scratch_size_ = std::max(scratch_size_, attentionSize); + + const size_t attentionBodySize = + getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(DataType); + scratch_size_ = std::max(scratch_size_, + std::max(attentionPolicySize, attentionBodySize)); ReportCUDAErrors(cudaMalloc(&scratch_mem_, scratch_size_)); const bool mish_net = file.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH; + ActivationFunction act = mish_net ? ACTIVATION_MISH : ACTIVATION_RELU; + // 2. Build the network, and copy the weights to GPU memory. - // Input. - { - auto inputConv = std::make_unique>( - nullptr, kNumFilters, 8, 8, kNumInputPlanes, mish_net ? MISH : RELU, - true, false, false, 0, use_gemm_ex, use_res_block_winograd_fuse_opt_); - inputConv->LoadWeights(&weights.input.weights[0], - &weights.input.biases[0], scratch_mem_); - network_.emplace_back(std::move(inputConv)); - } - - // Residual block. - for (int block = 0; block < numBlocks_; block++) { - bool has_se = weights.residual[block].has_se; - int se_k = (int)weights.residual[block].se.b1.size(); - - if (use_res_block_winograd_fuse_opt_) { - auto layer = std::make_unique>( - getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, block == 0, - block == (numBlocks_ - 1), mish_net ? MISH : RELU, - deviceProp.sharedMemPerBlockOptin); - layer->LoadWeights0(&weights.residual[block].conv1.weights[0], - &weights.residual[block].conv1.biases[0], - scratch_mem_); - layer->LoadWeights1(&weights.residual[block].conv2.weights[0], - &weights.residual[block].conv2.biases[0], - scratch_mem_); - if (has_se) - layer->LoadSEWeights(&weights.residual[block].se.w1[0], - &weights.residual[block].se.b1[0], - &weights.residual[block].se.w2[0], - &weights.residual[block].se.b2[0], scratch_mem_); - network_.emplace_back(std::move(layer)); - } else { - auto conv1 = std::make_unique>( - getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, false, false, 0, use_gemm_ex); - conv1->LoadWeights(&weights.residual[block].conv1.weights[0], - &weights.residual[block].conv1.biases[0], - scratch_mem_); - network_.emplace_back(std::move(conv1)); - - auto conv2 = std::make_unique>( - getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, true, has_se, se_k, use_gemm_ex); - conv2->LoadWeights(&weights.residual[block].conv2.weights[0], - &weights.residual[block].conv2.biases[0], - scratch_mem_); - if (has_se) - conv2->LoadSEWeights(&weights.residual[block].se.w1[0], - &weights.residual[block].se.b1[0], - &weights.residual[block].se.w2[0], - &weights.residual[block].se.b2[0], scratch_mem_); - network_.emplace_back(std::move(conv2)); + // Input conv only used if there are residual blocks in the network + if (numBlocks_ > 0) { + // Input. + { + auto inputConv = std::make_unique>( + nullptr, kNumFilters, 8, 8, kNumInputPlanes, act, true, false, + false, 0, use_gemm_ex, use_res_block_winograd_fuse_opt_); + inputConv->LoadWeights(&weights.input.weights[0], + &weights.input.biases[0], scratch_mem_); + network_.emplace_back(std::move(inputConv)); + } + + // Residual block. + for (int block = 0; block < numBlocks_; block++) { + bool has_se = weights.residual[block].has_se; + int se_k = (int)weights.residual[block].se.b1.size(); + + if (use_res_block_winograd_fuse_opt_) { + auto layer = std::make_unique>( + getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, + block == 0, block == (numBlocks_ - 1), act, + deviceProp.sharedMemPerBlockOptin); + layer->LoadWeights0(&weights.residual[block].conv1.weights[0], + &weights.residual[block].conv1.biases[0], + scratch_mem_); + layer->LoadWeights1(&weights.residual[block].conv2.weights[0], + &weights.residual[block].conv2.biases[0], + scratch_mem_); + if (has_se) + layer->LoadSEWeights(&weights.residual[block].se.w1[0], + &weights.residual[block].se.b1[0], + &weights.residual[block].se.w2[0], + &weights.residual[block].se.b2[0], + scratch_mem_); + network_.emplace_back(std::move(layer)); + } else { + auto conv1 = std::make_unique>( + getLastLayer(), kNumFilters, 8, 8, kNumFilters, act, true, false, + false, 0, use_gemm_ex); + conv1->LoadWeights(&weights.residual[block].conv1.weights[0], + &weights.residual[block].conv1.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv1)); + + auto conv2 = std::make_unique>( + getLastLayer(), kNumFilters, 8, 8, kNumFilters, act, true, true, + has_se, se_k, use_gemm_ex); + conv2->LoadWeights(&weights.residual[block].conv2.weights[0], + &weights.residual[block].conv2.biases[0], + scratch_mem_); + if (has_se) + conv2->LoadSEWeights(&weights.residual[block].se.w1[0], + &weights.residual[block].se.b1[0], + &weights.residual[block].se.w2[0], + &weights.residual[block].se.b2[0], + scratch_mem_); + network_.emplace_back(std::move(conv2)); + } } + resi_last_ = getLastLayer(); } - resi_last_ = getLastLayer(); + if (attn_body_) { + Activations activations; + const auto smolgen_activation = + file.format().network_format().smolgen_activation(); + activations.smolgen_activation = + smolgen_activation == pblczero::NetworkFormat::ACTIVATION_DEFAULT + ? act + : static_cast(smolgen_activation); + const auto ffn_activation = + file.format().network_format().ffn_activation(); + activations.ffn_activation = + ffn_activation == pblczero::NetworkFormat::ACTIVATION_DEFAULT + ? act + : static_cast(ffn_activation); + activations.default_activation = act; + + auto attention_body = std::make_unique>( + weights, scratch_mem_, activations, numBlocks_, + numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_); + network_.emplace_back(std::move(attention_body)); + + encoder_last_ = getLastLayer(); + } // Policy head. if (attn_policy_) { auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_); + getLastLayer(), weights, scratch_mem_, attn_body_, act, + max_batch_size_); network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( @@ -370,9 +451,10 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(policymap)); } else if (conv_policy_) { + assert(!attn_body_); // not supported with attention body auto conv1 = std::make_unique>( - resi_last_, kNumFilters, 8, 8, kNumFilters, mish_net ? MISH : RELU, - true, false, false, 0, use_gemm_ex); + resi_last_, kNumFilters, 8, 8, kNumFilters, act, true, false, false, + 0, use_gemm_ex); conv1->LoadWeights(&weights.policy1.weights[0], &weights.policy1.biases[0], scratch_mem_); network_.emplace_back(std::move(conv1)); @@ -381,8 +463,8 @@ class CudaNetwork : public Network { // No relu auto conv2 = std::make_unique>( - getLastLayer(), pol_channels, 8, 8, kNumFilters, NONE, true, false, - false, 0, use_gemm_ex); + getLastLayer(), pol_channels, 8, 8, kNumFilters, ACTIVATION_NONE, + true, false, false, 0, use_gemm_ex); conv2->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], scratch_mem_); network_.emplace_back(std::move(conv2)); @@ -393,33 +475,39 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(policymap)); } else { + assert(!attn_body_); // not supported with attention body auto convPol = std::make_unique>( - resi_last_, weights.policy.biases.size(), 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, use_gemm_ex); + resi_last_, weights.policy.biases.size(), 8, 8, kNumFilters, act, + true, use_gemm_ex); convPol->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], scratch_mem_); network_.emplace_back(std::move(convPol)); auto FCPol = std::make_unique>( - getLastLayer(), weights.ip_pol_b.size(), 1, 1, true, NONE); + getLastLayer(), weights.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0], scratch_mem_); network_.emplace_back(std::move(FCPol)); } - policy_out_ = getLastLayer(); // Value head. { - auto convVal = std::make_unique>( - resi_last_, weights.value.biases.size(), 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, use_gemm_ex); - convVal->LoadWeights(&weights.value.weights[0], &weights.value.biases[0], - scratch_mem_); - network_.emplace_back(std::move(convVal)); + if (attn_body_) { + auto embedded_val = std::make_unique>( + encoder_last_, weights.ip_val_w, weights.ip_val_b, scratch_mem_, + act); + network_.emplace_back(std::move(embedded_val)); + } else { + auto convVal = std::make_unique>( + resi_last_, weights.value.biases.size(), 8, 8, kNumFilters, act, + true, use_gemm_ex); + convVal->LoadWeights(&weights.value.weights[0], + &weights.value.biases[0], scratch_mem_); + network_.emplace_back(std::move(convVal)); + } auto FCVal1 = std::make_unique>( - getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, - mish_net ? MISH : RELU); + getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, act); FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal1)); @@ -430,39 +518,42 @@ class CudaNetwork : public Network { auto FCVal2 = std::make_unique>( getLastLayer(), weights.ip2_val_b.size(), 1, 1, true, - fc2_tanh ? TANH : NONE); + fc2_tanh ? ACTIVATION_TANH : ACTIVATION_NONE); FCVal2->LoadWeights(&weights.ip2_val_w[0], &weights.ip2_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal2)); } - value_out_ = getLastLayer(); // Moves left head moves_left_ = (file.format().network_format().moves_left() == pblczero::NetworkFormat::MOVES_LEFT_V1) && options.GetOrDefault("mlh", true); if (moves_left_) { - auto convMov = std::make_unique>( - resi_last_, weights.moves_left.biases.size(), 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, use_gemm_ex); - convMov->LoadWeights(&weights.moves_left.weights[0], - &weights.moves_left.biases[0], scratch_mem_); - network_.emplace_back(std::move(convMov)); - + if (attn_body_) { + auto embedded_mov = std::make_unique>( + encoder_last_, weights.ip_mov_w, weights.ip_mov_b, scratch_mem_, + act); + network_.emplace_back(std::move(embedded_mov)); + } else { + auto convMov = std::make_unique>( + resi_last_, weights.moves_left.biases.size(), 8, 8, kNumFilters, + act, true, use_gemm_ex); + convMov->LoadWeights(&weights.moves_left.weights[0], + &weights.moves_left.biases[0], scratch_mem_); + network_.emplace_back(std::move(convMov)); + } auto FCMov1 = std::make_unique>( - getLastLayer(), weights.ip1_mov_b.size(), 1, 1, true, - mish_net ? MISH : RELU); + getLastLayer(), weights.ip1_mov_b.size(), 1, 1, true, act); FCMov1->LoadWeights(&weights.ip1_mov_w[0], &weights.ip1_mov_b[0], scratch_mem_); network_.emplace_back(std::move(FCMov1)); auto FCMov2 = std::make_unique>(getLastLayer(), 1, 1, 1, - true, RELU); + true, ACTIVATION_RELU); FCMov2->LoadWeights(&weights.ip2_mov_w[0], &weights.ip2_mov_b[0], scratch_mem_); network_.emplace_back(std::move(FCMov2)); } - moves_left_out_ = getLastLayer(); // 3. Allocate GPU memory for running the network: // - three buffers of max size are enough (one to hold input, second to @@ -476,7 +567,7 @@ class CudaNetwork : public Network { maxSize = std::max(maxSize, layer->GetOutputSize(max_batch_size_)); } - if ((attn_policy_ || use_res_block_winograd_fuse_opt_) && + if ((attn_policy_ || use_res_block_winograd_fuse_opt_ || attn_body_) && (scratch_size_ > maxSize)) { maxSize = scratch_size_; } @@ -509,6 +600,8 @@ class CudaNetwork : public Network { DataType* tensor_mem[3]; void* scratch_mem; + DataType*** offset_pointers; + DataType*** head_offset_pointers; cudaStream_t stream; cublasHandle_t cublas; if (multi_stream_) { @@ -516,11 +609,15 @@ class CudaNetwork : public Network { // requests can run in parallel) for (int i = 0; i < 3; i++) tensor_mem[i] = (DataType*)io->tensor_mem_[i]; scratch_mem = io->scratch_mem_; + offset_pointers = (DataType***)&io->offset_pointers_; + head_offset_pointers = (DataType***)&io->head_offset_pointers_; stream = io->stream_; cublas = io->cublas_; } else { for (int i = 0; i < 3; i++) tensor_mem[i] = tensor_mem_[i]; scratch_mem = scratch_mem_; + offset_pointers = (DataType***)&offset_pointers_; + head_offset_pointers = (DataType***)&head_offset_pointers_; stream = 0; // default stream cublas = cublas_; } @@ -538,7 +635,6 @@ class CudaNetwork : public Network { float* opVal = io->op_value_mem_gpu_; float* opMov = io->op_moves_left_mem_gpu_; - // Figure out if the memory requirment for running the res block would fit // in the L2 cache. bool enableCacheOpt = false; @@ -573,27 +669,51 @@ class CudaNetwork : public Network { #endif int l = 0; - // Input. - network_[l++]->Eval(batchSize, skip_connection, tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // input conv - - // Residual block. - for (int block = 0; block < numBlocks_; block++) { - if (use_res_block_winograd_fuse_opt_) { - network_[l++]->Eval(batchSize, tensor_mem[2], skip_connection, nullptr, - enableCacheOpt ? nullptr : scratch_mem, - scratch_size_, nullptr, cublas, - stream); // block - } else { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // conv1 - network_[l++]->Eval(batchSize, tensor_mem[2], tensor_mem[0], - tensor_mem[2], scratch_mem, scratch_size_, nullptr, - cublas, stream); // conv2 + DataType* flow = tensor_mem[0]; + DataType* spare1 = tensor_mem[1]; + DataType* spare2 = tensor_mem[2]; + + if (numBlocks_ > 0) { + // Input. + network_[l++]->Eval(batchSize, skip_connection, tensor_mem[0], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // input conv + + // Residual block. + for (int block = 0; block < numBlocks_; block++) { + if (use_res_block_winograd_fuse_opt_) { + network_[l++]->Eval(batchSize, tensor_mem[2], skip_connection, + nullptr, enableCacheOpt ? nullptr : scratch_mem, + scratch_size_, nullptr, cublas, + stream); // block + } else { + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // conv1 + + network_[l++]->Eval(batchSize, tensor_mem[2], tensor_mem[0], + tensor_mem[2], scratch_mem, scratch_size_, + nullptr, cublas, stream); // conv2 + } } + + flow = tensor_mem[2]; + spare1 = tensor_mem[0]; + spare2 = tensor_mem[1]; + } + + if (attn_body_) { + network_[l++]->Eval( + batchSize, tensor_mem[1], + (numBlocks_ > 0) ? tensor_mem[2] : tensor_mem[0], + (numBlocks_ > 0) ? tensor_mem[0] : tensor_mem[2], scratch_mem, + scratch_size_, nullptr, cublas, stream, + offset_pointers); // Entire attention body of the network + + flow = tensor_mem[1]; + spare1 = tensor_mem[0]; + spare2 = tensor_mem[2]; } #if CUDART_VERSION >= 11000 @@ -609,57 +729,56 @@ class CudaNetwork : public Network { // Policy head. if (attn_policy_) { network_[l++]->Eval( - batchSize, tensor_mem[0], tensor_mem[2], tensor_mem[1], scratch_mem, - scratch_size_, nullptr, cublas, - stream); // Entire Attention policy head except for the policy map + batchSize, spare1, flow, spare2, scratch_mem, scratch_size_, nullptr, + cublas, stream, + head_offset_pointers); // Entire Attention policy head except for the + // policy map if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // policy map layer - copyTypeConverted(opPol, (half*)(tensor_mem[1]), - batchSize * kNumOutputPolicy, + copyTypeConverted(opPol, (half*)spare2, batchSize * kNumOutputPolicy, stream); // POLICY output } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opPol, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // policy map layer // POLICY output } } else if (conv_policy_) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // policy conv1 - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // policy conv2 if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // policy map layer - copyTypeConverted(opPol, (half*)(tensor_mem[0]), - batchSize * kNumOutputPolicy, + copyTypeConverted(opPol, (half*)(spare1), batchSize * kNumOutputPolicy, stream); // POLICY output } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[1], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opPol, spare2, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // policy map layer // POLICY output } } else { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // pol conv if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // pol FC - copyTypeConverted(opPol, (half*)(tensor_mem[1]), - batchSize * kNumOutputPolicy, stream); // POLICY + copyTypeConverted(opPol, (half*)(spare2), batchSize * kNumOutputPolicy, + stream); // POLICY } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opPol, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // pol FC // POLICY } @@ -672,36 +791,36 @@ class CudaNetwork : public Network { cudaMemcpyDeviceToHost, stream)); // value head - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value conv + network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, + stream); // value conv or embedding - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // value FC1 if (wdl_) { if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // value FC2 // VALUE - copyTypeConverted(opVal, (half*)(tensor_mem[0]), 3 * batchSize, + copyTypeConverted(opVal, (half*)spare1, 3 * batchSize, stream); // VALUE } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opVal, spare2, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // value FC2 // VALUE } } else { if (fp16) { // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // value FC2 - copyTypeConverted(opVal, (half*)(tensor_mem[0]), batchSize, + copyTypeConverted(opVal, (half*)(spare1), batchSize, stream); // VALUE } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opVal, spare2, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // value FC2 // VALUE } @@ -709,23 +828,22 @@ class CudaNetwork : public Network { if (moves_left_) { // Moves left head - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // moves conv + network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, + stream); // moves conv or embedding - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); // moves FC1 // Moves left FC2 if (fp16) { // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); - copyTypeConverted(opMov, (half*)(tensor_mem[0]), batchSize, stream); + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, scratch_mem, + scratch_size_, nullptr, cublas, stream); + copyTypeConverted(opMov, (half*)(spare1), batchSize, stream); } else { - network_[l++]->Eval(batchSize, (DataType*)opMov, tensor_mem[1], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opMov, spare2, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); } @@ -766,6 +884,9 @@ class CudaNetwork : public Network { for (auto mem : tensor_mem_) { if (mem) ReportCUDAErrors(cudaFree(mem)); } + if (offset_pointers_) ReportCUDAErrors(cudaFree(offset_pointers_)); + if (head_offset_pointers_) + ReportCUDAErrors(cudaFree(head_offset_pointers_)); cublasDestroy(cublas_); } } @@ -816,7 +937,7 @@ class CudaNetwork : public Network { bool use_res_block_winograd_fuse_opt_; // fuse operations inside the residual // tower bool multi_stream_; // run multiple parallel network evals - bool allow_cache_opt_; // try to fit residual block activations in L2 cache + bool allow_cache_opt_; // try to fit residual block activations in L2 cache // Currently only one NN Eval can happen a time (we can fix this if needed // by allocating more memory). @@ -827,19 +948,22 @@ class CudaNetwork : public Network { bool has_se_; bool conv_policy_; bool attn_policy_; + bool attn_body_; + int num_encoder_blocks_; std::vector>> network_; BaseLayer* getLastLayer() { return network_.back().get(); } BaseLayer* resi_last_; - BaseLayer* policy_out_; - BaseLayer* value_out_; - BaseLayer* moves_left_out_; + BaseLayer* encoder_last_; size_t tensor_mem_size_; size_t scratch_size_; // this copy is used only for initialization when multi-stream is enabled void* scratch_mem_; + // this is only used when multi-stream is disabled + void** offset_pointers_ = nullptr; + void** head_offset_pointers_ = nullptr; bool has_tensor_cores_; @@ -930,7 +1054,9 @@ std::unique_ptr MakeCudaNetwork(const std::optional& w, if (weights.format().network_format().network() != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT) { + pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + weights.format().network_format().network() != + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( weights.format().network_format().network()) + diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index a17a1290c0..d68d280e72 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -37,47 +37,17 @@ #include "layers.h" #include "neural/factory.h" #include "neural/network_legacy.h" +#include "neural/shared/activation.h" #include "neural/shared/attention_policy_map.h" #include "neural/shared/policy_map.h" #include "utils/bititer.h" #include "utils/exception.h" -//#define DEBUG_RAW_NPS +// #define DEBUG_RAW_NPS namespace lczero { using namespace cudnn_backend; -#if 0 -// debug code to dump allocation in GPU memory -void dumpTensor(void *memory, int elements, const char *message, bool fp16 = false) -{ - printf("\n%s\n", message); - int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); - int bytes = elements * elementSize; - void *temp = malloc(bytes); - cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); - - for (int i = 0; i < elements; i++) - { - float val; - if (fp16) - { - half *arr = (half*)temp; - val = (float)arr[i]; - } - else - { - float *arr = (float *)temp; - val = arr[i]; - } - printf("%8.4f ", val); - if ((i % 8) == 7) printf("\n"); - } - free(temp); - printf("\n"); -} -#endif - template class CudnnNetwork; @@ -100,8 +70,9 @@ static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { const size_t encoder_heads = weights.pol_encoder_head_count; - size_t size = N * 64 * std::max(std::max(embedding_op_size, encoder_dff), - policy_d_model); + size_t size = + N * 64 * + std::max(std::max(embedding_op_size, encoder_dff), policy_d_model); // size of matmul_qk matrix = encoder_heads_ * Batch * 64 * 64 const size_t matmul_qk_size = encoder_heads * N * 64 * 64; @@ -432,15 +403,16 @@ class CudnnNetwork : public Network { // Input. if (use_custom_winograd_) { auto inputConv = std::make_unique>( - nullptr, kNumFilters, 8, 8, kNumInputPlanes, mish_net ? MISH : RELU, - true, false, false, 0, use_gemm_ex, use_res_block_winograd_fuse_opt_); + nullptr, kNumFilters, 8, 8, kNumInputPlanes, + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true, false, false, 0, + use_gemm_ex, use_res_block_winograd_fuse_opt_); inputConv->LoadWeights(&weights.input.weights[0], &weights.input.biases[0], scratch_mem_); network_.emplace_back(std::move(inputConv)); } else { auto inputConv = std::make_unique>( - nhwc_, kNumFilters, 8, 8, 3, kNumInputPlanes, mish_net ? MISH : RELU, - true); + nhwc_, kNumFilters, 8, 8, 3, kNumInputPlanes, + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); inputConv->LoadWeights(&weights.input.weights[0], &weights.input.biases[0], scratch_mem_); network_.emplace_back(std::move(inputConv)); @@ -455,7 +427,8 @@ class CudnnNetwork : public Network { if (use_res_block_winograd_fuse_opt_) { auto layer = std::make_unique>( getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, - block == 0, block == (numBlocks_ - 1), mish_net ? MISH : RELU, + block == 0, block == (numBlocks_ - 1), + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, deviceProp.sharedMemPerBlockOptin); layer->LoadWeights0(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], @@ -473,7 +446,8 @@ class CudnnNetwork : public Network { } else { auto conv1 = std::make_unique>( getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, false, false, 0, use_gemm_ex); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true, false, false, + 0, use_gemm_ex); conv1->LoadWeights(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], scratch_mem_); @@ -481,7 +455,8 @@ class CudnnNetwork : public Network { auto conv2 = std::make_unique>( getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, true, has_se, se_k, use_gemm_ex); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true, true, has_se, + se_k, use_gemm_ex); conv2->LoadWeights(&weights.residual[block].conv2.weights[0], &weights.residual[block].conv2.biases[0], scratch_mem_); @@ -497,7 +472,7 @@ class CudnnNetwork : public Network { } else { auto conv1 = std::make_unique>( getLastLayer(), kNumFilters, 8, 8, 3, kNumFilters, - mish_net ? MISH : RELU, true); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); conv1->LoadWeights(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], scratch_mem_); @@ -508,7 +483,9 @@ class CudnnNetwork : public Network { auto conv2 = std::make_unique>( getLastLayer(), kNumFilters, 8, 8, 3, kNumFilters, - useReluAndBias ? (mish_net ? MISH : RELU) : NONE, useReluAndBias); + useReluAndBias ? (mish_net ? ACTIVATION_MISH : ACTIVATION_RELU) + : ACTIVATION_NONE, + useReluAndBias); conv2->LoadWeights( &weights.residual[block].conv2.weights[0], useReluAndBias ? &weights.residual[block].conv2.biases[0] : nullptr, @@ -518,7 +495,8 @@ class CudnnNetwork : public Network { if (weights.residual[block].has_se) { int numFCOut = (int)weights.residual[block].se.b1.size(); auto se = std::make_unique>( - getLastLayer(), numFCOut, false, mish_net ? MISH : RELU); + getLastLayer(), numFCOut, false, + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU); se->LoadWeights(&weights.residual[block].se.w1[0], &weights.residual[block].se.b1[0], &weights.residual[block].se.w2[0], @@ -535,7 +513,8 @@ class CudnnNetwork : public Network { // Policy head. if (attn_policy_) { auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_); + getLastLayer(), weights, scratch_mem_, false, ACTIVATION_SELU, + max_batch_size_); network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( @@ -544,8 +523,8 @@ class CudnnNetwork : public Network { network_.emplace_back(std::move(policymap)); } else if (conv_policy_) { auto conv1 = std::make_unique>( - resi_last_, kNumFilters, 8, 8, 3, kNumFilters, mish_net ? MISH : RELU, - true); + resi_last_, kNumFilters, 8, 8, 3, kNumFilters, + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); conv1->LoadWeights(&weights.policy1.weights[0], &weights.policy1.biases[0], scratch_mem_); network_.emplace_back(std::move(conv1)); @@ -554,7 +533,8 @@ class CudnnNetwork : public Network { // No relu auto conv2 = std::make_unique>( - getLastLayer(), pol_channels, 8, 8, 3, kNumFilters, NONE, true); + getLastLayer(), pol_channels, 8, 8, 3, kNumFilters, ACTIVATION_NONE, + true); conv2->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], scratch_mem_); network_.emplace_back(std::move(conv2)); @@ -567,13 +547,13 @@ class CudnnNetwork : public Network { } else { auto convPol = std::make_unique>( resi_last_, weights.policy.biases.size(), 8, 8, 1, kNumFilters, - mish_net ? MISH : RELU, true); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); convPol->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], scratch_mem_); network_.emplace_back(std::move(convPol)); auto FCPol = std::make_unique>( - getLastLayer(), weights.ip_pol_b.size(), 1, 1, true, NONE); + getLastLayer(), weights.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0], scratch_mem_); network_.emplace_back(std::move(FCPol)); @@ -584,14 +564,14 @@ class CudnnNetwork : public Network { { auto convVal = std::make_unique>( resi_last_, weights.value.biases.size(), 8, 8, 1, kNumFilters, - mish_net ? MISH : RELU, true); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); convVal->LoadWeights(&weights.value.weights[0], &weights.value.biases[0], scratch_mem_); network_.emplace_back(std::move(convVal)); auto FCVal1 = std::make_unique>( getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, - mish_net ? MISH : RELU); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU); FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal1)); @@ -602,7 +582,7 @@ class CudnnNetwork : public Network { auto FCVal2 = std::make_unique>( getLastLayer(), weights.ip2_val_b.size(), 1, 1, true, - fc2_tanh ? TANH : NONE); + fc2_tanh ? ACTIVATION_TANH : ACTIVATION_NONE); FCVal2->LoadWeights(&weights.ip2_val_w[0], &weights.ip2_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal2)); @@ -616,20 +596,20 @@ class CudnnNetwork : public Network { if (moves_left_) { auto convMov = std::make_unique>( resi_last_, weights.moves_left.biases.size(), 8, 8, 1, kNumFilters, - mish_net ? MISH : RELU, true); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); convMov->LoadWeights(&weights.moves_left.weights[0], &weights.moves_left.biases[0], scratch_mem_); network_.emplace_back(std::move(convMov)); auto FCMov1 = std::make_unique>( getLastLayer(), weights.ip1_mov_b.size(), 1, 1, true, - mish_net ? MISH : RELU); + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU); FCMov1->LoadWeights(&weights.ip1_mov_w[0], &weights.ip1_mov_b[0], scratch_mem_); network_.emplace_back(std::move(FCMov1)); auto FCMov2 = std::make_unique>(getLastLayer(), 1, 1, 1, - true, RELU); + true, ACTIVATION_RELU); FCMov2->LoadWeights(&weights.ip2_mov_w[0], &weights.ip2_mov_b[0], scratch_mem_); network_.emplace_back(std::move(FCMov2)); @@ -755,8 +735,9 @@ class CudnnNetwork : public Network { if (attn_policy_) { network_[l++]->Eval( batchSize, tensor_mem_[0], tensor_mem_[2], tensor_mem_[1], - scratch_mem_, scratch_size_, nullptr, cublas_, - stream); // Entire Attention policy head except for the policy map + scratch_mem_, scratch_size_, nullptr, cublas_, stream, + &head_offset_pointers_); // Entire Attention policy head except for + // the policy map if (fp16) { network_[l++]->Eval(batchSize, tensor_mem_[1], tensor_mem_[0], nullptr, scratch_mem_, scratch_size_, nullptr, cublas_, @@ -918,6 +899,8 @@ class CudnnNetwork : public Network { if (mem) ReportCUDAErrors(cudaFree(mem)); } if (scratch_mem_) ReportCUDAErrors(cudaFree(scratch_mem_)); + if (head_offset_pointers_) + ReportCUDAErrors(cudaFree(head_offset_pointers_)); cudnnDestroy(cudnn_); cublasDestroy(cublas_); } @@ -993,6 +976,7 @@ class CudnnNetwork : public Network { DataType* tensor_mem_[3]; void* scratch_mem_; + DataType** head_offset_pointers_ = nullptr; size_t scratch_size_; mutable std::mutex inputs_outputs_lock_; diff --git a/src/neural/cuda/winograd_helper.inc b/src/neural/cuda/winograd_helper.inc index 3f362f3074..09c5d0fcc9 100644 --- a/src/neural/cuda/winograd_helper.inc +++ b/src/neural/cuda/winograd_helper.inc @@ -41,16 +41,20 @@ __device__ __forceinline__ float mishActivate(float el) { __device__ __forceinline__ float activate(float cVal, ActivationFunction activation) { switch (activation) { - case RELU: + case ACTIVATION_RELU: if (cVal < 0) cVal = 0; break; - case TANH: + case ACTIVATION_RELU_2: + if (cVal < 0) cVal = 0; + cVal *= cVal; + break; + case ACTIVATION_TANH: cVal = tanh(cVal); break; - case SIGMOID: + case ACTIVATION_SIGMOID: cVal = 1.0f / (1.0f + __expf(-cVal)); break; - case SELU: { + case ACTIVATION_SELU: { float alpha = 1.67326324f, scale = 1.05070098f; if (cVal > 0) cVal = scale * cVal; @@ -58,9 +62,12 @@ __device__ __forceinline__ float activate(float cVal, cVal = scale * alpha * (__expf(cVal) - 1.0f); break; } - case MISH: + case ACTIVATION_MISH: cVal = mishActivate(cVal); break; + case ACTIVATION_SWISH: + cVal /= (1.0f + __expf(-cVal)); + break; } return cVal; } @@ -391,7 +398,7 @@ __global__ void OutputTransform_kernel(int N, int C, int se_K, T* output, } // relu - if (activation != NONE) { + if (activation != ACTIVATION_NONE) { #pragma unroll for (int w = 0; w < 8; w++) board[h][w] = (T)activate((float)board[h][w], activation); @@ -430,11 +437,13 @@ __device__ __forceinline__ float warpMax(float x) { // atomic max implementation for floats __device__ __forceinline__ float atomicMaxFloat(float* addr, float val) { - float max; - max = !signbit(val) ? __int_as_float(atomicMax((int*)addr, __float_as_int(val))) : - __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(val))); + float max; + max = !signbit(val) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(val))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(val))); - return max; + return max; } // Helper fuction to do vector loads/stores @@ -451,10 +460,17 @@ __device__ __forceinline__ void copyAs(void* dst, const void* src) { // every thread generates an entire board/plane (8x8 elements) template -__global__ __launch_bounds__(kMaxResBlockFusingChannels, 1) -void OutputTransform_SE_relu_InputTransform_kernel( - int N, int C, int se_K, T* output, const T* input, T* skip, - const T* bias, const T* w1, const T* b1, const T* w2, const T* b2) { +__global__ __launch_bounds__( + kMaxResBlockFusingChannels, + 1) void OutputTransform_SE_relu_InputTransform_kernel(int N, int C, + int se_K, T* output, + const T* input, + T* skip, + const T* bias, + const T* w1, + const T* b1, + const T* w2, + const T* b2) { const bool fp16 = std::is_same::value; int k = threadIdx.x; @@ -561,7 +577,7 @@ void OutputTransform_SE_relu_InputTransform_kernel( } // relu - if (activation != NONE) { + if (activation != ACTIVATION_NONE) { #pragma unroll for (int w = 0; w < 8; w++) board[h][w] = (T)activate((float)board[h][w], activation); @@ -655,17 +671,18 @@ void OutputTransform_SE_relu_InputTransform_kernel( } } - constexpr int kOpInpTransformBlockSize = 64; -template -__global__ __launch_bounds__(kOpInpTransformBlockSize, 4) -void OutputTransform_relu_InputTransform_kernel(int N, int C, - T* output, const T* input, - T* skip, const T* bias) { +template +__global__ __launch_bounds__( + kOpInpTransformBlockSize, + 4) void OutputTransform_relu_InputTransform_kernel(int N, int C, T* output, + const T* input, T* skip, + const T* bias) { const bool fp16 = std::is_same::value; int k = threadIdx.x + blockIdx.x * kOpInpTransformBlockSize; - if (k >= C) return; // wasted threads (for non-multiple of 64 channel counts) + if (k >= C) return; // wasted threads (for non-multiple of 64 channel counts) int n = blockIdx.y; T board[8][8]; @@ -708,7 +725,6 @@ void OutputTransform_relu_InputTransform_kernel(int N, int C, for (int x = 0; x < 8; x++) if (use_bias) board[y][x] += b; - // Add skip connection, perform relu, and write to output. for (int h = 0; h < 8; h++) { // residual add @@ -718,10 +734,10 @@ void OutputTransform_relu_InputTransform_kernel(int N, int C, } // activation - if (activation != NONE) { + if (activation != ACTIVATION_NONE) { #pragma unroll for (int w = 0; w < 8; w++) - board[h][w] = (T) activate((float)board[h][w], activation); + board[h][w] = (T)activate((float)board[h][w], activation); } // write un-transformed output to 'skip' if required @@ -812,7 +828,6 @@ void OutputTransform_relu_InputTransform_kernel(int N, int C, } } - template void FilterTransform(int N, int C, T* transformedFilter, const T* filter) { // Each thread processes entire filter block (input 3x3 elements -> output 6x6 diff --git a/src/neural/decoder.cc b/src/neural/decoder.cc index 1798523730..34f78466bf 100644 --- a/src/neural/decoder.cc +++ b/src/neural/decoder.cc @@ -72,7 +72,7 @@ void PopulateBoard(pblczero::NetworkFormat::InputFormat input_format, auto kingTheirs = BitBoard(planes[11].mask); ChessBoard::Castlings castlings; switch (input_format) { - case pblczero::NetworkFormat::InputFormat::INPUT_CLASSICAL_112_PLANE: { + case pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE: { if (planes[kAuxPlaneBase + 0].mask != 0) { castlings.set_we_can_000(); } diff --git a/src/neural/encoder.cc b/src/neural/encoder.cc index 63851fad48..b459aec75b 100644 --- a/src/neural/encoder.cc +++ b/src/neural/encoder.cc @@ -247,7 +247,12 @@ InputPlanes EncodePositionForNN( !board.en_passant().empty()) { break; } - if (history_idx < 0 && fill_empty_history == FillEmptyHistory::NO) break; + // If en-passant is possible we know the previous move. + if (fill_empty_history == FillEmptyHistory::NO && + (history_idx < -1 || + (history_idx == -1 && board.en_passant().empty()))) { + break; + } // Board may be flipped so compare with position.GetBoard(). if (history_idx < 0 && fill_empty_history == FillEmptyHistory::FEN_ONLY && position.GetBoard() == ChessBoard::kStartposBoard) { diff --git a/src/neural/loader.cc b/src/neural/loader.cc index 7f3ff3e5fc..4e8e76ec57 100644 --- a/src/neural/loader.cc +++ b/src/neural/loader.cc @@ -108,11 +108,6 @@ void FixOlderWeightsFile(WeightsFile* file) { using nf = pblczero::NetworkFormat; auto network_format = file->format().network_format().network(); const auto has_network_format = file->format().has_network_format(); - if (has_network_format && network_format != nf::NETWORK_CLASSICAL && - network_format != nf::NETWORK_SE) { - // Already in a new format, return unchanged. - return; - } auto* net = file->mutable_format()->mutable_network_format(); if (!has_network_format) { @@ -132,6 +127,18 @@ void FixOlderWeightsFile(WeightsFile* file) { net->set_network(nf::NETWORK_SE_WITH_HEADFORMAT); net->set_value(nf::VALUE_CLASSICAL); net->set_policy(nf::POLICY_CLASSICAL); + } else if (network_format == + pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + file->weights().encoder().size() > 0) { + // Attention body network made with old protobuf. + auto* net = file->mutable_format()->mutable_network_format(); + net->set_network( + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT); + if (file->weights().has_smolgen_w()) { + // Need to override activation defaults for smolgen. + net->set_ffn_activation(pblczero::NetworkFormat::ACTIVATION_RELU_2); + net->set_smolgen_activation(pblczero::NetworkFormat::ACTIVATION_SWISH); + } } } diff --git a/src/neural/metal/metal_common.h b/src/neural/metal/metal_common.h index 07b8e64d0c..a42c00dcac 100644 --- a/src/neural/metal/metal_common.h +++ b/src/neural/metal/metal_common.h @@ -34,7 +34,8 @@ static int kNumOutputPolicy = 1858; static int kInputPlanes = 112; struct InputsOutputs { - InputsOutputs(int maxBatchSize, bool wdl, bool moves_left, bool conv_policy, bool attn_policy) { + InputsOutputs(int maxBatchSize, bool wdl, bool moves_left, bool conv_policy, + bool attn_policy) { input_masks_mem_.reserve(maxBatchSize * kInputPlanes); input_val_mem_.reserve(maxBatchSize * kInputPlanes); input_val_mem_expanded_.reserve(maxBatchSize * kInputPlanes * 64); @@ -46,15 +47,14 @@ struct InputsOutputs { }; /** - * @todo policy map implementation has bug in MPSGraph (GatherND not working in graph). - * Implementation of policy map to be done in CPU for now. + * @todo policy map implementation has bug in MPSGraph (GatherND not working + * in graph). Implementation of policy map to be done in CPU for now. * * Remove this op_policy_raw_mem_ memory allocation when bug is fixed. */ if (attn_policy) { op_policy_raw_mem_.reserve(maxBatchSize * (64 * 64 + 8 * 24)); - } - else if (conv_policy) { + } else if (conv_policy) { op_policy_raw_mem_.reserve(maxBatchSize * 73 * 64); } } diff --git a/src/neural/metal/mps/MetalNetworkBuilder.h b/src/neural/metal/mps/MetalNetworkBuilder.h index f36fa36605..e2822187c3 100644 --- a/src/neural/metal/mps/MetalNetworkBuilder.h +++ b/src/neural/metal/mps/MetalNetworkBuilder.h @@ -32,23 +32,28 @@ namespace lczero { namespace metal_backend { -class MetalNetworkBuilder { -public: - MetalNetworkBuilder(void); - ~MetalNetworkBuilder(void); - - std::string init(int gpu_id); +struct Activations { + std::string default_activation = "relu"; + std::string smolgen_activation = "swish"; + std::string ffn_activation = "relu_2"; +}; - void build(int kInputPlanes, int channelSize, int kernelSize, LegacyWeights& weights, bool attn_policy, bool conv_policy, bool wdl, bool moves_left, std::string default_activation); +class MetalNetworkBuilder { + public: + MetalNetworkBuilder(void); + ~MetalNetworkBuilder(void); - void forwardEval(float * inputs, int batchSize, std::vector output_mems); + std::string init(int gpu_id); - void saveVariables(std::vector names); + void build(int kInputPlanes, LegacyWeights& weights, bool attn_body, + bool attn_policy, bool conv_policy, bool wdl, bool moves_left, + Activations activations); - void dumpVariables(std::vector names, int batches); + void forwardEval(float* inputs, int batchSize, + std::vector output_mems); -private: - int gpu_id; + private: + int gpu_id; }; } // namespace metal_backend diff --git a/src/neural/metal/mps/MetalNetworkBuilder.mm b/src/neural/metal/mps/MetalNetworkBuilder.mm index 2a25accdfd..e3353f499e 100644 --- a/src/neural/metal/mps/MetalNetworkBuilder.mm +++ b/src/neural/metal/mps/MetalNetworkBuilder.mm @@ -26,6 +26,7 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a */ #import "neural/network_legacy.h" +#import "neural/shared/attention_policy_map.h" #import "MetalNetworkBuilder.h" #import "NetworkGraph.h" @@ -56,148 +57,156 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a return std::string([devices[gpu_id].name UTF8String]); } -void MetalNetworkBuilder::build(int kInputPlanes, int channelSize, int kernelSize, LegacyWeights& weights, bool attn_policy, bool conv_policy, bool wdl, bool moves_left, std::string default_activation) +void MetalNetworkBuilder::build(int kInputPlanes, LegacyWeights& weights, bool attn_body, bool attn_policy, bool conv_policy, bool wdl, bool moves_left, Activations activations) { Lc0NetworkGraph * graph = [Lc0NetworkGraph getGraphAt:[NSNumber numberWithInt:this->gpu_id]]; - MPSGraphTensor * layer; - NSString * defaultActivation = [NSString stringWithUTF8String:default_activation.c_str()]; + NSString * defaultActivation = [NSString stringWithUTF8String:activations.default_activation.c_str()]; + NSString * smolgenActivation = [NSString stringWithUTF8String:activations.smolgen_activation.c_str()]; + NSString * ffnActivation = [NSString stringWithUTF8String:activations.ffn_activation.c_str()]; // 0. Input placeholder. - layer = [graph inputPlaceholderWithInputChannels:kInputPlanes - height:8 - width:8 - label:@"inputs"]; - - // 1. Input layer - layer = [graph addConvolutionBlockWithParent:layer - inputChannels:kInputPlanes - outputChannels:channelSize - kernelSize:kernelSize - weights:&weights.input.weights[0] - biases:&weights.input.biases[0] - activation:defaultActivation - label:@"input/conv"]; - - // 2. Residual blocks - for (size_t i = 0; i < weights.residual.size(); i++) { - layer = [graph addResidualBlockWithParent:layer - inputChannels:channelSize - outputChannels:channelSize - kernelSize:kernelSize - weights1:&weights.residual[i].conv1.weights[0] - biases1:&weights.residual[i].conv1.biases[0] - weights2:&weights.residual[i].conv2.weights[0] - biases2:&weights.residual[i].conv2.biases[0] - label:[NSString stringWithFormat:@"block_%zu", i] - hasSe:weights.residual[i].has_se ? YES : NO - seWeights1:&weights.residual[i].se.w1[0] - seBiases1:&weights.residual[i].se.b1[0] - seWeights2:&weights.residual[i].se.w2[0] - seBiases2:&weights.residual[i].se.b2[0] - seFcOutputs:weights.residual[i].se.b1.size() - activation:defaultActivation]; + // @todo - placeholder can be made directly as NHWC to avoid transposes. + MPSGraphTensor * layer = [graph inputPlaceholderWithInputChannels:kInputPlanes + height:8 + width:8 + label:@"inputs"]; + + const NSUInteger kernelSize = 3; + + // Initialize global smolgen weights. + if (weights.has_smolgen) { + [graph setGlobalSmolgenWeights:&weights.smolgen_w[0]]; + } + + // Input conv layer only when there are residual blocks. + if (weights.residual.size() > 0) { + + const NSUInteger channelSize = weights.input.weights.size() / (kInputPlanes * kernelSize * kernelSize); + + // 1. Input layer + layer = [graph addConvolutionBlockWithParent:layer + outputChannels:channelSize + kernelSize:kernelSize + weights:&weights.input.weights[0] + biases:&weights.input.biases[0] + activation:defaultActivation + label:@"input/conv"]; + + // 2. Residual blocks + for (size_t i = 0; i < weights.residual.size(); i++) { + layer = [graph addResidualBlockWithParent:layer + outputChannels:channelSize + kernelSize:kernelSize + weights1:&weights.residual[i].conv1.weights[0] + biases1:&weights.residual[i].conv1.biases[0] + weights2:&weights.residual[i].conv2.weights[0] + biases2:&weights.residual[i].conv2.biases[0] + label:[NSString stringWithFormat:@"block_%zu", i] + hasSe:weights.residual[i].has_se ? YES : NO + seWeights1:&weights.residual[i].se.w1[0] + seBiases1:&weights.residual[i].se.b1[0] + seWeights2:&weights.residual[i].se.w2[0] + seBiases2:&weights.residual[i].se.b2[0] + seFcOutputs:weights.residual[i].se.b1.size() + activation:defaultActivation]; + } + + } + + // Attention body. + if (attn_body) { + + assert(weights.ip_emb_b.size() > 0); + + // 1. NCHW -> NHWC + // @todo if input placeholder is NHWC, then this is not needed for attn_body, but kPosEncoding has to be reordered. + layer = [graph transposeChannelsWithTensor:layer withShape:@[@(-1), @64, layer.shape[1]] label:@"input/nchw_nhwc"]; + + if (weights.residual.size() == 0) { + // No residual means pure transformer, so process input position encoding. + layer = [graph positionEncodingWithTensor:layer + withShape:@[@64, @64] + weights:&kPosEncoding[0][0] + type:nil + label:@"input/position_encoding"]; + } + + // 2a. Input embedding for attention body. + // if self.arc_encoding: @todo needs to be implemented + layer = [graph addFullyConnectedLayerWithParent:layer + outputChannels:weights.ip_emb_b.size() + weights:&weights.ip_emb_w[0] + biases:&weights.ip_emb_b[0] + activation:defaultActivation + label:@"input/embedding"]; + + // # !!! input gate + // flow = ma_gating(flow, name=name+'embedding') + // def ma_gating(inputs, name): + // out = Gating(name=name+'/mult_gate', additive=False)(inputs) + // out = Gating(name=name+'/add_gate', additive=True)(out) + if (weights.ip_mult_gate.size() > 0) { + layer = [graph addGatingLayerWithParent:layer + weights:&weights.ip_mult_gate[0] + withOperation:@"mult" + label:@"input/mult_gate"]; + } + if (weights.ip_add_gate.size() > 0) { + layer = [graph addGatingLayerWithParent:layer + weights:&weights.ip_add_gate[0] + withOperation:@"add" + label:@"input/add_gate"]; + } + // 2b. Attention body encoder layers. + float alpha = (float) pow(2.0 * weights.encoder.size(), 0.25); + for (size_t i = 0; i < weights.encoder.size(); i++) { + layer = [graph addEncoderLayerWithParent:layer + legacyWeights:weights.encoder[i] + heads:weights.encoder_head_count + embeddingSize:weights.ip_emb_b.size() + smolgenActivation:smolgenActivation + ffnActivation:ffnActivation + alpha:alpha + label:[NSString stringWithFormat:@"encoder_%zu", i]]; + } } // 3. Policy head. MPSGraphTensor * policy; if (attn_policy) { // 1. NCHW -> NHWC - policy = [graph transposeChannelsWithTensor:layer label:@"policy/nchw_nhwc"]; + if (!attn_body) { + policy = [graph transposeChannelsWithTensor:layer withShape:@[@(-1), @64, layer.shape[1]] label:@"policy/nchw_nhwc"]; + } + else { + policy = layer; + } + + // 2. Square Embedding: Dense with default activation (or SELU for old ap-mish nets). NSUInteger embeddingSize = weights.ip_pol_b.size(); NSUInteger policyDModel = weights.ip2_pol_b.size(); - - // 2. Square Embedding: Dense with SELU + // ap-mish uses hardcoded SELU policy = [graph addFullyConnectedLayerWithParent:policy - inputChannels:channelSize outputChannels:embeddingSize weights:&weights.ip_pol_w[0] biases:&weights.ip_pol_b[0] - activation:@"selu" + activation:attn_body ? defaultActivation : @"selu" label:@"policy/fc_embed"]; // 3. Encoder layers - MPSGraphTensor * mhaQ, * mhaK, * mhaV; - NSUInteger dModel; for (NSUInteger i = 0; i < weights.pol_encoder.size(); i++) { - dModel = weights.pol_encoder[i].mha.q_b.size(); - mhaQ = [graph addFullyConnectedLayerWithParent:policy - inputChannels:embeddingSize - outputChannels:weights.pol_encoder[i].mha.q_b.size() - weights:&weights.pol_encoder[i].mha.q_w[0] - biases:&weights.pol_encoder[i].mha.q_b[0] - activation:nil - label:[NSString stringWithFormat:@"policy/encoder_%zu/mhaq/fc", i]]; - - mhaK = [graph addFullyConnectedLayerWithParent:policy - inputChannels:embeddingSize - outputChannels:weights.pol_encoder[i].mha.k_b.size() - weights:&weights.pol_encoder[i].mha.k_w[0] - biases:&weights.pol_encoder[i].mha.k_b[0] - activation:nil - label:[NSString stringWithFormat:@"policy/encoder_%zu/mhak/fc", i]]; - - mhaV = [graph addFullyConnectedLayerWithParent:policy - inputChannels:embeddingSize - outputChannels:weights.pol_encoder[i].mha.v_b.size() - weights:&weights.pol_encoder[i].mha.v_w[0] - biases:&weights.pol_encoder[i].mha.v_b[0] - activation:nil - label:[NSString stringWithFormat:@"policy/encoder_%zu/mhav/fc", i]]; - - MPSGraphTensor * mha = [graph scaledMHAMatmulWithQueries:mhaQ - withKeys:mhaK - withValues:mhaV - heads:weights.pol_encoder_head_count - label:[NSString stringWithFormat:@"policy/encoder_%zu/mha", i]]; - - // MHA final dense layer. - mha = [graph addFullyConnectedLayerWithParent:mha - inputChannels:dModel - outputChannels:embeddingSize - weights:&weights.pol_encoder[i].mha.dense_w[0] - biases:&weights.pol_encoder[i].mha.dense_b[0] - activation:nil - label:[NSString stringWithFormat:@"policy/encoder_%zu/mha/fc", i]]; - - // Skip connection + Layer Norm 1. - policy = [graph addLayerNormalizationWithSkipParent:policy - secondaryInput:mha - gammas:&weights.pol_encoder[i].ln1_gammas[0] - betas:&weights.pol_encoder[i].ln1_betas[0] - channelSize:weights.pol_encoder[i].ln1_gammas.size() - epsilon:1e-6 - label:[NSString stringWithFormat:@"policy/encoder_%zu/ln1", i]]; - - // Feedforward network (FFN). - MPSGraphTensor * ffn = [graph addFullyConnectedLayerWithParent:policy - inputChannels:dModel - outputChannels:weights.pol_encoder[i].ffn.dense1_b.size() - weights:&weights.pol_encoder[i].ffn.dense1_w[0] - biases:&weights.pol_encoder[i].ffn.dense1_b[0] - activation:@"selu" - label:[NSString stringWithFormat:@"policy/encoder_%zu/ffn1", i]]; - - ffn = [graph addFullyConnectedLayerWithParent:ffn - inputChannels:weights.pol_encoder[i].ffn.dense1_b.size() - outputChannels:weights.pol_encoder[i].ffn.dense2_b.size() - weights:&weights.pol_encoder[i].ffn.dense2_w[0] - biases:&weights.pol_encoder[i].ffn.dense2_b[0] - activation:nil - label:[NSString stringWithFormat:@"policy/encoder_%zu/ffn2", i]]; - - // Skip connection + Layer Norm 2. - policy = [graph addLayerNormalizationWithSkipParent:policy - secondaryInput:ffn - gammas:&weights.pol_encoder[i].ln2_gammas[0] - betas:&weights.pol_encoder[i].ln2_betas[0] - channelSize:weights.pol_encoder[i].ln2_gammas.size() - epsilon:1e-6 - label:[NSString stringWithFormat:@"policy/encoder_%zu/ln2", i]]; - } // End of encoder layers. + policy = [graph addEncoderLayerWithParent:policy + legacyWeights:weights.pol_encoder[i] + heads:weights.pol_encoder_head_count + embeddingSize:embeddingSize + smolgenActivation:attn_body ? smolgenActivation : nil + ffnActivation:attn_body ? ffnActivation : @"selu" + alpha:1.0 + label:[NSString stringWithFormat:@"policy/encoder_%zu", i]]; + } // 4. Self-attention q and k. MPSGraphTensor * queries = [graph addFullyConnectedLayerWithParent:policy - inputChannels:embeddingSize outputChannels:policyDModel weights:&weights.ip2_pol_w[0] biases:&weights.ip2_pol_b[0] @@ -205,7 +214,6 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a label:@"policy/self_attention/q"]; MPSGraphTensor * keys = [graph addFullyConnectedLayerWithParent:policy - inputChannels:embeddingSize outputChannels:policyDModel weights:&weights.ip3_pol_w[0] biases:&weights.ip3_pol_b[0] @@ -218,8 +226,6 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a scale:1.0f / sqrt(policyDModel) label:@"policy/self_attention/kq"]; - [graph setVariable:@"policy/self_attention/kq" tensor:policy]; - // 6. Slice last 8 keys (k[:, 56:, :]) and matmul with policy promotion weights, then concat to matmul_qk. policy = [graph attentionPolicyPromoMatmulConcatWithParent:policy withKeys:keys @@ -231,9 +237,12 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a label:@"policy/promo_logits"]; } else if (conv_policy) { + if (attn_body) { + [NSException raise:@"Unsupported architecture." + format:@"Convolutional policy not supported with attention body."]; + } policy = [graph addConvolutionBlockWithParent:layer - inputChannels:channelSize - outputChannels:channelSize + outputChannels:weights.policy1.biases.size() kernelSize:kernelSize weights:&weights.policy1.weights[0] biases:&weights.policy1.biases[0] @@ -242,8 +251,7 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a // No activation. policy = [graph addConvolutionBlockWithParent:policy - inputChannels:channelSize - outputChannels:80 + outputChannels:weights.policy.biases.size() kernelSize:kernelSize weights:&weights.policy.weights[0] biases:&weights.policy.biases[0] @@ -277,10 +285,14 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a */ } else { + if (attn_body) { + [NSException raise:@"Unsupported architecture." + format:@"Classical policy not supported with attention body."]; + } + const int policySize = weights.policy.biases.size(); policy = [graph addConvolutionBlockWithParent:layer - inputChannels:channelSize outputChannels:policySize kernelSize:1 weights:&weights.policy.weights[0] @@ -288,9 +300,12 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a activation:defaultActivation label:@"policy/conv"]; + policy = [graph flatten2DTensor:policy + axis:1 + name:@"policy/conv/flatten"]; + policy = [graph addFullyConnectedLayerWithParent:policy - inputChannels:policySize * 8 * 8 - outputChannels:1858 + outputChannels:weights.ip_pol_b.size() weights:&weights.ip_pol_w[0] biases:&weights.ip_pol_b[0] activation:nil @@ -299,18 +314,30 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a // 4. Value head. MPSGraphTensor * value; - value = [graph addConvolutionBlockWithParent:layer - inputChannels:channelSize - outputChannels:32 - kernelSize:1 - weights:&weights.value.weights[0] - biases:&weights.value.biases[0] - activation:defaultActivation - label:@"value/conv"]; + if (attn_body) { + value = [graph addFullyConnectedLayerWithParent:layer + outputChannels:weights.ip_val_b.size() + weights:&weights.ip_val_w[0] + biases:&weights.ip_val_b[0] + activation:defaultActivation + label:@"value/embedding"]; + } + else { + value = [graph addConvolutionBlockWithParent:layer + outputChannels:weights.value.biases.size() + kernelSize:1 + weights:&weights.value.weights[0] + biases:&weights.value.biases[0] + activation:defaultActivation + label:@"value/conv"]; + } + + value = [graph flatten2DTensor:value + axis:1 + name:@"value/flatten"]; value = [graph addFullyConnectedLayerWithParent:value - inputChannels:32 * 8 * 8 - outputChannels:128 + outputChannels:weights.ip1_val_b.size() weights:&weights.ip1_val_w[0] biases:&weights.ip1_val_b[0] activation:defaultActivation @@ -318,8 +345,7 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a if (wdl) { value = [graph addFullyConnectedLayerWithParent:value - inputChannels:128 - outputChannels:3 + outputChannels:weights.ip2_val_b.size() weights:&weights.ip2_val_w[0] biases:&weights.ip2_val_b[0] activation:@"softmax" @@ -327,8 +353,7 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a } else { value = [graph addFullyConnectedLayerWithParent:value - inputChannels:128 - outputChannels:1 + outputChannels:weights.ip2_val_b.size() weights:&weights.ip2_val_w[0] biases:&weights.ip2_val_b[0] activation:@"tanh" @@ -338,32 +363,41 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a // 5. Moves left head. MPSGraphTensor * mlh; if (moves_left) { - const int mlhChannels = weights.moves_left.biases.size(); + if (attn_body) { + mlh = [graph addFullyConnectedLayerWithParent:layer + outputChannels:weights.ip_mov_b.size() + weights:&weights.ip_mov_w[0] + biases:&weights.ip_mov_b[0] + activation:defaultActivation + label:@"moves_left/embedding"]; + } + else { + mlh = [graph addConvolutionBlockWithParent:layer + outputChannels:weights.moves_left.biases.size() + kernelSize:1 + weights:&weights.moves_left.weights[0] + biases:&weights.moves_left.biases[0] + activation:defaultActivation + label:@"moves_left/conv"]; + } - mlh = [graph addConvolutionBlockWithParent:layer - inputChannels:channelSize - outputChannels:mlhChannels - kernelSize:1 - weights:&weights.moves_left.weights[0] - biases:&weights.moves_left.biases[0] - activation:defaultActivation - label:@"mlh/conv"]; + mlh = [graph flatten2DTensor:mlh + axis:1 + name:@"moves_left/flatten"]; mlh = [graph addFullyConnectedLayerWithParent:mlh - inputChannels:mlhChannels * 8 * 8 outputChannels:weights.ip1_mov_b.size() weights:&weights.ip1_mov_w[0] biases:&weights.ip1_mov_b[0] activation:defaultActivation - label:@"mlh/fc1"]; + label:@"moves_left/fc1"]; mlh = [graph addFullyConnectedLayerWithParent:mlh - inputChannels:weights.ip1_mov_b.size() - outputChannels:1 + outputChannels:weights.ip2_mov_b.size() weights:&weights.ip2_mov_w[0] biases:&weights.ip2_mov_b[0] activation:@"relu" - label:@"mlh/fc2"]; + label:@"moves_left/fc2"]; } // Select the outputs to be run through the inference graph. @@ -383,23 +417,5 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a } } -void MetalNetworkBuilder::saveVariables(std::vector names) -{ - Lc0NetworkGraph * graph = [Lc0NetworkGraph getGraphAt:[NSNumber numberWithInt:this->gpu_id]]; - - for (const std::string name : names) { - [graph trackVariable:[NSString stringWithUTF8String:name.c_str()]]; - } -} - -void MetalNetworkBuilder::dumpVariables(std::vector names, int batches) -{ - Lc0NetworkGraph * graph = [Lc0NetworkGraph getGraphAt:[NSNumber numberWithInt:this->gpu_id]]; - - for (const std::string name : names) { - [graph dumpVariable:[NSString stringWithUTF8String:name.c_str()] batches:batches]; - } -} - } // namespace metal_backend } // namespace lczero diff --git a/src/neural/metal/mps/NetworkGraph.h b/src/neural/metal/mps/NetworkGraph.h index 8df9a1058c..657b828253 100644 --- a/src/neural/metal/mps/NetworkGraph.h +++ b/src/neural/metal/mps/NetworkGraph.h @@ -30,6 +30,8 @@ #import #import +#import "neural/network_legacy.h" + @interface MPSGraphTensor(Lc0Extensions) -(NSUInteger) size; @@ -40,15 +42,12 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat16; -@interface Lc0NetworkGraph : NSObject { +@interface Lc0NetworkGraph : MPSGraph { @public // Keep the device and command queue objects around for ease of use. MPSGraphDevice * _device; id _queue; - // MPSGraph implementation. - MPSGraph * _graph; - // Input tensor and tensor data placeholders. MPSGraphTensor * _inputTensor; @@ -60,6 +59,9 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat // Variables for triple buffering dispatch_semaphore_t _doubleBufferingSemaphore; + + // Global smolgen weights. + float * __nullable _globalSmolgenWeights; } +(Lc0NetworkGraph * _Nonnull) getGraphAt:(NSNumber * _Nonnull)index; @@ -75,7 +77,6 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat label:(NSString * __nullable)label; -(nonnull MPSGraphTensor *) addConvolutionBlockWithParent:(MPSGraphTensor * __nonnull)parent - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels kernelSize:(NSUInteger)kernelSize weights:(float * __nonnull)weights @@ -84,7 +85,6 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat label:(NSString * __nonnull)label; -(nonnull MPSGraphTensor *) addResidualBlockWithParent:(MPSGraphTensor * __nonnull)parent - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels kernelSize:(NSUInteger)kernelSize weights1:(float * __nonnull)weights1 @@ -101,25 +101,36 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat activation:(NSString * __nullable)activation; -(nonnull MPSGraphTensor *) addFullyConnectedLayerWithParent:(MPSGraphTensor * __nonnull)parent - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels weights:(float * __nonnull)weights - biases:(float * __nonnull)biases + biases:(float * __nullable)biases activation:(NSString * __nullable)activation label:(NSString * __nonnull)label; --(nonnull MPSGraphTensor *) addLayerNormalizationWithSkipParent:(MPSGraphTensor * __nonnull)parent - secondaryInput:(MPSGraphTensor * __nonnull)secondary - gammas:(float * __nonnull)gammas - betas:(float * __nonnull)betas - channelSize:(NSUInteger)channelSize - epsilon:(float)epsilon - label:(NSString * __nonnull)label; +-(nonnull MPSGraphTensor *) addEncoderLayerWithParent:parent + legacyWeights:(lczero::LegacyWeights::EncoderLayer &)weights + heads:(NSUInteger)heads + embeddingSize:(NSUInteger)embeddingSize + smolgenActivation:(NSString * __nullable)smolgenActivation + ffnActivation:(NSString * __nonnull)ffnActivation + alpha:(float)alpha + label:(NSString * __nonnull)label; + +-(nonnull MPSGraphTensor *) addLayerNormalizationWithParent:(MPSGraphTensor * __nonnull)parent + scaledSecondaryTensor:(MPSGraphTensor * __nullable)secondary + gammas:(float * __nonnull)gammas + betas:(float * __nonnull)betas + alpha:(float)alpha + epsilon:(float)epsilon + label:(NSString * __nonnull)label; -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnull)queries withKeys:(MPSGraphTensor * __nonnull)keys withValues:(MPSGraphTensor * __nonnull)values heads:(NSUInteger)heads + parent:(MPSGraphTensor * __nonnull)parent + smolgen:(lczero::LegacyWeights::Smolgen * __nullable)smolgen + smolgenActivation:(NSString * __nullable)smolgenActivation label:(NSString * __nonnull)label; -(nonnull MPSGraphTensor *) scaledQKMatmulWithQueries:(MPSGraphTensor * __nonnull)queries @@ -137,8 +148,22 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat label:(NSString * __nonnull)label; -(nonnull MPSGraphTensor *) transposeChannelsWithTensor:(MPSGraphTensor * __nonnull)tensor + withShape:(MPSShape * __nonnull)withShape label:(NSString * __nonnull)label; +-(nonnull MPSGraphTensor *) positionEncodingWithTensor:(MPSGraphTensor * __nonnull)tensor + withShape:(MPSShape * __nonnull)shape + weights:(const float * __nonnull)encodings + type:(NSString * __nullable)type + label:(NSString * __nonnull)label; + +-(nonnull MPSGraphTensor *) addGatingLayerWithParent:(MPSGraphTensor * __nonnull)parent + weights:(const float * __nonnull)weights + withOperation:(NSString * __nonnull)op + label:(NSString * __nonnull)label; + +-(void) setGlobalSmolgenWeights:(float * __nonnull)weights; + -(void) setResultTensors:(NSArray * __nonnull)results; -(nonnull NSArray *) runInferenceWithBatchSize:(NSUInteger)batchSize @@ -152,12 +177,4 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat -(void) copyResultsToBuffers:(float * __nonnull * __nonnull)outputBuffers subBatchSize:(NSUInteger)subBatchSize; --(void) setVariable:(NSString * __nonnull)name - tensor:(MPSGraphTensor * __nonnull)tensor; - --(void) trackVariable:(NSString * __nonnull)name; - --(void) dumpVariable:(NSString * __nonnull)name - batches:(NSUInteger)batches; - @end diff --git a/src/neural/metal/mps/NetworkGraph.mm b/src/neural/metal/mps/NetworkGraph.mm index 2a437ffe23..49384fbe54 100644 --- a/src/neural/metal/mps/NetworkGraph.mm +++ b/src/neural/metal/mps/NetworkGraph.mm @@ -25,8 +25,8 @@ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a Program grant you additional permission to convey the resulting work. */ +#import "neural/network_legacy.h" #import "NetworkGraph.h" - #import static MPSGraphConvolution2DOpDescriptor * __nonnull convolution2DDescriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:1 @@ -72,6 +72,15 @@ -(NSUInteger) sizeOfDimensions:(NSArray *)dimensions { return size; } + +-(NSUInteger) sizeOfDimensionsFrom:(NSNumber *)dimension { + NSUInteger size = 1; + for (NSUInteger dim = [dimension intValue]; dim < [self.shape count]; dim++) { + size *= [self.shape[dim] intValue]; + } + return size; +} + @end @implementation Lc0NetworkGraph @@ -93,9 +102,9 @@ +(NSMutableDictionary * _Nonnull) getGraphs { // This is the Lc0NetworkGraph getter method. +(Lc0NetworkGraph * _Nonnull) getGraphAt:(NSNumber * _Nonnull)index { - NSMutableDictionary * graphs = [Lc0NetworkGraph getGraphs]; + NSMutableDictionary * graphs = [Lc0NetworkGraph getGraphs]; - return graphs[index]; + return graphs[index]; } // This is the Lc0NetworkGraph factory method. @@ -118,7 +127,6 @@ -(nonnull instancetype) initWithDevice:(id __nonnull)device self = [super init]; _device = [MPSGraphDevice deviceWithMTLDevice:device]; _queue = [device newCommandQueue]; - _graph = [[MPSGraph alloc] init]; _resultTensors = @[]; _readVariables = [[NSMutableDictionary alloc] init]; _doubleBufferingSemaphore = dispatch_semaphore_create(kMaxInflightBuffers); @@ -191,7 +199,7 @@ -(nonnull MPSCommandBuffer *) runCommandSubBatchWithInputs:(float * __nonnull)in dispatch_semaphore_signal(_doubleBufferingSemaphore); }; - [_graph encodeToCommandBuffer:commandBuffer + [self encodeToCommandBuffer:commandBuffer feeds:@{_inputTensor : inputTensorData} targetTensors:_targetTensors targetOperations:nil @@ -236,13 +244,12 @@ -(nonnull MPSGraphTensor *) inputPlaceholderWithInputChannels:(NSUInteger)channe label:(NSString * __nullable)label { // Create a placeholder tensor that can hold the specified number of sub-batches. - _inputTensor = [_graph placeholderWithShape:@[@(-1), @(channels), @(height), @(width)] name:label]; + _inputTensor = [self placeholderWithShape:@[@(-1), @(channels), @(height), @(width)] name:label]; return _inputTensor; } -(nonnull MPSGraphTensor *) addConvolutionBlockWithParent:(MPSGraphTensor * __nonnull)parent - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels kernelSize:(NSUInteger)kernelSize weights:(float * __nonnull)weights @@ -250,11 +257,13 @@ -(nonnull MPSGraphTensor *) addConvolutionBlockWithParent:(MPSGraphTensor * __no activation:(NSString * __nullable)activation label:(NSString * __nonnull)label { + NSUInteger inputChannels = [parent.shape[1] intValue]; + NSData * weightsData = [NSData dataWithBytesNoCopy:weights length:outputChannels * inputChannels * kernelSize * kernelSize * sizeof(float) freeWhenDone:NO]; - MPSGraphTensor * weightsTensor = [_graph variableWithData:weightsData + MPSGraphTensor * weightsTensor = [self variableWithData:weightsData shape:@[@(outputChannels), @(inputChannels), @(kernelSize), @(kernelSize)] dataType:MPSDataTypeFloat32 name:[NSString stringWithFormat:@"%@/weights", label]]; @@ -263,33 +272,24 @@ -(nonnull MPSGraphTensor *) addConvolutionBlockWithParent:(MPSGraphTensor * __no length:outputChannels * sizeof(float) freeWhenDone:NO]; - MPSGraphTensor * biasTensor = [_graph variableWithData:biasData + MPSGraphTensor * biasTensor = [self variableWithData:biasData shape:@[@(outputChannels), @1, @1] dataType:MPSDataTypeFloat32 name:[NSString stringWithFormat:@"%@/biases", label]]; - MPSGraphTensor * convTensor = [_graph convolution2DWithSourceTensor:parent + MPSGraphTensor * convTensor = [self convolution2DWithSourceTensor:parent weightsTensor:weightsTensor descriptor:convolution2DDescriptor name:[NSString stringWithFormat:@"%@/conv", label]]; - MPSGraphTensor * convBiasTensor = [_graph additionWithPrimaryTensor:convTensor - secondaryTensor:biasTensor - name:[NSString stringWithFormat:@"%@/bias_add", label]]; - - [self setVariable:[NSString stringWithFormat:@"%@/weights", label] tensor:weightsTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/biases", label] tensor:biasTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/conv", label] tensor:convTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/bias_add", label] tensor:convBiasTensor]; - - convBiasTensor = [self applyActivationWithTensor:convBiasTensor activation:activation label:label]; - [self setVariable:label tensor:convBiasTensor]; + MPSGraphTensor * convBiasTensor = [self additionWithPrimaryTensor:convTensor + secondaryTensor:biasTensor + name:[NSString stringWithFormat:@"%@/bias_add", label]]; - return convBiasTensor; + return [self applyActivationWithTensor:convBiasTensor activation:activation label:label]; } -(nonnull MPSGraphTensor *) addResidualBlockWithParent:(MPSGraphTensor * __nonnull)parent - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels kernelSize:(NSUInteger)kernelSize weights1:(float * __nonnull)weights1 @@ -305,9 +305,7 @@ -(nonnull MPSGraphTensor *) addResidualBlockWithParent:(MPSGraphTensor * __nonnu seFcOutputs:(NSUInteger)seFcOutputs activation:(NSString * __nullable)activation { - MPSGraphTensor * conv1Tensor = [self addConvolutionBlockWithParent:parent - inputChannels:inputChannels outputChannels:outputChannels kernelSize:kernelSize weights:weights1 @@ -316,7 +314,6 @@ -(nonnull MPSGraphTensor *) addResidualBlockWithParent:(MPSGraphTensor * __nonnu label:[NSString stringWithFormat:@"%@/conv1", label]]; MPSGraphTensor * conv2Tensor = [self addConvolutionBlockWithParent:conv1Tensor - inputChannels:inputChannels outputChannels:outputChannels kernelSize:kernelSize weights:weights2 @@ -324,100 +321,77 @@ -(nonnull MPSGraphTensor *) addResidualBlockWithParent:(MPSGraphTensor * __nonnu activation:nil label:[NSString stringWithFormat:@"%@/conv2", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/conv1", label] tensor:conv1Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/conv2", label] tensor:conv2Tensor]; - if (hasSe) { // SE Unit. - MPSGraphTensor * seUnit = [self addSEUnitWithParent:conv2Tensor - skipNode:parent - inputChannels:inputChannels - outputChannels:outputChannels - seFcOutputs:seFcOutputs - weights1:seWeights1 - biases1:seBiases1 - weights2:seWeights2 - biases2:seBiases2 - activation:activation - label:[NSString stringWithFormat:@"%@/se", label]]; - - [self setVariable:label tensor:seUnit]; - return seUnit; + return [self addSEUnitWithParent:conv2Tensor + skipNode:parent + outputChannels:outputChannels + seFcOutputs:seFcOutputs + weights1:seWeights1 + biases1:seBiases1 + weights2:seWeights2 + biases2:seBiases2 + activation:activation + label:[NSString stringWithFormat:@"%@/se", label]]; } else { - MPSGraphTensor * residualTensor = [_graph additionWithPrimaryTensor:parent - secondaryTensor:conv2Tensor - name:[NSString stringWithFormat:@"%@/add", label]]; - - MPSGraphTensor * activationTensor = [self applyActivationWithTensor:residualTensor - activation:activation - label:label]; - [self setVariable:[NSString stringWithFormat:@"%@/add", label] tensor:residualTensor]; - [self setVariable:label tensor:activationTensor]; - return activationTensor; + MPSGraphTensor * residualTensor = [self additionWithPrimaryTensor:parent + secondaryTensor:conv2Tensor + name:[NSString stringWithFormat:@"%@/add", label]]; + + return [self applyActivationWithTensor:residualTensor + activation:activation + label:label]; } } -(nonnull MPSGraphTensor *) addFullyConnectedLayerWithParent:(MPSGraphTensor * __nonnull)parent - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels weights:(float * __nonnull)weights - biases:(float * __nonnull)biases + biases:(float * __nullable)biases activation:(NSString * __nullable)activation label:(NSString * __nonnull)label { + NSUInteger inputChannels = [[parent.shape lastObject] intValue]; + NSData * weightData = [NSData dataWithBytesNoCopy:weights length:outputChannels * inputChannels * sizeof(float) freeWhenDone:NO]; - MPSGraphTensor * weightTensor = [_graph variableWithData:weightData + MPSGraphTensor * weightTensor = [self variableWithData:weightData shape:@[@(outputChannels), @(inputChannels)] dataType:MPSDataTypeFloat32 name:[NSString stringWithFormat:@"%@/weights", label]]; // Leela weights are OIHW, need to be transposed to IO** to allow matmul. - MPSGraphTensor * transposeTensor = [_graph transposeTensor:weightTensor - dimension:0 - withDimension:1 - name:[NSString stringWithFormat:@"%@/weights_transpose", label]]; - - MPSGraphTensor * reshaped = [_graph reshapeTensor:parent - withShape:@[@(-1), @([parent sizeOfDimensions:@[@1, @2, @3]])] - name:[NSString stringWithFormat:@"%@/reshape", label]]; - - MPSGraphTensor * fcTensor = [_graph matrixMultiplicationWithPrimaryTensor:reshaped - secondaryTensor:transposeTensor - name:[NSString stringWithFormat:@"%@/matmul", label]]; - - NSData * biasData = [NSData dataWithBytesNoCopy:biases - length:outputChannels * sizeof(float) - freeWhenDone:NO]; + weightTensor = [self transposeTensor:weightTensor + dimension:0 + withDimension:1 + name:[NSString stringWithFormat:@"%@/weights_transpose", label]]; - MPSGraphTensor * biasTensor = [_graph variableWithData:biasData - shape:@[@(outputChannels)] - dataType:MPSDataTypeFloat32 - name:[NSString stringWithFormat:@"%@/biases", label]]; - - MPSGraphTensor * addTensor = [_graph additionWithPrimaryTensor:fcTensor - secondaryTensor:biasTensor - name:[NSString stringWithFormat:@"%@/bias_add", label]]; + parent = [self matrixMultiplicationWithPrimaryTensor:parent + secondaryTensor:weightTensor + name:[NSString stringWithFormat:@"%@/matmul", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/weights", label] tensor:weightTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/weights_transpose", label] tensor:transposeTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/reshape", label] tensor:reshaped]; - [self setVariable:[NSString stringWithFormat:@"%@/matmul", label] tensor:fcTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/biases", label] tensor:biasTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/bias_add", label] tensor:addTensor]; + if (biases != nil) { + NSData * biasData = [NSData dataWithBytesNoCopy:biases + length:outputChannels * sizeof(float) + freeWhenDone:NO]; - addTensor = [self applyActivationWithTensor:addTensor activation:activation label:label]; + MPSGraphTensor * biasTensor = [self variableWithData:biasData + shape:@[@(outputChannels)] + dataType:MPSDataTypeFloat32 + name:[NSString stringWithFormat:@"%@/biases", label]]; - [self setVariable:label tensor:addTensor]; - return addTensor; + parent = [self additionWithPrimaryTensor:parent + secondaryTensor:biasTensor + name:[NSString stringWithFormat:@"%@/bias_add", label]]; + } + return [self applyActivationWithTensor:parent activation:activation label:label]; } -(nonnull MPSGraphTensor *) addSEUnitWithParent:(MPSGraphTensor * __nonnull)parent skipNode:(MPSGraphTensor * __nonnull)skipTensor - inputChannels:(NSUInteger)inputChannels outputChannels:(NSUInteger)outputChannels seFcOutputs:(NSUInteger)seFcOutputs weights1:(float * __nonnull)weights1 @@ -429,88 +403,72 @@ -(nonnull MPSGraphTensor *) addSEUnitWithParent:(MPSGraphTensor * __nonnull)pare { // 1. Global Average Pooling 2D - MPSGraphTensor * poolTensor = [_graph avgPooling2DWithSourceTensor:parent - descriptor:averagePoolingDescriptor - name:[NSString stringWithFormat:@"%@/pool", label]]; + MPSGraphTensor * seunit = [self avgPooling2DWithSourceTensor:parent + descriptor:averagePoolingDescriptor + name:[NSString stringWithFormat:@"%@/pool", label]]; // 2. FC Layer 1. - MPSGraphTensor * fc1Tensor = [self addFullyConnectedLayerWithParent:poolTensor - inputChannels:inputChannels - outputChannels:seFcOutputs - weights:weights1 - biases:biases1 - activation:activation - label:[NSString stringWithFormat:@"%@/fc1", label]]; + seunit = [self flatten2DTensor:seunit + axis:1 + name:[NSString stringWithFormat:@"%@/flatten", label]]; + + seunit = [self addFullyConnectedLayerWithParent:seunit + outputChannels:seFcOutputs + weights:weights1 + biases:biases1 + activation:activation + label:[NSString stringWithFormat:@"%@/fc1", label]]; // 3. FC Layer 2. - MPSGraphTensor * fc2Tensor = [self addFullyConnectedLayerWithParent:fc1Tensor - inputChannels:seFcOutputs - outputChannels:2 * inputChannels - weights:weights2 - biases:biases2 - activation:nil - label:[NSString stringWithFormat:@"%@/fc2", label]]; - - // 4. Slice 1 and gamma. - MPSGraphTensor * slice1Tensor = [_graph sliceTensor:fc2Tensor - dimension:1 - start:0 - length:inputChannels - name:[NSString stringWithFormat:@"%@/slice1", label]]; - - MPSGraphTensor * gammaTensor = [_graph sigmoidWithTensor:slice1Tensor - name:[NSString stringWithFormat:@"%@/sigmoid", label]]; - - // 5. Slice 2 - MPSGraphTensor * slice2Tensor = [_graph sliceTensor:fc2Tensor - dimension:1 - start:inputChannels - length:inputChannels - name:[NSString stringWithFormat:@"%@/slice2", label]]; - - // 5. Multiply and add. - MPSGraphTensor * reshape1Tensor = [_graph reshapeTensor:gammaTensor - withShape:@[@(-1), gammaTensor.shape[1], @1, @1] - name:[NSString stringWithFormat:@"%@/reshape1", label]]; - - - MPSGraphTensor * multiplyTensor = [_graph multiplicationWithPrimaryTensor:parent - secondaryTensor:reshape1Tensor - name:[NSString stringWithFormat:@"%@/multiply", label]]; - - MPSGraphTensor * reshape2Tensor = [_graph reshapeTensor:slice2Tensor - withShape:@[@(-1), slice2Tensor.shape[1], @1, @1] - name:[NSString stringWithFormat:@"%@/reshape2", label]]; - - MPSGraphTensor * add1Tensor = [_graph additionWithPrimaryTensor:multiplyTensor - secondaryTensor:reshape2Tensor - name:[NSString stringWithFormat:@"%@/add1", label]]; - - MPSGraphTensor * add2Tensor = [_graph additionWithPrimaryTensor:add1Tensor - secondaryTensor:skipTensor - name:[NSString stringWithFormat:@"%@/add2", label]]; + NSUInteger inputChannels = [parent.shape[1] intValue]; + seunit = [self addFullyConnectedLayerWithParent:seunit + outputChannels:2 * inputChannels + weights:weights2 + biases:biases2 + activation:nil + label:[NSString stringWithFormat:@"%@/fc2", label]]; + + // 4. Slice 1, gamma and multiply. + MPSGraphTensor * gamma = [self sliceTensor:seunit + dimension:1 + start:0 + length:inputChannels + name:[NSString stringWithFormat:@"%@/slice1", label]]; + + gamma = [self sigmoidWithTensor:gamma + name:[NSString stringWithFormat:@"%@/sigmoid", label]]; + + gamma = [self reshapeTensor:gamma + withShape:@[@(-1), gamma.shape[1], @1, @1] + name:[NSString stringWithFormat:@"%@/reshape1", label]]; + + gamma = [self multiplicationWithPrimaryTensor:parent + secondaryTensor:gamma + name:[NSString stringWithFormat:@"%@/multiply", label]]; + + // 5. Slice 2 and add. + seunit = [self sliceTensor:seunit + dimension:1 + start:inputChannels + length:inputChannels + name:[NSString stringWithFormat:@"%@/slice2", label]]; + + seunit = [self reshapeTensor:seunit + withShape:@[@(-1), seunit.shape[1], @1, @1] + name:[NSString stringWithFormat:@"%@/reshape2", label]]; + + seunit = [self additionWithPrimaryTensor:gamma + secondaryTensor:seunit + name:[NSString stringWithFormat:@"%@/add1", label]]; + + seunit = [self additionWithPrimaryTensor:seunit + secondaryTensor:skipTensor + name:[NSString stringWithFormat:@"%@/add2", label]]; // 6. Default activation. - MPSGraphTensor * activationTensor = [self applyActivationWithTensor:add2Tensor - activation:activation - label:label]; - - // Add all the variables if specified. - [self setVariable:[NSString stringWithFormat:@"%@/pool", label] tensor:poolTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/fc1", label] tensor:fc1Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/fc2", label] tensor:fc2Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/slice1", label] tensor:slice1Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/slice2", label] tensor:slice2Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/sigmoid", label] tensor:gammaTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/reshape1", label] tensor:reshape1Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/reshape2", label] tensor:reshape2Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/multiply", label] tensor:multiplyTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/add1", label] tensor:add1Tensor]; - [self setVariable:[NSString stringWithFormat:@"%@/add2", label] tensor:add2Tensor]; - - [self setVariable:label tensor:activationTensor]; - - return activationTensor; + return [self applyActivationWithTensor:seunit + activation:activation + label:label]; } -(nonnull MPSGraphTensor *) addPolicyMapLayerWithParent:(MPSGraphTensor * __nonnull)parent @@ -521,53 +479,142 @@ -(nonnull MPSGraphTensor *) addPolicyMapLayerWithParent:(MPSGraphTensor * __nonn length:kNumPolicyOutputs * sizeof(uint32_t) freeWhenDone:NO]; - MPSGraphTensor * mappingTensor = [_graph constantWithData:policyMapData + MPSGraphTensor * mappingTensor = [self constantWithData:policyMapData shape:@[@(kNumPolicyOutputs)] dataType:MPSDataTypeUInt32]; - MPSGraphTensor * flatConvTensor = [_graph reshapeTensor:parent - withShape:@[parent.shape[0], @([parent sizeOfDimensions:@[@1, @2, @3]])] + MPSGraphTensor * flatConvTensor = [self flatten2DTensor:parent + axis:1 name:[NSString stringWithFormat:@"%@/flatten", label]]; - MPSGraphTensor * policyTensor = [_graph gatherWithUpdatesTensor:flatConvTensor + MPSGraphTensor * policyTensor = [self gatherWithUpdatesTensor:flatConvTensor indicesTensor:mappingTensor axis:1 batchDimensions:0 name:[NSString stringWithFormat:@"%@/gather", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/constant", label] tensor:mappingTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/flatten", label] tensor:flatConvTensor]; - [self setVariable:[NSString stringWithFormat:@"%@/gather", label] tensor:policyTensor]; - [self setVariable:label tensor:policyTensor]; - return policyTensor; } --(nonnull MPSGraphTensor *) addLayerNormalizationWithSkipParent:(MPSGraphTensor * __nonnull)parent - secondaryInput:(MPSGraphTensor * __nonnull)secondary - gammas:(float * __nonnull)gammas - betas:(float * __nonnull)betas - channelSize:(NSUInteger)channelSize - epsilon:(float)epsilon - label:(NSString * __nonnull)label +-(nonnull MPSGraphTensor *) addEncoderLayerWithParent:parent + legacyWeights:(lczero::LegacyWeights::EncoderLayer &)encoder + heads:(NSUInteger)heads + embeddingSize:(NSUInteger)embeddingSize + smolgenActivation:(NSString * __nullable)smolgenActivation + ffnActivation:(NSString * __nonnull)ffnActivation + alpha:(float)alpha + label:(NSString * __nonnull)label { - parent = [_graph additionWithPrimaryTensor:parent - secondaryTensor:secondary - name:[NSString stringWithFormat:@"%@/add", label]]; + NSUInteger dModel = encoder.mha.q_b.size(); + MPSGraphTensor * mhaQ = [self addFullyConnectedLayerWithParent:parent + outputChannels:encoder.mha.q_b.size() + weights:&encoder.mha.q_w[0] + biases:&encoder.mha.q_b[0] + activation:nil + label:[NSString stringWithFormat:@"%@/mhaq/fc", label]]; + + MPSGraphTensor * mhaK = [self addFullyConnectedLayerWithParent:parent + outputChannels:encoder.mha.k_b.size() + weights:&encoder.mha.k_w[0] + biases:&encoder.mha.k_b[0] + activation:nil + label:[NSString stringWithFormat:@"%@/mhak/fc", label]]; + + MPSGraphTensor * mhaV = [self addFullyConnectedLayerWithParent:parent + outputChannels:encoder.mha.v_b.size() + weights:&encoder.mha.v_w[0] + biases:&encoder.mha.v_b[0] + activation:nil + label:[NSString stringWithFormat:@"%@/mhav/fc", label]]; + + MPSGraphTensor * mha = [self scaledMHAMatmulWithQueries:mhaQ + withKeys:mhaK + withValues:mhaV + heads:heads + parent:parent + smolgen:encoder.mha.has_smolgen ? &encoder.mha.smolgen : nil + smolgenActivation:smolgenActivation + label:[NSString stringWithFormat:@"%@/mha", label]]; + + // MHA final dense layer. + mha = [self addFullyConnectedLayerWithParent:mha + outputChannels:embeddingSize + weights:&encoder.mha.dense_w[0] + biases:&encoder.mha.dense_b[0] + activation:nil + label:[NSString stringWithFormat:@"%@/mha/fc", label]]; + + // Skip connection + Layer Norm 1. + MPSGraphTensor * enc = [self addLayerNormalizationWithParent:mha + scaledSecondaryTensor:parent + gammas:&encoder.ln1_gammas[0] + betas:&encoder.ln1_betas[0] + alpha:alpha + epsilon:1e-6 + label:[NSString stringWithFormat:@"%@/ln1", label]]; + + // Feedforward network (FFN). + MPSGraphTensor * ffn = [self addFullyConnectedLayerWithParent:enc + outputChannels:encoder.ffn.dense1_b.size() + weights:&encoder.ffn.dense1_w[0] + biases:&encoder.ffn.dense1_b[0] + activation:ffnActivation + label:[NSString stringWithFormat:@"%@/ffn1", label]]; + + ffn = [self addFullyConnectedLayerWithParent:ffn + outputChannels:encoder.ffn.dense2_b.size() + weights:&encoder.ffn.dense2_w[0] + biases:&encoder.ffn.dense2_b[0] + activation:nil + label:[NSString stringWithFormat:@"%@/ffn2", label]]; + + // Skip connection + Layer Norm 2. + return [self addLayerNormalizationWithParent:ffn + scaledSecondaryTensor:enc + gammas:&encoder.ln2_gammas[0] + betas:&encoder.ln2_betas[0] + alpha:alpha + epsilon:1e-6 + label:[NSString stringWithFormat:@"%@/ln2", label]]; +} - MPSGraphTensor * means = [_graph meanOfTensor:parent - axes:@[@1] +-(nonnull MPSGraphTensor *) addLayerNormalizationWithParent:(MPSGraphTensor * __nonnull)parent + scaledSecondaryTensor:(MPSGraphTensor * __nullable)secondary + gammas:(float * __nonnull)gammas + betas:(float * __nonnull)betas + alpha:(float)alpha + epsilon:(float)epsilon + label:(NSString * __nonnull)label +{ + if (secondary != nil) { + if (alpha != 1.0) { + MPSGraphTensor * alphaTensor = [self constantWithScalar:alpha shape:@[@1] dataType:parent.dataType]; + secondary = [self multiplicationWithPrimaryTensor:secondary + secondaryTensor:alphaTensor + name:[NSString stringWithFormat:@"%@/multiply", label]]; + } + + parent = [self additionWithPrimaryTensor:parent + secondaryTensor:secondary + name:[NSString stringWithFormat:@"%@/add", label]]; + } + + NSUInteger axis = [parent.shape count] - 1; + NSUInteger channelSize = [[parent.shape lastObject] intValue]; + + MPSGraphTensor * means = [self meanOfTensor:parent + axes:@[@(axis)] name:[NSString stringWithFormat:@"%@/mean", label]]; - MPSGraphTensor * variances = [_graph varianceOfTensor:parent - axes:@[@1] + MPSGraphTensor * variances = [self varianceOfTensor:parent + axes:@[@(axis)] name:[NSString stringWithFormat:@"%@/variance", label]]; NSData * gammaData = [NSData dataWithBytesNoCopy:gammas length:channelSize * sizeof(float) freeWhenDone:NO]; - MPSGraphTensor * gammaTensor = [_graph variableWithData:gammaData + MPSGraphTensor * gammaTensor = [self variableWithData:gammaData shape:@[@(channelSize)] dataType:MPSDataTypeFloat32 name:[NSString stringWithFormat:@"%@/gamma", label]]; @@ -576,12 +623,12 @@ -(nonnull MPSGraphTensor *) addLayerNormalizationWithSkipParent:(MPSGraphTensor length:channelSize * sizeof(float) freeWhenDone:NO]; - MPSGraphTensor * betaTensor = [_graph variableWithData:betaData + MPSGraphTensor * betaTensor = [self variableWithData:betaData shape:@[@(channelSize)] dataType:MPSDataTypeFloat32 name:[NSString stringWithFormat:@"%@/beta", label]]; - return [_graph normalizationWithTensor:parent + return [self normalizationWithTensor:parent meanTensor:means varianceTensor:variances gammaTensor:gammaTensor @@ -591,62 +638,133 @@ -(nonnull MPSGraphTensor *) addLayerNormalizationWithSkipParent:(MPSGraphTensor } -(nonnull MPSGraphTensor *) transposeChannelsWithTensor:(MPSGraphTensor * __nonnull)tensor - label:(NSString * __nonnull)label + withShape:(MPSShape * __nonnull)withShape + label:(NSString * __nonnull)label { - MPSGraphTensor * transposeTensor = [_graph transposeTensor:tensor + MPSGraphTensor * transposeTensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/weights_transpose_1", label]]; - transposeTensor = [_graph transposeTensor:transposeTensor + transposeTensor = [self transposeTensor:transposeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/weights_transpose_2", label]]; - return [_graph reshapeTensor:transposeTensor - withShape:@[@(-1), [transposeTensor.shape lastObject]] + return [self reshapeTensor:transposeTensor + withShape:withShape name:[NSString stringWithFormat:@"%@/reshape", label]]; } -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnull)queries - withKeys:(MPSGraphTensor * __nonnull)keys - withValues:(MPSGraphTensor * __nonnull)values - heads:(NSUInteger)heads - label:(NSString * __nonnull)label + withKeys:(MPSGraphTensor * __nonnull)keys + withValues:(MPSGraphTensor * __nonnull)values + heads:(NSUInteger)heads + parent:(MPSGraphTensor * __nonnull)parent + smolgen:(lczero::LegacyWeights::Smolgen * __nullable)smolgen + smolgenActivation:(NSString * __nullable)smolgenActivation + label:(NSString * __nonnull)label { // Split heads. const NSUInteger dmodel = [[queries.shape lastObject] intValue]; const NSUInteger depth = dmodel / heads; - queries = [_graph reshapeTensor:queries withShape:@[@(-1), @64, @(heads), @(depth)] name:[NSString stringWithFormat:@"%@/reshape_q", label]]; - queries = [_graph transposeTensor:queries dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_q", label]]; + queries = [self reshapeTensor:queries withShape:@[@(-1), @64, @(heads), @(depth)] name:[NSString stringWithFormat:@"%@/reshape_q", label]]; + queries = [self transposeTensor:queries dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_q", label]]; - keys = [_graph reshapeTensor:keys withShape:@[@(-1), @64, @(heads), @(depth)] name:[NSString stringWithFormat:@"%@/reshape_k", label]]; - keys = [_graph transposeTensor:keys dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_k", label]]; + keys = [self reshapeTensor:keys withShape:@[@(-1), @64, @(heads), @(depth)] name:[NSString stringWithFormat:@"%@/reshape_k", label]]; + keys = [self transposeTensor:keys dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_k", label]]; - values = [_graph reshapeTensor:values withShape:@[@(-1), @64, @(heads), @(depth)] name:[NSString stringWithFormat:@"%@/reshape_v", label]]; - values = [_graph transposeTensor:values dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_v", label]]; + values = [self reshapeTensor:values withShape:@[@(-1), @64, @(heads), @(depth)] name:[NSString stringWithFormat:@"%@/reshape_v", label]]; + values = [self transposeTensor:values dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_v", label]]; // Scaled attention matmul. - keys = [_graph transposeTensor:keys dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_k_2", label]]; - MPSGraphTensor * attn = [_graph matrixMultiplicationWithPrimaryTensor:queries + keys = [self transposeTensor:keys dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_k_2", label]]; + MPSGraphTensor * attn = [self matrixMultiplicationWithPrimaryTensor:queries secondaryTensor:keys name:[NSString stringWithFormat:@"%@/matmul_qk", label]]; - attn = [_graph divisionWithPrimaryTensor:attn - secondaryTensor:[_graph constantWithScalar:sqrt(depth) - shape:@[@1] - dataType:attn.dataType] + attn = [self divisionWithPrimaryTensor:attn + secondaryTensor:[self constantWithScalar:sqrt(depth) + shape:@[@1] + dataType:attn.dataType] name:[NSString stringWithFormat:@"%@/scale", label]]; + // Smolgen. + if (smolgen != nil) { + // Smolgen weights. + // 1. Compressed fully connected layer and reshape. + NSUInteger hidden_channels = smolgen->compress.size() / [[parent.shape lastObject] intValue]; + MPSGraphTensor * smolgenWeights = [self addFullyConnectedLayerWithParent:parent + outputChannels:hidden_channels + weights:&smolgen->compress[0] + biases:nil + activation:nil + label:[NSString stringWithFormat:@"%@/smolgen/compress", label]]; + smolgenWeights = [self flatten2DTensor:smolgenWeights + axis:1 + name:[NSString stringWithFormat:@"%@/smolgen/flatten", label]]; + + // 2. Dense 1 with layer norm. + smolgenWeights = [self addFullyConnectedLayerWithParent:smolgenWeights + outputChannels:smolgen->dense1_b.size() + weights:&smolgen->dense1_w[0] + biases:&smolgen->dense1_b[0] + activation:smolgenActivation + label:[NSString stringWithFormat:@"%@/smolgen/dense_1", label]]; + + smolgenWeights = [self addLayerNormalizationWithParent:smolgenWeights + scaledSecondaryTensor:nil + gammas:&smolgen->ln1_gammas[0] + betas:&smolgen->ln1_betas[0] + alpha:0.0 + epsilon:1e-6 + label:[NSString stringWithFormat:@"%@/smolgen/ln1", label]]; + + // 3. Dense 2 with layer norm. + smolgenWeights = [self addFullyConnectedLayerWithParent:smolgenWeights + outputChannels:smolgen->dense2_b.size() + weights:&smolgen->dense2_w[0] + biases:&smolgen->dense2_b[0] + activation:smolgenActivation + label:[NSString stringWithFormat:@"%@/smolgen/dense_2", label]]; + + smolgenWeights = [self addLayerNormalizationWithParent:smolgenWeights + scaledSecondaryTensor:nil + gammas:&smolgen->ln2_gammas[0] + betas:&smolgen->ln2_betas[0] + alpha:0.0 + epsilon:1e-6 + label:[NSString stringWithFormat:@"%@/smolgen/ln2", label]]; + + smolgenWeights = [self reshapeTensor:smolgenWeights + withShape:@[@(-1), @(heads), @(smolgen->dense2_b.size() / heads)] + name:[NSString stringWithFormat:@"%@/smolgen/reshape_1", label]]; + + // 4. Global smolgen weights + smolgenWeights = [self addFullyConnectedLayerWithParent:smolgenWeights + outputChannels:64 * 64 + weights:_globalSmolgenWeights + biases:nil + activation:nil + label:[NSString stringWithFormat:@"%@/smolgen/global", label]]; + + smolgenWeights = [self reshapeTensor:smolgenWeights + withShape:@[@(-1), @(heads), @64, @64] + name:[NSString stringWithFormat:@"%@/smolgen/reshape_2", label]]; + + attn = [self additionWithPrimaryTensor:attn + secondaryTensor:smolgenWeights + name:[NSString stringWithFormat:@"%@/smolgen_add", label]]; + } attn = [self applyActivationWithTensor:attn activation:@"softmax" label:label]; // matmul(scaled_attention_weights, v). - attn = [_graph matrixMultiplicationWithPrimaryTensor:attn + attn = [self matrixMultiplicationWithPrimaryTensor:attn secondaryTensor:values name:[NSString stringWithFormat:@"%@/matmul_v", label]]; - attn = [_graph transposeTensor:attn dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_a", label]]; + attn = [self transposeTensor:attn dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_a", label]]; - return [_graph reshapeTensor:attn withShape:@[@(-1), @(dmodel)] name:[NSString stringWithFormat:@"%@/reshape_a", label]]; + return [self reshapeTensor:attn withShape:@[@(-1), @64, @(dmodel)] name:[NSString stringWithFormat:@"%@/reshape_a", label]]; } -(nonnull MPSGraphTensor *) scaledQKMatmulWithQueries:(MPSGraphTensor * __nonnull)queries @@ -654,25 +772,25 @@ -(nonnull MPSGraphTensor *) scaledQKMatmulWithQueries:(MPSGraphTensor * __nonnul scale:(float)scale label:(NSString * __nonnull)label { - queries = [_graph reshapeTensor:queries + queries = [self reshapeTensor:queries withShape:@[@(-1), @64, [queries.shape lastObject]] name:[NSString stringWithFormat:@"%@/reshape_q", label]]; - keys = [_graph reshapeTensor:keys + keys = [self reshapeTensor:keys withShape:@[@(-1), @64, [keys.shape lastObject]] name:[NSString stringWithFormat:@"%@/reshape_k", label]]; - keys = [_graph transposeTensor:keys + keys = [self transposeTensor:keys dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_k", label]]; - MPSGraphTensor * qkMatmul = [_graph matrixMultiplicationWithPrimaryTensor:queries + MPSGraphTensor * qkMatmul = [self matrixMultiplicationWithPrimaryTensor:queries secondaryTensor:keys name:[NSString stringWithFormat:@"%@/matmul", label]]; - qkMatmul = [_graph multiplicationWithPrimaryTensor:qkMatmul - secondaryTensor:[_graph constantWithScalar:scale + qkMatmul = [self multiplicationWithPrimaryTensor:qkMatmul + secondaryTensor:[self constantWithScalar:scale shape:@[@1] dataType:qkMatmul.dataType] name:[NSString stringWithFormat:@"%@/scale", label]]; return qkMatmul; @@ -687,48 +805,150 @@ -(nonnull MPSGraphTensor *) attentionPolicyPromoMatmulConcatWithParent:(MPSGraph channelSize:(NSUInteger)channelSize label:(NSString * __nonnull)label { - keys = [_graph reshapeTensor:keys withShape:@[@(-1), @64, @(channelSize)] name:[NSString stringWithFormat:@"%@/slice", label]]; + keys = [self reshapeTensor:keys withShape:@[@(-1), @64, @(channelSize)] name:[NSString stringWithFormat:@"%@/slice", label]]; - keys = [_graph sliceTensor:keys dimension:1 start:sliceFrom length:inputSize name:[NSString stringWithFormat:@"%@/slice", label]]; + keys = [self sliceTensor:keys dimension:1 start:sliceFrom length:inputSize name:[NSString stringWithFormat:@"%@/slice", label]]; NSData * weightData = [NSData dataWithBytesNoCopy:weights length:outputSize * channelSize * sizeof(float) freeWhenDone:NO]; - MPSGraphTensor * weightTensor = [_graph variableWithData:weightData + MPSGraphTensor * weightTensor = [self variableWithData:weightData shape:@[@(outputSize), @(channelSize)] dataType:parent.dataType name:[NSString stringWithFormat:@"%@/weights", label]]; - keys = [_graph transposeTensor:keys dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose", label]]; + keys = [self transposeTensor:keys dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose", label]]; - keys = [_graph matrixMultiplicationWithPrimaryTensor:weightTensor - secondaryTensor:keys - name:[NSString stringWithFormat:@"%@/matmul", label]]; + keys = [self matrixMultiplicationWithPrimaryTensor:weightTensor + secondaryTensor:keys + name:[NSString stringWithFormat:@"%@/matmul", label]]; - NSArray * offsets = [_graph splitTensor:keys - splitSizes:@[@3, @1] - axis:1 - name:[NSString stringWithFormat:@"%@/offset_split", label]]; + MPSGraphTensor * offset1 = [self sliceTensor:keys + dimension:1 + start:0 + length:3 + name:[NSString stringWithFormat:@"%@/offset_slice_1", label]]; - MPSGraphTensor * promo = [_graph additionWithPrimaryTensor:offsets[0] - secondaryTensor:offsets[1] - name:[NSString stringWithFormat:@"%@/offset_add", label]]; + MPSGraphTensor * offset2 = [self sliceTensor:keys + dimension:1 + start:3 + length:1 + name:[NSString stringWithFormat:@"%@/offset_slice_2", label]]; + + MPSGraphTensor * promo = [self additionWithPrimaryTensor:offset1 + secondaryTensor:offset2 + name:[NSString stringWithFormat:@"%@/offset_add", label]]; NSMutableArray * stack = [NSMutableArray arrayWithCapacity:inputSize]; for (NSUInteger i = 0; i < inputSize; i++) { [stack addObject:promo]; } - promo = [_graph stackTensors:stack axis:3 name:[NSString stringWithFormat:@"%@/offset_broadcast", label]]; + promo = [self stackTensors:stack axis:3 name:[NSString stringWithFormat:@"%@/offset_broadcast", label]]; + + promo = [self transposeTensor:promo dimension:1 withDimension:3 name:[NSString stringWithFormat:@"%@/offset_transpose", label]]; + + promo = [self reshapeTensor:promo withShape:@[@(-1), @3, @64] name:[NSString stringWithFormat:@"%@/offset_reshape", label]]; - promo = [_graph transposeTensor:promo dimension:1 withDimension:3 name:[NSString stringWithFormat:@"%@/offset_transpose", label]]; + parent = [self reshapeTensor:parent withShape:@[@(-1), @64, @64] name:[NSString stringWithFormat:@"%@/parent_reshape", label]]; - promo = [_graph reshapeTensor:promo withShape:@[@(-1), @3, @64] name:[NSString stringWithFormat:@"%@/offset_reshape", label]]; + return [self concatTensor:parent withTensor:promo dimension:1 name:[NSString stringWithFormat:@"%@/concat", label]]; +} + +-(nonnull MPSGraphTensor *) positionEncodingWithTensor:(MPSGraphTensor * __nonnull)tensor + withShape:(MPSShape * __nonnull)shape + weights:(const float * __nonnull)encodings + type:(NSString * __nullable)type + label:(NSString * __nonnull)label +{ + assert([shape count] == 2 && shape[0] == tensor.shape[1]); + + NSData * encodingData = [NSData dataWithBytesNoCopy:(void *)encodings + length:[shape[0] intValue] * [shape[1] intValue] * sizeof(float) + freeWhenDone:NO]; + + MPSGraphTensor * encodingTensor = [self variableWithData:encodingData + shape:shape + dataType:MPSDataTypeFloat32 + name:[NSString stringWithFormat:@"%@/weights", label]]; + + MPSGraphTensor * shapeTensor = [self shapeOfTensor:tensor + name:[NSString stringWithFormat:@"%@/shape", label]]; + + // # add positional encoding for each square to the input + // positional_encoding = tf.broadcast_to(tf.convert_to_tensor(self.POS_ENC, dtype=self.model_dtype), + // [tf.shape(flow)[0], 64, tf.shape(self.POS_ENC)[2]]) + // flow = tf.concat([flow, positional_encoding], axis=2) + + // shapeTensor is (b, hw, c) and we want to make it (b, hw, hw). Since we don't know b yet, we have to manipulate this + // tensor and use it for the broadcast op. + // @todo look for a better way to do this. + shapeTensor = [self sliceTensor:shapeTensor + dimension:0 + start:0 + length:2 + name:[NSString stringWithFormat:@"%@/shape/slice", label]]; + + shapeTensor = [self concatTensor:shapeTensor + withTensor:[self constantWithScalar:[[shape lastObject] intValue] + shape:@[@1] + dataType:shapeTensor.dataType] + dimension:0 + name:[NSString stringWithFormat:@"%@/shape/concat", label]]; + + encodingTensor = [self broadcastTensor:encodingTensor + toShapeTensor:shapeTensor + name:[NSString stringWithFormat:@"%@/weights/broadcast", label]]; + + encodingTensor = [self reshapeTensor:encodingTensor + withShape:@[@(-1), shape[0], shape[1]] + name:[NSString stringWithFormat:@"%@/weights/reshape", label]]; + + return [self concatTensor:tensor + withTensor:encodingTensor + dimension:[tensor.shape count] - 1 + name:[NSString stringWithFormat:@"%@/concat", label]]; +} + +-(nonnull MPSGraphTensor *) addGatingLayerWithParent:(MPSGraphTensor * __nonnull)parent + weights:(const float * __nonnull)weights + withOperation:(NSString * __nonnull)op + label:(NSString * __nonnull)label +{ + NSData * weightsData = [NSData dataWithBytesNoCopy:(void *)weights + length:[parent sizeOfDimensionsFrom:@1] * sizeof(float) + freeWhenDone:NO]; + + MPSGraphTensor * weightsTensor = [self variableWithData:weightsData + shape:@[parent.shape[2], parent.shape[1]] + dataType:MPSDataTypeFloat32 + name:[NSString stringWithFormat:@"%@/weights", label]]; + + // Leela weights are transposed. + weightsTensor = [self transposeTensor:weightsTensor + dimension:0 + withDimension:1 + name:[NSString stringWithFormat:@"%@/weights_transpose", label]]; + + if ([op isEqual:@"add"]) { + return [self additionWithPrimaryTensor:parent + secondaryTensor:weightsTensor + name:[NSString stringWithFormat:@"%@/add", label]]; + } + else if ([op isEqual:@"mult"]) { + return [self multiplicationWithPrimaryTensor:parent + secondaryTensor:weightsTensor + name:[NSString stringWithFormat:@"%@/multiply", label]]; + } + + return parent; +} - parent = [_graph reshapeTensor:parent withShape:@[@(-1), @64, @64] name:[NSString stringWithFormat:@"%@/parent_reshape", label]]; - return [_graph concatTensor:parent withTensor:promo dimension:1 name:[NSString stringWithFormat:@"%@/concat", label]]; +-(void) setGlobalSmolgenWeights:(float * __nonnull)weights +{ + _globalSmolgenWeights = weights; } -(nonnull MPSGraphTensor *) applyActivationWithTensor:(MPSGraphTensor * __nonnull)tensor @@ -736,29 +956,31 @@ -(nonnull MPSGraphTensor *) applyActivationWithTensor:(MPSGraphTensor * __nonnul label:(NSString * __nullable)label { if ([activation isEqual:@"relu"]) { - tensor = [_graph reLUWithTensor:tensor name:[NSString stringWithFormat:@"%@/relu", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/relu", label] tensor:tensor]; + return [self reLUWithTensor:tensor name:[NSString stringWithFormat:@"%@/relu", label]]; + } + if ([activation isEqual:@"relu_2"]) { + tensor = [self reLUWithTensor:tensor name:[NSString stringWithFormat:@"%@/relu", label]]; + return [self multiplicationWithPrimaryTensor:tensor + secondaryTensor:tensor + name:[NSString stringWithFormat:@"%@/square", label]]; } else if ([activation isEqual:@"tanh"]) { - tensor = [_graph tanhWithTensor:tensor name:[NSString stringWithFormat:@"%@/tanh", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/tanh", label] tensor:tensor]; + return [self tanhWithTensor:tensor name:[NSString stringWithFormat:@"%@/tanh", label]]; } else if ([activation isEqual:@"sigmoid"]) { - tensor = [_graph sigmoidWithTensor:tensor name:[NSString stringWithFormat:@"%@/sigmoid", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/sigmoid", label] tensor:tensor]; + return [self sigmoidWithTensor:tensor name:[NSString stringWithFormat:@"%@/sigmoid", label]]; } else if ([activation isEqual:@"softmax"]) { - tensor = [_graph softMaxWithTensor:tensor axis:([tensor.shape count] - 1) name:[NSString stringWithFormat:@"%@/softmax", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/softmax", label] tensor:tensor]; + return [self softMaxWithTensor:tensor axis:([tensor.shape count] - 1) name:[NSString stringWithFormat:@"%@/softmax", label]]; } else if ([activation isEqual:@"selu"]) { - tensor = [self seluWithTensor:tensor label:[NSString stringWithFormat:@"%@/mish", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/selu", label] tensor:tensor]; + return [self seluWithTensor:tensor label:[NSString stringWithFormat:@"%@/mish", label]]; } else if ([activation isEqual:@"mish"]) { - tensor = [self mishWithTensor:tensor label:[NSString stringWithFormat:@"%@/mish", label]]; - //tensor = [self fastMishWithTensor:tensor label:[NSString stringWithFormat:@"%@/mish", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/mish", label] tensor:tensor]; + return [self mishWithTensor:tensor label:[NSString stringWithFormat:@"%@/mish", label]]; + } + else if ([activation isEqual:@"swish"]) { + return [self swishWithTensor:tensor beta:1.0 label:[NSString stringWithFormat:@"%@/swish", label]]; } return tensor; @@ -768,86 +990,40 @@ -(nonnull MPSGraphTensor *) mishWithTensor:(MPSGraphTensor * __nonnull)tensor label:(NSString * __nonnull)label { // mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) - MPSGraphTensor * mishTensor = [_graph exponentWithTensor:tensor + MPSGraphTensor * mishTensor = [self exponentWithTensor:tensor name:[NSString stringWithFormat:@"%@/exp", label]]; - MPSGraphTensor * oneTensor = [_graph constantWithScalar:1.0 shape:@[@1] dataType:mishTensor.dataType]; - mishTensor = [_graph additionWithPrimaryTensor:mishTensor - secondaryTensor:oneTensor - name:[NSString stringWithFormat:@"%@/add", label]]; + MPSGraphTensor * oneTensor = [self constantWithScalar:1.0 shape:@[@1] dataType:mishTensor.dataType]; + mishTensor = [self additionWithPrimaryTensor:mishTensor + secondaryTensor:oneTensor + name:[NSString stringWithFormat:@"%@/add", label]]; - mishTensor = [_graph logarithmWithTensor:mishTensor name:[NSString stringWithFormat:@"%@/ln", label]]; + mishTensor = [self logarithmWithTensor:mishTensor name:[NSString stringWithFormat:@"%@/ln", label]]; - mishTensor = [_graph tanhWithTensor:mishTensor name:[NSString stringWithFormat:@"%@/tanh", label]]; + mishTensor = [self tanhWithTensor:mishTensor name:[NSString stringWithFormat:@"%@/tanh", label]]; - mishTensor = [_graph multiplicationWithPrimaryTensor:mishTensor - secondaryTensor:tensor - name:[NSString stringWithFormat:@"%@/multiply", label]]; + mishTensor = [self multiplicationWithPrimaryTensor:mishTensor + secondaryTensor:tensor + name:[NSString stringWithFormat:@"%@/multiply", label]]; return mishTensor; } --(nonnull MPSGraphTensor *) fastMishWithTensor:(MPSGraphTensor * __nonnull)tensor - label:(NSString * __nonnull)label +-(nonnull MPSGraphTensor *) swishWithTensor:(MPSGraphTensor * __nonnull)tensor + beta:(float)beta + label:(NSString * __nonnull)label { - // @todo: currently hangs for multiple resnet stacks. - // Faster mish implementation using approximate functions. - // __device__ __forceinline__ float mishActivate(float el) { - // auto e = __expf(el); - // auto n = e * e + 2 * e; - // if (el <= -0.6f) { - // return n * __fdividef(el, n + 2); - // } else { - // return el - 2 * __fdividef(el, n + 2); - // } - // } - MPSGraphTensor * c_0_6 = [_graph constantWithScalar:-0.6 shape:@[@1] dataType:tensor.dataType]; - MPSGraphTensor * c_2_0 = [_graph constantWithScalar:2.0 shape:@[@1] dataType:tensor.dataType]; - MPSGraphTensor * exp = [_graph exponentWithTensor:tensor - name:[NSString stringWithFormat:@"%@/exp", label]]; - MPSGraphTensor * nExp = [_graph additionWithPrimaryTensor:exp - secondaryTensor:c_2_0 - name:[NSString stringWithFormat:@"%@/add", label]]; - nExp = [_graph multiplicationWithPrimaryTensor:exp - secondaryTensor:nExp - name:[NSString stringWithFormat:@"%@/multiply", label]]; - - - MPSGraphTensor * lessOrEqual = [_graph lessThanOrEqualToWithPrimaryTensor:tensor - secondaryTensor:c_0_6 - name:[NSString stringWithFormat:@"%@/le", label]]; - MPSGraphTensor * greater = [_graph greaterThanWithPrimaryTensor:tensor - secondaryTensor:c_0_6 - name:[NSString stringWithFormat:@"%@/gt", label]]; - - MPSGraphTensor * fdiv = [_graph additionWithPrimaryTensor:nExp - secondaryTensor:c_2_0 - name:[NSString stringWithFormat:@"%@/fdiv_add", label]]; - fdiv = [_graph divisionWithPrimaryTensor:tensor - secondaryTensor:fdiv - name:[NSString stringWithFormat:@"%@/fdiv", label]]; - - lessOrEqual = [_graph multiplicationWithPrimaryTensor:lessOrEqual - secondaryTensor:nExp - name:[NSString stringWithFormat:@"%@/le_multiply", label]]; - lessOrEqual = [_graph multiplicationWithPrimaryTensor:lessOrEqual - secondaryTensor:fdiv - name:[NSString stringWithFormat:@"%@/le_fdiv", label]]; - - MPSGraphTensor * fastMish = [_graph multiplicationWithPrimaryTensor:c_2_0 - secondaryTensor:fdiv - name:[NSString stringWithFormat:@"%@/fdiv_multiply", label]]; - fastMish = [_graph subtractionWithPrimaryTensor:tensor - secondaryTensor:fastMish - name:[NSString stringWithFormat:@"%@/fdiv_subtract", label]]; - - greater = [_graph multiplicationWithPrimaryTensor:greater - secondaryTensor:fastMish - name:[NSString stringWithFormat:@"%@/gt_multiply", label]]; - - return [_graph additionWithPrimaryTensor:lessOrEqual - secondaryTensor:greater - name:[NSString stringWithFormat:@"%@/condition_add", label]]; + // swish(x) = x * sigmoid(β * x) + MPSGraphTensor * betaTensor = [self constantWithScalar:beta shape:@[@1] dataType:tensor.dataType]; + MPSGraphTensor * swish = [self multiplicationWithPrimaryTensor:tensor + secondaryTensor:betaTensor + name:[NSString stringWithFormat:@"%@/multiply", label]]; + swish = [self sigmoidWithTensor:swish + name:[NSString stringWithFormat:@"%@/sigmoid", label]]; + + return [self multiplicationWithPrimaryTensor:tensor + secondaryTensor:swish + name:[NSString stringWithFormat:@"%@/multiply_2", label]]; } @@ -858,100 +1034,47 @@ -(nonnull MPSGraphTensor *) seluWithTensor:(MPSGraphTensor * __nonnull)tensor // if x > 0: return scale * x // if x < 0: return scale * alpha * (exp(x) - 1) // alpha=1.67326324, scale=1.05070098 - MPSGraphTensor * zero = [_graph constantWithScalar:0.0 shape:@[@1] dataType:tensor.dataType]; - MPSGraphTensor * scale = [_graph constantWithScalar:1.05070098 shape:@[@1] dataType:tensor.dataType]; - MPSGraphTensor * alpha = [_graph constantWithScalar:1.67326324 shape:@[@1] dataType:tensor.dataType]; + MPSGraphTensor * zero = [self constantWithScalar:0.0 shape:@[@1] dataType:tensor.dataType]; + MPSGraphTensor * scale = [self constantWithScalar:1.05070098 shape:@[@1] dataType:tensor.dataType]; + MPSGraphTensor * alpha = [self constantWithScalar:1.67326324 shape:@[@1] dataType:tensor.dataType]; - MPSGraphTensor * lessThanZero = [_graph lessThanWithPrimaryTensor:tensor + MPSGraphTensor * lessThanZero = [self lessThanWithPrimaryTensor:tensor secondaryTensor:zero name:[NSString stringWithFormat:@"%@/ltzero", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/ltzero", label] tensor:lessThanZero]; - MPSGraphTensor * greaterThanZero = [_graph greaterThanOrEqualToWithPrimaryTensor:tensor + MPSGraphTensor * greaterThanZero = [self greaterThanOrEqualToWithPrimaryTensor:tensor secondaryTensor:zero name:[NSString stringWithFormat:@"%@/gtzero", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/gtzero", label] tensor:greaterThanZero]; - MPSGraphTensor * scaled = [_graph multiplicationWithPrimaryTensor:tensor + MPSGraphTensor * scaled = [self multiplicationWithPrimaryTensor:tensor secondaryTensor:scale name:[NSString stringWithFormat:@"%@/scale", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/scale", label] tensor:scaled]; - scaled = [_graph multiplicationWithPrimaryTensor:scaled + scaled = [self multiplicationWithPrimaryTensor:scaled secondaryTensor:greaterThanZero name:[NSString stringWithFormat:@"%@/scale_mask", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/scale_mask", label] tensor:scaled]; - MPSGraphTensor * exp = [_graph exponentWithTensor:tensor + MPSGraphTensor * exp = [self exponentWithTensor:tensor name:[NSString stringWithFormat:@"%@/exp", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/exp", label] tensor:exp]; - MPSGraphTensor * one = [_graph constantWithScalar:1.0 shape:@[@1] dataType:tensor.dataType]; - exp = [_graph subtractionWithPrimaryTensor:exp + MPSGraphTensor * one = [self constantWithScalar:1.0 shape:@[@1] dataType:tensor.dataType]; + exp = [self subtractionWithPrimaryTensor:exp secondaryTensor:one name:[NSString stringWithFormat:@"%@/exp_1", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/exp_1", label] tensor:exp]; - exp = [_graph multiplicationWithPrimaryTensor:exp + exp = [self multiplicationWithPrimaryTensor:exp secondaryTensor:alpha name:[NSString stringWithFormat:@"%@/exp_alpha", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/exp_alpha", label] tensor:exp]; - exp = [_graph multiplicationWithPrimaryTensor:exp + exp = [self multiplicationWithPrimaryTensor:exp secondaryTensor:scale name:[NSString stringWithFormat:@"%@/exp_scale", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/exp_scale", label] tensor:exp]; - exp = [_graph multiplicationWithPrimaryTensor:exp + exp = [self multiplicationWithPrimaryTensor:exp secondaryTensor:lessThanZero name:[NSString stringWithFormat:@"%@/exp_mask", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/exp_mask", label] tensor:exp]; - - exp = [_graph additionWithPrimaryTensor:scaled secondaryTensor:exp name:[NSString stringWithFormat:@"%@/sum", label]]; - [self setVariable:[NSString stringWithFormat:@"%@/sum", label] tensor:exp]; - - return exp; -} - --(void) setVariable:(NSString * __nonnull)name - tensor:(MPSGraphTensor *)tensor -{ - if (![[_readVariables allKeys] containsObject:name]) return; - - _readVariables[name] = tensor; -} --(void) trackVariable:(NSString * __nonnull)name -{ - _readVariables[name] = [NSNull null]; -} - --(void) dumpVariable:(NSString * __nonnull)name - batches:(NSUInteger)batches -{ - if (!_readVariables[name] || _readVariables[name] == [NSNull null]) { - NSLog(@"No variable '%@' found", name); - return; - } - - MPSGraphTensor * variable = (MPSGraphTensor *) _readVariables[name]; - NSUInteger size = [variable.shape[0] intValue] > 0 ? [variable size] : batches * [variable sizeOfDimensions:@[@1, @2, @3]]; - - if (variable.dataType == MPSDataTypeUInt32) { - uint32_t * dumpArray = (uint32_t *)malloc(size * sizeof(uint32_t)); - [[_resultDataDicts[@0][_readVariables[name]] mpsndarray] readBytes:dumpArray strideBytes:nil]; - NSLog(@"Dumping: '%@', size: %i, type: %i", name, size, variable.dataType); - for (NSUInteger i = 0; i < (size > 100 ? 100 : size); i++) { - NSLog(@";%i;%i", i, dumpArray[i]); - } - } else { - float * dumpArray = (float *)malloc(size * sizeof(float)); - [[_resultDataDicts[@0][_readVariables[name]] mpsndarray] readBytes:dumpArray strideBytes:nil]; - NSLog(@"Dumping: '%@', size: %i, type: %i", name, size, variable.dataType); - for (NSUInteger i = 0; i < (size > 100 ? 100 : size); i++) { - NSLog(@";%i;%f", i, dumpArray[i]); - } - } + return [self additionWithPrimaryTensor:scaled secondaryTensor:exp name:[NSString stringWithFormat:@"%@/sum", label]]; } @end diff --git a/src/neural/metal/network_metal.cc b/src/neural/metal/network_metal.cc index 73a3dabc94..68dac8fb0f 100644 --- a/src/neural/metal/network_metal.cc +++ b/src/neural/metal/network_metal.cc @@ -25,7 +25,6 @@ Program grant you additional permission to convey the resulting work. */ #include "network_metal.h" -#include "mps/MetalNetworkBuilder.h" #include #include @@ -35,17 +34,19 @@ #include #include +#include "mps/MetalNetworkBuilder.h" #include "neural/factory.h" #include "neural/network_legacy.h" -#include "neural/shared/policy_map.h" #include "neural/shared/attention_policy_map.h" +#include "neural/shared/policy_map.h" #include "utils/bititer.h" #include "utils/exception.h" namespace lczero { namespace metal_backend { -MetalNetworkComputation::MetalNetworkComputation(MetalNetwork* network, bool wdl, bool moves_left) +MetalNetworkComputation::MetalNetworkComputation(MetalNetwork* network, + bool wdl, bool moves_left) : wdl_(wdl), moves_left_(moves_left), network_(network) { batch_size_ = 0; inputs_outputs_ = network_->GetInputsOutputs(); @@ -59,10 +60,34 @@ void MetalNetworkComputation::ComputeBlocking() { network_->forwardEval(inputs_outputs_.get(), GetBatchSize()); } +std::string activationString(pblczero::NetworkFormat::ActivationFunction act) { + switch (act) { + case pblczero::NetworkFormat::ACTIVATION_RELU: + return "relu"; + case pblczero::NetworkFormat::ACTIVATION_MISH: + return "mish"; + case pblczero::NetworkFormat::ACTIVATION_NONE: + return "none"; + case pblczero::NetworkFormat::ACTIVATION_TANH: + return "tanh"; + case pblczero::NetworkFormat::ACTIVATION_SIGMOID: + return "sigmoid"; + case pblczero::NetworkFormat::ACTIVATION_SELU: + return "selu"; + case pblczero::NetworkFormat::ACTIVATION_SWISH: + return "swish"; + case pblczero::NetworkFormat::ACTIVATION_RELU_2: + return "relu_2"; + case pblczero::NetworkFormat::ACTIVATION_SOFTMAX: + return "softmax"; + default: + return ""; + } +} + MetalNetwork::MetalNetwork(const WeightsFile& file, const OptionsDict& options) : capabilities_{file.format().network_format().input(), file.format().network_format().moves_left()} { - LegacyWeights weights(file.weights()); try { @@ -74,9 +99,6 @@ MetalNetwork::MetalNetwork(const WeightsFile& file, const OptionsDict& options) throw Exception("There was an error initializing the GPU device."); } - const int channelSize = weights.input.weights.size() / kInputPlanes / 9; - const int kernelSize = 3; - max_batch_size_ = options.GetOrDefault("max_batch", 1024); batch_size_ = options.GetOrDefault("batch", 64); @@ -86,31 +108,54 @@ MetalNetwork::MetalNetwork(const WeightsFile& file, const OptionsDict& options) attn_policy_ = file.format().network_format().policy() == pblczero::NetworkFormat::POLICY_ATTENTION; - wdl_ = file.format().network_format().value() == pblczero::NetworkFormat::VALUE_WDL; + wdl_ = file.format().network_format().value() == + pblczero::NetworkFormat::VALUE_WDL; moves_left_ = (file.format().network_format().moves_left() == pblczero::NetworkFormat::MOVES_LEFT_V1) && options.GetOrDefault("mlh", true); - policy_d_model_ = weights.ip2_pol_b.size(); + bool attn_body = + file.format().network_format().network() == + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; // Build MPS Graph. - builder_->build(kInputPlanes, channelSize, kernelSize, weights, attn_policy_, conv_policy_, wdl_, moves_left_, - file.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH ? "mish" : "relu" - ); + Activations activations; + activations.default_activation = + file.format().network_format().default_activation() == + pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH + ? "mish" + : "relu"; + const auto smolgen_activation = + file.format().network_format().smolgen_activation(); + activations.smolgen_activation = + smolgen_activation == pblczero::NetworkFormat::ACTIVATION_DEFAULT + ? activations.default_activation + : activationString( + static_cast( + smolgen_activation)); + const auto ffn_activation = file.format().network_format().ffn_activation(); + activations.ffn_activation = + ffn_activation == pblczero::NetworkFormat::ACTIVATION_DEFAULT + ? activations.default_activation + : activationString( + static_cast( + ffn_activation)); + builder_->build(kInputPlanes, weights, attn_body, attn_policy_, conv_policy_, + wdl_, moves_left_, activations); } void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) { // Expand encoded input into N x 112 x 8 x 8. - float * dptr = &io->input_val_mem_expanded_[0]; + float* dptr = &io->input_val_mem_expanded_[0]; for (size_t i = 0; i < batchSize; i++) { - for (size_t j = 0; j < kInputPlanes; j++) { - const float value = io->input_val_mem_[j + i * kInputPlanes]; - const uint64_t mask = io->input_masks_mem_[j + i * kInputPlanes]; - for (auto k = 0; k < 64; k++) { - *(dptr++) = (mask & (((uint64_t)1) << k)) != 0 ? value : 0; - } + for (size_t j = 0; j < kInputPlanes; j++) { + const float value = io->input_val_mem_[j + i * kInputPlanes]; + const uint64_t mask = io->input_masks_mem_[j + i * kInputPlanes]; + for (auto k = 0; k < 64; k++) { + *(dptr++) = (mask & (((uint64_t)1) << k)) != 0 ? value : 0; } + } } // Metal is not thread-safe, so lock is needed. @@ -118,18 +163,17 @@ void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) { if (attn_policy_ || conv_policy_) { /** - * @todo policy map implementation has bug in MPSGraph (GatherND not working in graph). - * Implementation of policy map to be done in CPU for now. + * @todo policy map implementation has bug in MPSGraph (GatherND not working + * in graph). Implementation of policy map to be done in CPU for now. * * Remove this if-branch when bug is fixed. See comments above. */ if (moves_left_) { - builder_->forwardEval( - &io->input_val_mem_expanded_[0], batchSize, - {&io->op_policy_raw_mem_[0], &io->op_value_mem_[0], &io->op_moves_left_mem_[0]}); - } - else { + builder_->forwardEval(&io->input_val_mem_expanded_[0], batchSize, + {&io->op_policy_raw_mem_[0], &io->op_value_mem_[0], + &io->op_moves_left_mem_[0]}); + } else { builder_->forwardEval( &io->input_val_mem_expanded_[0], batchSize, {&io->op_policy_raw_mem_[0], &io->op_value_mem_[0]}); @@ -145,10 +189,10 @@ void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) { for (int i = 0; i < 3; i++) { // c in cuda // Promotion offsets already precalculated and stored in GPU. // Just the main policy offsets need to be added here. - io->op_policy_raw_mem_[batch * (64 * 64 + 8 * 24) + 64 * 64 + 24 * k + - 3 * j + i] += + io->op_policy_raw_mem_[batch * (64 * 64 + 8 * 24) + 64 * 64 + + 24 * k + 3 * j + i] += io->op_policy_raw_mem_[batch * (64 * 64 + 8 * 24) + - (48 + k) * 64 + 56 + j]; + (48 + k) * 64 + 56 + j]; } } } @@ -163,8 +207,7 @@ void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) { } } } - } - else if (conv_policy_) { + } else if (conv_policy_) { // Mapping from convolutional policy to lc0 policy for (size_t batch = 0; batch < batchSize; batch++) { for (size_t i = 0; i < 73 * 64; i++) { @@ -177,23 +220,19 @@ void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) { } } - } - else { + } else { if (moves_left_) { - builder_->forwardEval( - &io->input_val_mem_expanded_[0], batchSize, - {&io->op_policy_mem_[0], &io->op_value_mem_[0], &io->op_moves_left_mem_[0]}); - } - else { - builder_->forwardEval( - &io->input_val_mem_expanded_[0], batchSize, - {&io->op_policy_mem_[0], &io->op_value_mem_[0]}); + builder_->forwardEval(&io->input_val_mem_expanded_[0], batchSize, + {&io->op_policy_mem_[0], &io->op_value_mem_[0], + &io->op_moves_left_mem_[0]}); + } else { + builder_->forwardEval(&io->input_val_mem_expanded_[0], batchSize, + {&io->op_policy_mem_[0], &io->op_value_mem_[0]}); } // The next thread can start using the GPU now. lock_.unlock(); } - } std::unique_ptr MakeMetalNetwork(const std::optional& w, @@ -205,10 +244,12 @@ std::unique_ptr MakeMetalNetwork(const std::optional& w, if (weights.format().network_format().network() != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT) { + pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + weights.format().network_format().network() != + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( - weights.format().network_format().network()) + + weights.format().network_format().network()) + " is not supported by the Metal backend."); } if (weights.format().network_format().policy() != @@ -219,7 +260,7 @@ std::unique_ptr MakeMetalNetwork(const std::optional& w, pblczero::NetworkFormat::POLICY_ATTENTION) { throw Exception("Policy format " + pblczero::NetworkFormat::PolicyFormat_Name( - weights.format().network_format().policy()) + + weights.format().network_format().policy()) + " is not supported by the Metal backend."); } if (weights.format().network_format().value() != @@ -228,7 +269,7 @@ std::unique_ptr MakeMetalNetwork(const std::optional& w, pblczero::NetworkFormat::VALUE_WDL) { throw Exception("Value format " + pblczero::NetworkFormat::ValueFormat_Name( - weights.format().network_format().value()) + + weights.format().network_format().value()) + " is not supported by the Metal backend."); } if (weights.format().network_format().moves_left() != @@ -237,22 +278,23 @@ std::unique_ptr MakeMetalNetwork(const std::optional& w, pblczero::NetworkFormat::MOVES_LEFT_V1) { throw Exception("Moves left head format " + pblczero::NetworkFormat::MovesLeftFormat_Name( - weights.format().network_format().moves_left()) + + weights.format().network_format().moves_left()) + " is not supported by the Metal backend."); } if (weights.format().network_format().default_activation() != pblczero::NetworkFormat::DEFAULT_ACTIVATION_RELU && weights.format().network_format().default_activation() != pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH) { - throw Exception("Default activation " + - pblczero::NetworkFormat::DefaultActivation_Name( - weights.format().network_format().default_activation()) + - " is not supported by the Metal backend."); + throw Exception( + "Default activation " + + pblczero::NetworkFormat::DefaultActivation_Name( + weights.format().network_format().default_activation()) + + " is not supported by the Metal backend."); } return std::make_unique(weights, options); } REGISTER_NETWORK("metal", MakeMetalNetwork, 105) -} // namespace backend_metal +} // namespace metal_backend } // namespace lczero diff --git a/src/neural/metal/network_metal.h b/src/neural/metal/network_metal.h index de2a7d5d49..b2e2df4b39 100644 --- a/src/neural/metal/network_metal.h +++ b/src/neural/metal/network_metal.h @@ -28,9 +28,9 @@ #include +#include "metal_common.h" #include "neural/factory.h" #include "neural/network_legacy.h" -#include "metal_common.h" namespace lczero { namespace metal_backend { @@ -107,7 +107,8 @@ class MetalNetwork : public Network { public: MetalNetwork(const WeightsFile& file, const OptionsDict& options); ~MetalNetwork() { - // if (builder_) { /** @todo clean-up delegate first */ delete builder; builder = NULL; } + // if (builder_) { /** @todo clean-up delegate first */ delete builder; + // builder = NULL; } } void forwardEval(InputsOutputs* io, int inputBatchSize); @@ -119,8 +120,8 @@ class MetalNetwork : public Network { std::unique_ptr GetInputsOutputs() { std::lock_guard lock(inputs_outputs_lock_); if (free_inputs_outputs_.empty()) { - return std::make_unique(max_batch_size_, wdl_, - moves_left_, conv_policy_, attn_policy_); + return std::make_unique(max_batch_size_, wdl_, moves_left_, + conv_policy_, attn_policy_); } else { std::unique_ptr resource = std::move(free_inputs_outputs_.front()); @@ -141,20 +142,19 @@ class MetalNetwork : public Network { private: NetworkCapabilities capabilities_{ pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE, - pblczero::NetworkFormat::MOVES_LEFT_NONE - }; + pblczero::NetworkFormat::MOVES_LEFT_NONE}; int max_batch_size_; int batch_size_; bool wdl_; bool moves_left_; bool conv_policy_; bool attn_policy_; - int policy_d_model_; std::mutex inputs_outputs_lock_; std::list> free_inputs_outputs_; std::unique_ptr builder_; - // Metal not really good at multi-threading, so we need to do one NN eval at a time. + // Metal not really good at multi-threading, so we need to do one NN eval at a + // time. mutable std::mutex lock_; }; diff --git a/src/neural/network.h b/src/neural/network.h index fe6337711b..054b2ebd33 100644 --- a/src/neural/network.h +++ b/src/neural/network.h @@ -98,8 +98,7 @@ struct NetworkCapabilities { } bool has_mlh() const { - return moves_left != - pblczero::NetworkFormat::MovesLeftFormat::MOVES_LEFT_NONE; + return moves_left != pblczero::NetworkFormat::MOVES_LEFT_NONE; } }; diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index f4819d4952..387590de6b 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -30,6 +30,10 @@ static constexpr float kEpsilon = 1e-5f; LegacyWeights::LegacyWeights(const pblczero::Weights& weights) : input(weights.input()), + ip_emb_w(LayerAdapter(weights.ip_emb_w()).as_vector()), + ip_emb_b(LayerAdapter(weights.ip_emb_b()).as_vector()), + ip_mult_gate(LayerAdapter(weights.ip_mult_gate()).as_vector()), + ip_add_gate(LayerAdapter(weights.ip_add_gate()).as_vector()), policy1(weights.policy1()), policy(weights.policy()), ip_pol_w(LayerAdapter(weights.ip_pol_w()).as_vector()), @@ -40,18 +44,28 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) ip3_pol_b(LayerAdapter(weights.ip3_pol_b()).as_vector()), ip4_pol_w(LayerAdapter(weights.ip4_pol_w()).as_vector()), value(weights.value()), + ip_val_w(LayerAdapter(weights.ip_val_w()).as_vector()), + ip_val_b(LayerAdapter(weights.ip_val_b()).as_vector()), ip1_val_w(LayerAdapter(weights.ip1_val_w()).as_vector()), ip1_val_b(LayerAdapter(weights.ip1_val_b()).as_vector()), ip2_val_w(LayerAdapter(weights.ip2_val_w()).as_vector()), ip2_val_b(LayerAdapter(weights.ip2_val_b()).as_vector()), moves_left(weights.moves_left()), + ip_mov_w(LayerAdapter(weights.ip_mov_w()).as_vector()), + ip_mov_b(LayerAdapter(weights.ip_mov_b()).as_vector()), ip1_mov_w(LayerAdapter(weights.ip1_mov_w()).as_vector()), ip1_mov_b(LayerAdapter(weights.ip1_mov_b()).as_vector()), ip2_mov_w(LayerAdapter(weights.ip2_mov_w()).as_vector()), - ip2_mov_b(LayerAdapter(weights.ip2_mov_b()).as_vector()) { + ip2_mov_b(LayerAdapter(weights.ip2_mov_b()).as_vector()), + smolgen_w(LayerAdapter(weights.smolgen_w()).as_vector()), + has_smolgen(weights.has_smolgen_w()) { for (const auto& res : weights.residual()) { residual.emplace_back(res); } + encoder_head_count = weights.headcount(); + for (const auto& enc : weights.encoder()) { + encoder.emplace_back(enc); + } pol_encoder_head_count = weights.pol_headcount(); for (const auto& enc : weights.pol_encoder()) { pol_encoder.emplace_back(enc); @@ -135,7 +149,9 @@ LegacyWeights::MHA::MHA(const pblczero::Weights::MHA& mha) v_w(LayerAdapter(mha.v_w()).as_vector()), v_b(LayerAdapter(mha.v_b()).as_vector()), dense_w(LayerAdapter(mha.dense_w()).as_vector()), - dense_b(LayerAdapter(mha.dense_b()).as_vector()) {} + dense_b(LayerAdapter(mha.dense_b()).as_vector()), + smolgen(Smolgen(mha.smolgen())), + has_smolgen(mha.has_smolgen()) {} LegacyWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) : dense1_w(LayerAdapter(ffn.dense1_w()).as_vector()), @@ -152,4 +168,16 @@ LegacyWeights::EncoderLayer::EncoderLayer( ln2_gammas(LayerAdapter(encoder.ln2_gammas()).as_vector()), ln2_betas(LayerAdapter(encoder.ln2_betas()).as_vector()) {} +LegacyWeights::Smolgen::Smolgen( + const pblczero::Weights::Smolgen& smolgen) + : compress(LayerAdapter(smolgen.compress()).as_vector()), + dense1_w(LayerAdapter(smolgen.dense1_w()).as_vector()), + dense1_b(LayerAdapter(smolgen.dense1_b()).as_vector()), + ln1_gammas(LayerAdapter(smolgen.ln1_gammas()).as_vector()), + ln1_betas(LayerAdapter(smolgen.ln1_betas()).as_vector()), + dense2_w(LayerAdapter(smolgen.dense2_w()).as_vector()), + dense2_b(LayerAdapter(smolgen.dense2_b()).as_vector()), + ln2_gammas(LayerAdapter(smolgen.ln2_gammas()).as_vector()), + ln2_betas(LayerAdapter(smolgen.ln2_betas()).as_vector()) {} + } // namespace lczero diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 3ba6028d5e..5715c40fbb 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -55,6 +55,19 @@ struct LegacyWeights { bool has_se; }; + struct Smolgen { + explicit Smolgen(const pblczero::Weights::Smolgen& smolgen); + Vec compress; + Vec dense1_w; + Vec dense1_b; + Vec ln1_gammas; + Vec ln1_betas; + Vec dense2_w; + Vec dense2_b; + Vec ln2_gammas; + Vec ln2_betas; + }; + struct MHA { explicit MHA(const pblczero::Weights::MHA& mha); Vec q_w; @@ -65,6 +78,8 @@ struct LegacyWeights { Vec v_b; Vec dense_w; Vec dense_b; + Smolgen smolgen; + bool has_smolgen; }; struct FFN { @@ -88,6 +103,18 @@ struct LegacyWeights { // Input convnet. ConvBlock input; + // Embedding layer + Vec ip_emb_w; + Vec ip_emb_b; + + // Input gating + Vec ip_mult_gate; + Vec ip_add_gate; + + // Encoder stack. + std::vector encoder; + int encoder_head_count; + // Residual tower. std::vector residual; @@ -109,6 +136,8 @@ struct LegacyWeights { // Value head ConvBlock value; + Vec ip_val_w; + Vec ip_val_b; Vec ip1_val_w; Vec ip1_val_b; Vec ip2_val_w; @@ -116,10 +145,17 @@ struct LegacyWeights { // Moves left head ConvBlock moves_left; + Vec ip_mov_w; + Vec ip_mov_b; Vec ip1_mov_w; Vec ip1_mov_b; Vec ip2_mov_w; Vec ip2_mov_b; + + // Smolgen global weights + Vec smolgen_w; + Vec smolgen_b; + bool has_smolgen; }; } // namespace lczero diff --git a/src/neural/onednn/layers.cc b/src/neural/onednn/layers.cc index 96e1c98349..5d89112451 100644 --- a/src/neural/onednn/layers.cc +++ b/src/neural/onednn/layers.cc @@ -93,9 +93,9 @@ void ConvLayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, if (use_skip_) { conv_ops.append_sum(); } - if (activation_ == RELU) { + if (activation_ == ACTIVATION_RELU) { conv_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); - } else if (activation_ == TANH) { + } else if (activation_ == ACTIVATION_TANH) { conv_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_tanh, 0.0f, 0.0f); } dnnl::primitive_attr conv_attr; @@ -110,7 +110,7 @@ void ConvLayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, out_md = conv_pd.dst_desc(); // Apparently convolution doesn't go well with mish post op. - if (activation_ == MISH) { + if (activation_ == ACTIVATION_MISH) { auto mish_d = dnnl::eltwise_forward::desc( dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_mish, out_md, 0.f, 0.f); @@ -183,7 +183,7 @@ void ConvLayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, {DNNL_ARG_DST, output}, {DNNL_ARG_SCRATCHPAD, scratchpad_mem}}); - if (activation_ == MISH) { + if (activation_ == ACTIVATION_MISH) { mish_.execute(stream, {{DNNL_ARG_SRC, output}, {DNNL_ARG_DST, output}, {DNNL_ARG_SCRATCHPAD, scratchpad_mem}}); @@ -263,11 +263,11 @@ void SELayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, dnnl::prop_kind::forward_inference, t_fc1_in_md, t_filter_md, bias_mem.get_desc(), t_fc1_out_md); dnnl::post_ops fc_ops; - if (activation_ == RELU) { + if (activation_ == ACTIVATION_RELU) { fc_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); - } else if (activation_ == MISH) { + } else if (activation_ == ACTIVATION_MISH) { fc_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_mish, 0.0f, 0.0f); - } else if (activation_ == TANH) { + } else if (activation_ == ACTIVATION_TANH) { fc_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_tanh, 0.0f, 0.0f); } dnnl::primitive_attr fc_attr; @@ -328,11 +328,11 @@ void SELayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, if (eng.get_kind() == dnnl::engine::kind::gpu) { // Using binary post-ops is a gain on gpu but a huge loss on cpu. mul_ops.append_binary(dnnl::algorithm::binary_add, pool_out_md); - if (activation_ == RELU) { + if (activation_ == ACTIVATION_RELU) { mul_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); - } else if (activation_ == MISH) { + } else if (activation_ == ACTIVATION_MISH) { mul_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_mish, 0.0f, 0.0f); - } else if (activation_ == TANH) { + } else if (activation_ == ACTIVATION_TANH) { mul_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_tanh, 0.0f, 0.0f); } } @@ -350,11 +350,11 @@ void SELayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, dnnl::binary::desc(dnnl::algorithm::binary_add, output.get_desc(), pool_out_md, output.get_desc()); dnnl::post_ops add_ops; - if (activation_ == RELU) { + if (activation_ == ACTIVATION_RELU) { add_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); - } else if (activation_ == MISH) { + } else if (activation_ == ACTIVATION_MISH) { add_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_mish, 0.0f, 0.0f); - } else if (activation_ == TANH) { + } else if (activation_ == ACTIVATION_TANH) { add_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_tanh, 0.0f, 0.0f); } dnnl::primitive_attr add_attr; @@ -507,11 +507,11 @@ void FCLayer::Eval(int N, dnnl::memory& output, dnnl::memory& input, dnnl::prop_kind::forward_inference, t_in_md, t_filter_md, bias_mem.get_desc(), t_out_md.reshape({N, num_outputs})); dnnl::post_ops fc_ops; - if (activation_ == RELU) { + if (activation_ == ACTIVATION_RELU) { fc_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); - } else if (activation_ == MISH) { + } else if (activation_ == ACTIVATION_MISH) { fc_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_mish, 0.0f, 0.0f); - } else if (activation_ == TANH) { + } else if (activation_ == ACTIVATION_TANH) { fc_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_tanh, 0.0f, 0.0f); } dnnl::primitive_attr fc_attr; diff --git a/src/neural/onednn/layers.h b/src/neural/onednn/layers.h index 90904ec26d..64b0096fc1 100644 --- a/src/neural/onednn/layers.h +++ b/src/neural/onednn/layers.h @@ -65,7 +65,7 @@ class BaseLayer { class ConvLayer : public BaseLayer { public: ConvLayer(BaseLayer* ip, int C, int H, int W, int size, int Cin, - ActivationFunction activation = NONE, bool skip = false); + ActivationFunction activation = ACTIVATION_NONE, bool skip = false); void LoadWeights(dnnl::memory& w1, dnnl::memory& b1, dnnl::engine& eng, dnnl::stream& stream); diff --git a/src/neural/onednn/network_onednn.cc b/src/neural/onednn/network_onednn.cc index 42b52a39b4..8587e12982 100644 --- a/src/neural/onednn/network_onednn.cc +++ b/src/neural/onednn/network_onednn.cc @@ -165,8 +165,8 @@ class OnednnNetwork : public Network { default_activation_ = file.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH - ? MISH - : RELU; + ? ACTIVATION_MISH + : ACTIVATION_RELU; #if DNNL_VERSION_MAJOR * 100 + DNNL_VERSION_MINOR >= 105 dnnl::set_primitive_cache_capacity( @@ -287,7 +287,7 @@ class OnednnNetwork : public Network { auto conv2 = std::make_unique( getLastLayer(idx), numFilters_, 8, 8, 3, numFilters_, - has_se ? NONE : default_activation_, !has_se); + has_se ? ACTIVATION_NONE : default_activation_, !has_se); w_mem = dnnl::memory(w_md, cpu_eng_, &weights.residual[block].conv2.weights[0]); b_mem = dnnl::memory(b_md, cpu_eng_, @@ -385,8 +385,9 @@ class OnednnNetwork : public Network { layers_[idx].emplace_back(std::move(conv1)); // No Activation - auto conv2 = std::make_unique( - getLastLayer(idx), pol_channels_, 8, 8, 3, numFilters_, NONE); + auto conv2 = + std::make_unique(getLastLayer(idx), pol_channels_, 8, 8, + 3, numFilters_, ACTIVATION_NONE); w_md = dnnl::memory::desc({pol_channels_, numFilters_, 3, 3}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::oihw); @@ -411,8 +412,8 @@ class OnednnNetwork : public Network { convPol->LoadWeights(w_mem, b_mem, eng_, eng_stream_); layers_[idx].emplace_back(std::move(convPol)); - auto FCPol = std::make_unique(getLastLayer(idx), - kNumOutputPolicy, 1, 1, NONE); + auto FCPol = std::make_unique( + getLastLayer(idx), kNumOutputPolicy, 1, 1, ACTIVATION_NONE); w_md = dnnl::memory::desc({kNumOutputPolicy, pol_channels_, 8, 8}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abcd); @@ -461,8 +462,9 @@ class OnednnNetwork : public Network { pblczero::NetworkFormat::VALUE_WDL; auto fc2_tanh = !wdl_; - auto FCVal2 = std::make_unique(getLastLayer(idx), wdl_ ? 3 : 1, - 1, 1, fc2_tanh ? TANH : NONE); + auto FCVal2 = std::make_unique( + getLastLayer(idx), wdl_ ? 3 : 1, 1, 1, + fc2_tanh ? ACTIVATION_TANH : ACTIVATION_NONE); w_md = dnnl::memory::desc({wdl_ ? 3 : 1, value_channels_, 1, 1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abcd); diff --git a/src/neural/onnx/builder.cc b/src/neural/onnx/builder.cc index be97dfd7ac..7dcb04c19a 100644 --- a/src/neural/onnx/builder.cc +++ b/src/neural/onnx/builder.cc @@ -38,8 +38,8 @@ namespace lczero { OnnxBuilder::OnnxBuilder(int opset) : opset_(opset) { - if (opset < 7 || opset > 17) { - throw Exception("Only ONNX opsets between 7 and 17 are supported."); + if (opset < 7 || opset > 18) { + throw Exception("Only ONNX opsets between 7 and 18 are supported."); } model_.set_ir_version(4); model_.set_domain("org.lczero.models.*"); @@ -167,14 +167,17 @@ std::string OnnxBuilder::GlobalAveragePool(const std::string& name, } std::string OnnxBuilder::Squeeze(const std::string& name, - const std::string& input) { + const std::string& input, + std::initializer_list axes) { auto* node = model_.mutable_graph()->add_node(); auto out = PopulateStdNodeFields(node, name, input, "Squeeze"); if (opset_ < 13) { - AddIntsAttribute(node, "axes", {2, 3}); + AddIntsAttribute(node, "axes", axes); } else { - node->add_input( - AddInitializer(name + "/axes", Int64OnnxConst({2, 3}, {2}))); + node->add_input(AddInitializer( + name + "/axes", + Int64OnnxConst(std::vector(begin(axes), end(axes)), + {static_cast(axes.size())}))); } return out; } @@ -311,6 +314,7 @@ std::vector OnnxBuilder::Split(const std::string& name, } return out; } + if (opset_ >= 18) AddIntAttribute(node, "num_outputs", 2); node->add_output(name + "/out1"); node->add_output(name + "/out2"); return {name + "/out1", name + "/out2"}; @@ -371,4 +375,67 @@ std::string OnnxBuilder::LayerNormalization(const std::string& name, return out; } +std::string OnnxBuilder::Expand(const std::string& name, + const std::string& input, + const std::string& shape) { + auto* node = model_.mutable_graph()->add_node(); + auto out = PopulateStdNodeFields(node, name, input, "Expand"); + node->add_input(shape); + return out; +} + +std::string OnnxBuilder::Shape(const std::string& name, + const std::string& input) { + auto* node = model_.mutable_graph()->add_node(); + return PopulateStdNodeFields(node, name, input, "Shape"); +} + +std::string OnnxBuilder::Exp(const std::string& name, + const std::string& input) { + auto* node = model_.mutable_graph()->add_node(); + return PopulateStdNodeFields(node, name, input, "Exp"); +} + +std::string OnnxBuilder::Div(const std::string& name, const std::string& input1, + const std::string& input2) { + auto* node = model_.mutable_graph()->add_node(); + auto out = PopulateStdNodeFields(node, name, input1, "Div"); + node->add_input(input2); + return out; +} + +std::string OnnxBuilder::Sub(const std::string& name, const std::string& input1, + const std::string& input2) { + auto* node = model_.mutable_graph()->add_node(); + auto out = PopulateStdNodeFields(node, name, input1, "Sub"); + node->add_input(input2); + return out; +} + +std::string OnnxBuilder::Greater(const std::string& name, + const std::string& input1, + const OnnxConst& input2) { + auto* node = model_.mutable_graph()->add_node(); + auto out = PopulateStdNodeFields(node, name, input1, "Greater"); + node->add_input(AddInitializer(name + "/threshold", input2)); + return out; +} + +std::string OnnxBuilder::Where(const std::string& name, + const std::string& input1, + const std::string& input2, + const std::string& input3) { + auto* node = model_.mutable_graph()->add_node(); + auto out = PopulateStdNodeFields(node, name, input1, "Where"); + node->add_input(input2); + node->add_input(input3); + return out; +} + +std::string OnnxBuilder::Mish(const std::string& name, + const std::string& input) { + auto* node = model_.mutable_graph()->add_node(); + return PopulateStdNodeFields(node, name, input, "Mish"); +} + } // namespace lczero diff --git a/src/neural/onnx/builder.h b/src/neural/onnx/builder.h index 3b88cb27ec..c1f6c5e957 100644 --- a/src/neural/onnx/builder.h +++ b/src/neural/onnx/builder.h @@ -68,7 +68,8 @@ class OnnxBuilder { const OnnxConst&); std::string GlobalAveragePool(const std::string& name, const std::string& input); - std::string Squeeze(const std::string& name, const std::string& input); + std::string Squeeze(const std::string& name, const std::string& input, + std::initializer_list axes); std::string MatMul(const std::string& name, const std::string& input1, const OnnxConst& input2); std::string MatMul(const std::string& name, const std::string& input1, @@ -106,6 +107,19 @@ class OnnxBuilder { const std::string& input, const OnnxConst& scale, const OnnxConst& bias, int axis, float epsilon = 1e-6); + std::string Expand(const std::string& name, const std::string& input, + const std::string& shape); + std::string Shape(const std::string& name, const std::string& input); + std::string Exp(const std::string& name, const std::string& input); + std::string Div(const std::string& name, const std::string& input1, + const std::string& input2); + std::string Sub(const std::string& name, const std::string& input1, + const std::string& input2); + std::string Greater(const std::string& name, const std::string& input1, + const OnnxConst&); + std::string Where(const std::string& name, const std::string& input1, + const std::string& input2, const std::string& input3); + std::string Mish(const std::string& name, const std::string& input); // Returns ONNX model as protobuf. const pblczero::ModelProto& as_proto() const { return model_; } // Returns serialized model. diff --git a/src/neural/onnx/converter.cc b/src/neural/onnx/converter.cc index 6004e4d56f..549ee20af3 100644 --- a/src/neural/onnx/converter.cc +++ b/src/neural/onnx/converter.cc @@ -56,8 +56,8 @@ class Converter { default_activation_ = net.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH - ? MISH - : RELU; + ? ACTIVATION_MISH + : ACTIVATION_RELU; } void Convert(pblczero::Net* dst); @@ -67,7 +67,8 @@ class Converter { return LayerAdapter(src_.weights().input().weights()).size() / kInputPlanes / 9; } - size_t NumBlocks() const { return src_.weights().residual_size(); } + size_t NumResBlocks() const { return src_.weights().residual_size(); } + size_t NumEncBlocks() const { return src_.weights().encoder().size(); } void CopyGenericFields(pblczero::Net* dst); void GenerateOnnx(pblczero::OnnxModel* onnx); void FillValueInfo(pblczero::ValueInfoProto* vip, const std::string& name, @@ -86,6 +87,9 @@ class Converter { const std::string& input, const std::string& name); + std::string MakeAttentionBody(OnnxBuilder* builder, const std::string& input, + const LegacyWeights& weights); + std::string MakeSqueezeAndExcite(OnnxBuilder* builder, const LegacyWeights::SEunit& se_unit, const std::string& input, @@ -94,15 +98,26 @@ class Converter { std::string MakeMish(OnnxBuilder* builder, const std::string& input, const std::string& name); + std::string MakeSwish(OnnxBuilder* builder, const std::string& input, + const std::string& name); + std::string MakeActivation(OnnxBuilder* builder, const std::string& input, const std::string& name, ActivationFunction activation); + std::string MakeSmolgen(OnnxBuilder* builder, + const LegacyWeights::EncoderLayer& layer, + int embedding_size, int heads, + const std::string& encoder_in, + const std::string& name); + std::string MakeEncoderLayer(OnnxBuilder* builder, const LegacyWeights::EncoderLayer& layer, int embedding_size, int heads, const std::string& encoder_in, - const std::string& name); + const std::string& name, + ActivationFunction activation, + float alpha = 1.0f); std::string MakeAttentionPolicy(OnnxBuilder* builder, const std::string& input, @@ -128,6 +143,7 @@ class Converter { const pblczero::Net& src_; const WeightsToOnnxConverterOptions& options_; ActivationFunction default_activation_; + bool se_reshape_init_ = false; }; pblczero::TensorProto::DataType Converter::GetDataType() const { @@ -157,8 +173,34 @@ std::unique_ptr Converter::GetWeghtsConverter( std::string Converter::MakeMish(OnnxBuilder* builder, const std::string& input, const std::string& name) { - auto flow = builder->Softplus(name + "/softplus", input); - flow = builder->Tanh(name + "/tanh", flow); + if (!options_.alt_mish || options_.opset < 9 || + options_.data_type_ != + WeightsToOnnxConverterOptions::DataType::kFloat32) { + if (options_.opset >= 18) return builder->Mish(name, input); + auto flow = builder->Softplus(name + "/softplus", input); + flow = builder->Tanh(name + "/tanh", flow); + return builder->Mul(name, flow, input); + } else { + const OnnxConst& two = + static_cast(FloatOnnxConst({2.0f}, {1})); + const OnnxConst& zero = + static_cast(FloatOnnxConst({0.0f}, {1})); + auto e = builder->Exp(name + "/exp", input); + auto flow = builder->Add(name + "/e+2", e, two); + auto n = builder->Mul(name + "/n", e, flow); + flow = builder->Add(name + "/n+2", n, two); + auto d = builder->Div(name + "/d", input, flow); + auto f = builder->Mul(name + "/n*d", n, d); + flow = builder->Mul(name + "/2*d", d, two); + auto t = builder->Sub(name + "/in-2*d", input, flow); + flow = builder->Greater(name + "/compare", input, zero); + return builder->Where(name, flow, t, f); + } +} + +std::string Converter::MakeSwish(OnnxBuilder* builder, const std::string& input, + const std::string& name) { + auto flow = builder->Sigmoid(name + "/sigmoid", input); return builder->Mul(name, flow, input); } @@ -167,12 +209,18 @@ std::string Converter::MakeActivation(OnnxBuilder* builder, const std::string& name, ActivationFunction activation) { switch (activation) { - case RELU: + case ACTIVATION_RELU: return builder->Relu(name + "/relu", input); - case MISH: + case ACTIVATION_MISH: return MakeMish(builder, input, name + "/mish"); - case SELU: + case ACTIVATION_SELU: return builder->Selu(name + "/selu", input); + case ACTIVATION_SWISH: + return MakeSwish(builder, input, name + "/swish"); + case ACTIVATION_RELU_2: { + auto flow = builder->Relu(name + "/sqrrelu/relu", input); + return builder->Mul(name + "/sqrrelu/sqr", flow, flow); + } default: throw Exception("Unsupposrted activation in " + name); } @@ -183,8 +231,13 @@ std::string Converter::MakeSqueezeAndExcite( const std::string& input, const std::string& name) { const int se_filters = se_unit.b1.size(); + if (!se_reshape_init_) { + builder->AddInitializer("/const/se_reshape", + Int64OnnxConst({-1, NumFilters() * 2, 1, 1}, {4})); + se_reshape_init_ = true; + } auto flow = builder->GlobalAveragePool(name + "/pooled", input); - flow = builder->Squeeze(name + "/squeeze", flow); + flow = builder->Squeeze(name + "/squeeze", flow, {2, 3}); flow = builder->MatMul( name + "/matmul1", flow, *GetWeghtsConverter(se_unit.w1, {NumFilters(), se_filters}, {1, 0})); @@ -235,15 +288,76 @@ std::string Converter::MakeResidualBlock(OnnxBuilder* builder, name + "/conv2", res.has_se ? &res.se : nullptr, input); } -void Converter::AddStdInitializers(OnnxBuilder* builder) { - builder->AddInitializer("/const/se_reshape", - Int64OnnxConst({-1, NumFilters() * 2, 1, 1}, {4})); +std::string Converter::MakeSmolgen(OnnxBuilder* builder, + const LegacyWeights::EncoderLayer& layer, + int embedding_size, int heads, + const std::string& encoder_in, + const std::string& name) { + const auto smolgen_activation = static_cast( + src_.format().network_format().smolgen_activation()); + const auto activation = smolgen_activation == ACTIVATION_DEFAULT + ? default_activation_ + : smolgen_activation; + const int smolgen_hidden_channels = + layer.mha.smolgen.compress.size() / embedding_size; + const int smolgen_hidden_sz = layer.mha.smolgen.dense1_b.size(); + const int smolgen_gen_sz = layer.mha.smolgen.dense2_b.size() / heads; + auto flow = builder->MatMul( + name + "/smolgen/compress", encoder_in, + *GetWeghtsConverter(layer.mha.smolgen.compress, + {embedding_size, smolgen_hidden_channels}, {1, 0})); + flow = builder->Reshape( + name + "/smolgen/compress/reshape", flow, + builder->AddInitializer( + "/const" + name + "/smolgen/compress/shape", + Int64OnnxConst({-1, 64 * smolgen_hidden_channels}, {2}))); + flow = builder->MatMul( + name + "/smolgen/dense1/w", flow, + *GetWeghtsConverter(layer.mha.smolgen.dense1_w, + {64 * smolgen_hidden_channels, smolgen_hidden_sz}, + {1, 0})); + flow = builder->Add( + name + "/smolgen/dense1/b", flow, + *GetWeghtsConverter(layer.mha.smolgen.dense1_b, {smolgen_hidden_sz})); + flow = MakeActivation(builder, flow, name + "/smolgen/dense1", activation); + flow = builder->LayerNormalization( + name + "/smolgen/ln1", flow, + *GetWeghtsConverter(layer.mha.smolgen.ln1_gammas, {smolgen_hidden_sz}), + *GetWeghtsConverter(layer.mha.smolgen.ln1_betas, {smolgen_hidden_sz}), 1, + 1e-3); + flow = builder->MatMul( + name + "/smolgen/dense2/w", flow, + *GetWeghtsConverter(layer.mha.smolgen.dense2_w, + {smolgen_hidden_sz, smolgen_gen_sz * heads}, {1, 0})); + flow = builder->Add(name + "/smolgen/dense2/b", flow, + *GetWeghtsConverter(layer.mha.smolgen.dense2_b, + {smolgen_gen_sz * heads})); + flow = MakeActivation(builder, flow, name + "/smolgen/dense2", activation); + flow = builder->LayerNormalization( + name + "/smolgen/ln2", flow, + *GetWeghtsConverter(layer.mha.smolgen.ln2_gammas, + {smolgen_gen_sz * heads}), + *GetWeghtsConverter(layer.mha.smolgen.ln2_betas, + {smolgen_gen_sz * heads}), + 1, 1e-3); + flow = + builder->Reshape(name + "/smolgen/gen_from/reshape", flow, + builder->AddInitializer( + "/const" + name + "/smolgen/gen_from/shape", + Int64OnnxConst({-1, heads, smolgen_gen_sz}, {3}))); + flow = builder->MatMul(name + "/smolgen/smol_weight_gen", flow, + "/const/smolgen_w"); + flow = builder->Reshape( + name + "/smolgen/out/reshape", flow, + builder->AddInitializer("/const" + name + "/smolgen/out/shape", + Int64OnnxConst({-1, heads, 64, 64}, {4}))); + return flow; } std::string Converter::MakeEncoderLayer( OnnxBuilder* builder, const LegacyWeights::EncoderLayer& layer, int embedding_size, int heads, const std::string& encoder_in, - const std::string& name) { + const std::string& name, ActivationFunction activation, float alpha) { const int d_model = layer.mha.q_b.size(); const int depth = d_model / heads; @@ -281,6 +395,11 @@ std::string Converter::MakeEncoderLayer( FloatOnnxConst({1.0f / sqrtf(depth)}, {1})); } flow = builder->Mul(name + "/mha/QK/scale", flow, *scale); + if (layer.mha.has_smolgen) { + auto smolgen_weights = + MakeSmolgen(builder, layer, embedding_size, heads, encoder_in, name); + flow = builder->Add(name + "/smolgen_weights", flow, smolgen_weights); + } flow = builder->Softmax(name + "/mha/QK/softmax", flow, 3); flow = builder->MatMul(name + "/mha/QKV/matmul", flow, V); if (heads > 1) { @@ -296,7 +415,22 @@ std::string Converter::MakeEncoderLayer( {d_model, embedding_size}, {1, 0})); flow = builder->Add(name + "/mha/out/dense/b", flow, *GetWeghtsConverter(layer.mha.dense_b, {embedding_size})); - flow = builder->Add(name + "/mha/out/skip", flow, encoder_in); + std::unique_ptr alpha_onnx; + std::string alpha_in; + if (alpha != 1.0) { + if (GetDataType() == pblczero::TensorProto::FLOAT16) { + alpha_onnx = std::make_unique( + Float16OnnxConst({FP32toFP16(alpha)}, {1})); + } else { + alpha_onnx = + std::make_unique(FloatOnnxConst({alpha}, {1})); + } + alpha_in = builder->Mul(name + "/alpha*input", encoder_in, *alpha_onnx); + } else { + alpha_in = encoder_in; + } + flow = builder->Add(name + "/mha/out/skip", flow, alpha_in); + auto ffn_in = builder->LayerNormalization( name + "/ln1", flow, *GetWeghtsConverter(layer.ln1_gammas, {embedding_size}), @@ -308,7 +442,12 @@ std::string Converter::MakeEncoderLayer( {embedding_size, dff_size}, {1, 0})); flow = builder->Add(name + "/ffn/dense1/b", flow, *GetWeghtsConverter(layer.ffn.dense1_b, {dff_size})); - flow = MakeActivation(builder, flow, name + "/ffn/dense1", SELU); + + const auto ffn_activation = static_cast( + src_.format().network_format().ffn_activation()); + flow = MakeActivation( + builder, flow, name + "/ffn/dense1", + ffn_activation == ACTIVATION_DEFAULT ? activation : ffn_activation); flow = builder->MatMul(name + "/ffn/dense2/w", flow, *GetWeghtsConverter(layer.ffn.dense2_w, @@ -316,7 +455,13 @@ std::string Converter::MakeEncoderLayer( flow = builder->Add(name + "/ffn/dense2/b", flow, *GetWeghtsConverter(layer.ffn.dense2_b, {embedding_size})); - flow = builder->Add(name + "/ffn/skip", flow, ffn_in); + std::string alpha_ffn_in; + if (alpha != 1.0) { + alpha_ffn_in = builder->Mul(name + "/alpha*out1", ffn_in, *alpha_onnx); + } else { + alpha_ffn_in = ffn_in; + } + flow = builder->Add(name + "/ffn/skip", flow, alpha_ffn_in); flow = builder->LayerNormalization( name + "/ln2", flow, *GetWeghtsConverter(layer.ln2_gammas, {embedding_size}), @@ -324,6 +469,98 @@ std::string Converter::MakeEncoderLayer( return flow; } +std::string Converter::MakeAttentionBody(OnnxBuilder* builder, + const std::string& input, + const LegacyWeights& weights) { + if (weights.has_smolgen) { + builder->AddInitializer( + "/const/smolgen_w", + *GetWeghtsConverter( + weights.smolgen_w, + {static_cast(weights.smolgen_w.size() / 4096), 4096}, {1, 0})); + } + + auto flow = builder->Transpose("/attn_body/transpose", input, {0, 2, 3, 1}); + + if (NumResBlocks() > 0) { + flow = builder->Reshape( + "/attn_body/reshape", flow, + builder->AddInitializer("/const/att_body_shape", + Int64OnnxConst({-1, NumFilters()}, {2}))); + } else { + flow = builder->Reshape( + "/attn_body/reshape", flow, + builder->AddInitializer("/const/att_body_shape", + Int64OnnxConst({-1, 64, 112}, {3}))); + std::string pad; + if (options_.batch_size < 0) { + pad = builder->Shape("/attn_body/shape", flow); + pad = builder->Slice("/attn_body/batch", pad, {0}, {1}); + pad = builder->Concat( + "/attn_body/pos_encoding_shape", + {pad, builder->AddInitializer("/const/pos_encoding_shape", + Int64OnnxConst({64, 64}, {2}))}, + 0); + } else { + pad = builder->AddInitializer( + "/const/pos_encoding_shape", + Int64OnnxConst({options_.batch_size, 64, 64}, {3})); + } + pad = builder->Expand( + "/attn_body/expand", + builder->AddInitializer( + "/const/pos_encoding", + *GetWeghtsConverter( + std::vector(kPosEncoding[0], kPosEncoding[0] + 64 * 64), + {1, 64, 64})), + pad); + flow = builder->Concat("/attn_body/padded_input", {flow, pad}, 2); + flow = builder->Reshape( + "/attn_body/reshape2", flow, + builder->AddInitializer("/const/att_body_shape2", + Int64OnnxConst({-1, 176}, {2}))); + } + + int embedding_size = weights.ip_emb_b.size(); + flow = builder->MatMul( + "/attn_body/matmul", flow, + *GetWeghtsConverter( + weights.ip_emb_w, + {NumResBlocks() > 0 ? NumFilters() : 176, embedding_size}, {1, 0})); + flow = builder->Add("/attn_body/add", flow, + *GetWeghtsConverter(weights.ip_emb_b, {embedding_size})); + flow = MakeActivation(builder, flow, "/attn_body", default_activation_); + + if (weights.ip_mult_gate.size() > 0 || weights.ip_add_gate.size() > 0) { + flow = builder->Reshape( + "/attn_body/ma_gating/rehape1", flow, + builder->AddInitializer("/const/ma_gating/shape1", + Int64OnnxConst({-1, 64, embedding_size}, {3}))); + if (weights.ip_mult_gate.size() > 0) { + flow = builder->Mul("/ip_mul_gate", flow, + *GetWeghtsConverter(weights.ip_mult_gate, + {64, embedding_size}, {1, 0})); + } + if (weights.ip_add_gate.size() > 0) { + flow = builder->Add("/ip_add_gate", flow, + *GetWeghtsConverter(weights.ip_add_gate, + {64, embedding_size}, {1, 0})); + } + flow = builder->Reshape( + "/attn_body/ma_gating/rehape2", flow, + builder->AddInitializer("/const/ma_gating/shape2", + Int64OnnxConst({-1, embedding_size}, {2}))); + } + + float alpha = std::pow(2.0f * NumEncBlocks(), 0.25f); + for (size_t i = 0; i < NumEncBlocks(); i++) { + flow = MakeEncoderLayer( + builder, weights.encoder[i], embedding_size, weights.encoder_head_count, + flow, "/encoder" + std::to_string(i), default_activation_, alpha); + } + return flow; +} + namespace { std::vector MakePolicyMap(const short* map, int size) { std::vector policy_map(1858); @@ -339,32 +576,46 @@ std::vector MakePolicyMap(const short* map, int size) { std::string Converter::MakeAttentionPolicy(OnnxBuilder* builder, const std::string& input, const LegacyWeights& weights) { - const int embedding_size = weights.ip_pol_b.size(); + const int embedding_size = weights.ip_emb_b.size(); + const int policy_embedding_size = weights.ip_pol_b.size(); const int policy_d_model = weights.ip2_pol_b.size(); - auto flow = - builder->Transpose("/policy/dense1/transpose", input, {0, 2, 3, 1}); - flow = builder->Reshape( - "/policy/dense1/reshape", flow, - builder->AddInitializer("/const/policy_shape", - Int64OnnxConst({-1, NumFilters()}, {2}))); + auto flow = input; + auto activation = + src_.format().network_format().network() >= + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT + ? default_activation_ + : ACTIVATION_SELU; + if (NumEncBlocks() == 0) { + flow = builder->Transpose("/policy/dense1/transpose", flow, {0, 2, 3, 1}); + + flow = builder->Reshape( + "/policy/dense1/reshape", flow, + builder->AddInitializer("/const/policy_shape", + Int64OnnxConst({-1, NumFilters()}, {2}))); + } flow = builder->MatMul( "/policy/dense1/matmul", flow, - *GetWeghtsConverter(weights.ip_pol_w, {NumFilters(), embedding_size}, + *GetWeghtsConverter(weights.ip_pol_w, + {NumEncBlocks() > 0 ? embedding_size : NumFilters(), + policy_embedding_size}, {1, 0})); - flow = builder->Add("/policy/dense1/add", flow, - *GetWeghtsConverter(weights.ip_pol_b, {embedding_size})); - flow = MakeActivation(builder, flow, "/policy/dense1", SELU); + flow = builder->Add( + "/policy/dense1/add", flow, + *GetWeghtsConverter(weights.ip_pol_b, {policy_embedding_size})); + flow = MakeActivation(builder, flow, "/policy/dense1", activation); + for (size_t i = 0; i < weights.pol_encoder.size(); i++) { std::string name = "/policy/enc_layer_" + std::to_string(i); - flow = MakeEncoderLayer(builder, weights.pol_encoder[i], embedding_size, - weights.pol_encoder_head_count, flow, name); + flow = MakeEncoderLayer( + builder, weights.pol_encoder[i], policy_embedding_size, + weights.pol_encoder_head_count, flow, name, activation); } auto encoder_out = flow; flow = builder->MatMul( "/policy/Q/matmul", encoder_out, - *GetWeghtsConverter(weights.ip2_pol_w, {embedding_size, policy_d_model}, - {1, 0})); + *GetWeghtsConverter(weights.ip2_pol_w, + {policy_embedding_size, policy_d_model}, {1, 0})); flow = builder->Add("/policy/Q/add", flow, *GetWeghtsConverter(weights.ip2_pol_b, {policy_d_model})); auto Q = builder->Reshape( @@ -373,8 +624,8 @@ std::string Converter::MakeAttentionPolicy(OnnxBuilder* builder, Int64OnnxConst({-1, 64, policy_d_model}, {3}))); flow = builder->MatMul( "/policy/K/matmul", encoder_out, - *GetWeghtsConverter(weights.ip3_pol_w, {embedding_size, policy_d_model}, - {1, 0})); + *GetWeghtsConverter(weights.ip3_pol_w, + {policy_embedding_size, policy_d_model}, {1, 0})); flow = builder->Add("/policy/K/add", flow, *GetWeghtsConverter(weights.ip3_pol_b, {policy_d_model})); auto K = builder->Reshape("/policy/K/reshape", flow, "/const/QK_shape"); @@ -443,6 +694,10 @@ void Converter::MakePolicyHead(pblczero::OnnxModel* onnx, OnnxBuilder* builder, onnx->set_output_policy(output); } else if (!weights.policy1.weights.empty()) { // Conv policy head. + if (NumEncBlocks() > 0) { + throw Exception( + "Convolutional policy not supported with attention body."); + } auto flow = MakeConvBlock(builder, weights.policy1, NumFilters(), NumFilters(), input, "/policy/conv1"); flow = MakeConvBlock(builder, weights.policy, NumFilters(), 80, flow, @@ -463,6 +718,9 @@ void Converter::MakePolicyHead(pblczero::OnnxModel* onnx, OnnxBuilder* builder, onnx->set_output_policy(output); } else { // Dense policy head. + if (NumEncBlocks() > 0) { + throw Exception("Classical policy not supported with attention body."); + } const int pol_channels = weights.policy.biases.size(); auto flow = MakeConvBlock(builder, weights.policy, NumFilters(), pol_channels, @@ -486,15 +744,29 @@ void Converter::MakePolicyHead(pblczero::OnnxModel* onnx, OnnxBuilder* builder, void Converter::MakeValueHead(pblczero::OnnxModel* onnx, OnnxBuilder* builder, const std::string& input, const LegacyWeights& weights) { - auto flow = MakeConvBlock(builder, weights.value, NumFilters(), 32, input, - "/value/conv", nullptr, "", true, 1); + std::string flow; + const int val_channels = NumEncBlocks() > 0 ? weights.ip_val_b.size() : 32; + if (NumEncBlocks() > 0) { + int embedding_size = weights.ip_emb_b.size(); + flow = builder->MatMul( + "/value/embed/matmul", input, + *GetWeghtsConverter(weights.ip_val_w, {embedding_size, val_channels}, + {1, 0})); + flow = builder->Add("/value/embed/add", flow, + *GetWeghtsConverter(weights.ip_val_b, {val_channels})); + flow = MakeActivation(builder, flow, "/value/embed", default_activation_); + } else { + flow = MakeConvBlock(builder, weights.value, NumFilters(), val_channels, + input, "/value/conv", nullptr, "", true, 1); + } flow = builder->Reshape( "/value/reshape", flow, builder->AddInitializer("/const/value_shape", - Int64OnnxConst({-1, 32 * 8 * 8}, {2}))); - flow = builder->MatMul( - "/value/dense1/matmul", flow, - *GetWeghtsConverter(weights.ip1_val_w, {32 * 8 * 8, 128}, {1, 0})); + Int64OnnxConst({-1, val_channels * 8 * 8}, {2}))); + flow = + builder->MatMul("/value/dense1/matmul", flow, + *GetWeghtsConverter(weights.ip1_val_w, + {val_channels * 8 * 8, 128}, {1, 0})); flow = builder->Add("/value/dense1/add", flow, *GetWeghtsConverter(weights.ip1_val_b, {128})); flow = MakeActivation(builder, flow, "/value/dense1", default_activation_); @@ -530,11 +802,25 @@ void Converter::MakeMovesLeftHead(pblczero::OnnxModel* onnx, pblczero::NetworkFormat::MOVES_LEFT_V1) { return; } - const int mlh_channels = weights.moves_left.biases.size(); + const int mlh_channels = NumEncBlocks() > 0 + ? weights.ip_mov_b.size() + : weights.moves_left.biases.size(); const int mlh_fc1_outputs = weights.ip1_mov_b.size(); - auto flow = - MakeConvBlock(builder, weights.moves_left, NumFilters(), mlh_channels, - input, "/mlh/conv", nullptr, "", true, 1); + std::string flow; + if (NumEncBlocks() > 0) { + int embedding_size = weights.ip_emb_b.size(); + flow = builder->MatMul( + "/mlh/embed/matmul", input, + *GetWeghtsConverter(weights.ip_mov_w, {embedding_size, mlh_channels}, + {1, 0})); + flow = builder->Add("/mlh/embed/add", flow, + *GetWeghtsConverter(weights.ip_mov_b, {mlh_channels})); + flow = MakeActivation(builder, flow, "/mlh/embed", default_activation_); + } else { + flow = + MakeConvBlock(builder, weights.moves_left, NumFilters(), mlh_channels, + input, "/mlh/conv", nullptr, "", true, 1); + } flow = builder->Reshape( "/mlh/reshape", flow, builder->AddInitializer("/const/mlh_shape", @@ -562,21 +848,28 @@ void Converter::GenerateOnnx(pblczero::OnnxModel* onnx) { LegacyWeights weights(src_.weights()); OnnxBuilder builder(options_.opset); - AddStdInitializers(&builder); - onnx->set_input_planes(options_.input_planes_name); builder.AddInput(options_.input_planes_name, {options_.batch_size, 112, 8, 8}, GetDataType()); + + auto flow = options_.input_planes_name; + // Input convolution. - auto flow = MakeConvBlock(&builder, weights.input, kInputPlanes, NumFilters(), - options_.input_planes_name, "/inputconv"); + if (NumResBlocks() > 0) { + flow = MakeConvBlock(&builder, weights.input, kInputPlanes, NumFilters(), + flow, "/inputconv"); + } // Residual tower. - for (size_t i = 0; i < NumBlocks(); ++i) { + for (size_t i = 0; i < NumResBlocks(); ++i) { flow = MakeResidualBlock(&builder, weights.residual[i], flow, "/block" + std::to_string(i)); } + if (NumEncBlocks() > 0) { + flow = MakeAttentionBody(&builder, flow, weights); + } + // Policy head. MakePolicyHead(onnx, &builder, flow, weights); // Value head. diff --git a/src/neural/onnx/converter.h b/src/neural/onnx/converter.h index 8980ffb858..b906aebfce 100644 --- a/src/neural/onnx/converter.h +++ b/src/neural/onnx/converter.h @@ -43,6 +43,7 @@ struct WeightsToOnnxConverterOptions { std::string output_mlh = "/output/mlh"; int batch_size = -1; int opset = 17; + bool alt_mish = false; }; // Converts "classical" weights file to weights file with embedded ONNX model. diff --git a/src/neural/onnx/network_onnx.cc b/src/neural/onnx/network_onnx.cc index 7d7950175c..c458b3a585 100644 --- a/src/neural/onnx/network_onnx.cc +++ b/src/neural/onnx/network_onnx.cc @@ -82,8 +82,8 @@ class OnnxComputation : public NetworkComputation { class OnnxNetwork : public Network { public: OnnxNetwork(const WeightsFile& file, const OptionsDict& options, - OnnxProvider provider, int gpu, bool fp16, int batch_size, - int steps); + OnnxProvider provider, int gpu, int threads, bool fp16, + int batch_size, int steps); std::unique_ptr NewComputation() override { if (fp16_) { return std::make_unique>(this); @@ -196,12 +196,11 @@ Ort::Value OnnxComputation::PrepareInputs(int start, int batch_size) { int end = std::min(start + batch_size, static_cast(raw_input_.size())); for (int i = start; i < end; i++) { for (const auto& plane : raw_input_[i]) { + DataType value = std::is_same::value + ? FP32toFP16(plane.value) + : plane.value; for (auto bit : IterateBits(plane.mask)) { - if (std::is_same::value) { - *(iter + bit) = FP32toFP16(plane.value); - } else { - *(iter + bit) = plane.value; - } + *(iter + bit) = value; } iter += 64; } @@ -251,10 +250,11 @@ void OnnxComputation::ComputeBlocking() { } } -Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int batch_size) { +Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int threads, + int batch_size) { Ort::SessionOptions options; OrtCUDAProviderOptions cuda_options; - // options.SetIntraOpNumThreads(1); + options.SetIntraOpNumThreads(threads); options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); if (batch_size > 0) { @@ -281,8 +281,6 @@ Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int batch_size) { options.AppendExecutionProvider_CUDA(cuda_options); break; case OnnxProvider::CPU: - // Doesn't really work. :-( There are two execution providers (CUDA and - // CPU) already added, don't know how to force it to use CPU. auto status = OrtSessionOptionsAppendExecutionProvider_CPU(options, 0); if (status) { std::string error_message = Ort::GetApi().GetErrorMessage(status); @@ -297,7 +295,7 @@ Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int batch_size) { } OnnxNetwork::OnnxNetwork(const WeightsFile& file, const OptionsDict&, - OnnxProvider provider, int gpu, bool fp16, + OnnxProvider provider, int gpu, int threads, bool fp16, int batch_size, int steps) : onnx_env_(ORT_LOGGING_LEVEL_WARNING, "lc0"), steps_(steps), @@ -313,9 +311,10 @@ OnnxNetwork::OnnxNetwork(const WeightsFile& file, const OptionsDict&, } for (int step = 1; step <= steps_; step++) - session_.emplace_back(onnx_env_, file.onnx_model().model().data(), - file.onnx_model().model().size(), - GetOptions(provider, gpu, batch_size_ * step)); + session_.emplace_back( + onnx_env_, file.onnx_model().model().data(), + file.onnx_model().model().size(), + GetOptions(provider, gpu, threads, batch_size_ * step)); const auto& md = file.onnx_model(); if (!md.has_input_planes()) { @@ -359,7 +358,10 @@ std::unique_ptr MakeOnnxNetwork(const std::optional& w, opts.GetOrDefault("batch", kProvider == OnnxProvider::DML ? 16 : -1); int steps = - opts.GetOrDefault("steps", kProvider == OnnxProvider::DML ? 8 : 1); + opts.GetOrDefault("steps", kProvider == OnnxProvider::DML ? 4 : 1); + + int threads = + opts.GetOrDefault("threads", kProvider == OnnxProvider::CPU ? 1 : 0); if (batch_size <= 0) batch_size = -1; // Variable batch size. @@ -367,13 +369,15 @@ std::unique_ptr MakeOnnxNetwork(const std::optional& w, "fp16", kProvider == OnnxProvider::CPU ? false : true); if (w->has_onnx_model()) { - return std::make_unique(*w, opts, kProvider, gpu, false, - batch_size, steps); + return std::make_unique(*w, opts, kProvider, gpu, threads, + false, batch_size, steps); } else { if (w->format().network_format().network() != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && w->format().network_format().network() != - pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT) { + pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + w->format().network_format().network() != + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( w->format().network_format().network()) + @@ -410,12 +414,15 @@ std::unique_ptr MakeOnnxNetwork(const std::optional& w, } WeightsToOnnxConverterOptions converter_options; converter_options.opset = opts.GetOrDefault("opset", 17); + converter_options.alt_mish = opts.GetOrDefault( + "alt_mish", kProvider == OnnxProvider::CPU ? true : false); converter_options.data_type_ = fp16 ? WeightsToOnnxConverterOptions::DataType::kFloat16 : WeightsToOnnxConverterOptions::DataType::kFloat32; + auto converted = ConvertWeightsToOnnx(*w, converter_options); - return std::make_unique(converted, opts, kProvider, gpu, fp16, - batch_size, steps); + return std::make_unique(converted, opts, kProvider, gpu, + threads, fp16, batch_size, steps); } } diff --git a/src/neural/shared/activation.cc b/src/neural/shared/activation.cc index ecf154f2ed..1710d9e802 100644 --- a/src/neural/shared/activation.cc +++ b/src/neural/shared/activation.cc @@ -21,6 +21,8 @@ #include #include +#include "utils/exception.h" + #ifdef USE_ISPC #include "activation_ispc.h" #endif @@ -68,35 +70,45 @@ static inline float selu(float val) { float Activate(const float val, const ActivationFunction activation) { switch (activation) { - case RELU: + case ACTIVATION_RELU: return val > 0 ? val : 0; - case MISH: + case ACTIVATION_RELU_2: + return val > 0 ? val * val : 0; + case ACTIVATION_MISH: return mish(val); - case TANH: + case ACTIVATION_TANH: return tanhf(val); - case SIGMOID: + case ACTIVATION_SIGMOID: return 1.0f / (1.0f + expf(-val)); - case SELU: + case ACTIVATION_SELU: return selu(val); - case NONE: + case ACTIVATION_SWISH: + return val / (1.0f + expf(-val)); + case ACTIVATION_NONE: // Nothing to do. break; + default: + throw Exception("unsupported activation function"); } return val; } void Activate(const size_t len, const float* data, const float* bias, float* output, const ActivationFunction activation) { - if (activation == NONE) { + if (activation == ACTIVATION_NONE) { for (size_t b = 0; b < len; b++) { output[b] = data[b] + bias[b]; } - } else if (activation == RELU) { + } else if (activation == ACTIVATION_RELU) { +#ifndef USE_ISPC for (size_t b = 0; b < len; b++) { float val = data[b] + bias[b]; output[b] = val > 0 ? val : 0; } - } else if (activation == MISH) { +#else + ispc::ActivateRelu(len, 1.0f, data, bias, 0.0f, output); +#endif + } else if (activation == ACTIVATION_MISH) { #ifndef USE_ISPC for (size_t b = 0; b < len; b++) { float val = data[b] + bias[b]; @@ -104,6 +116,34 @@ void Activate(const size_t len, const float* data, const float* bias, } #else ispc::ActivateMish(len, 1.0f, data, bias, 0.0f, output); +#endif + } else if (activation == ACTIVATION_RELU_2) { +#ifndef USE_ISPC + for (size_t b = 0; b < len; b++) { + float val = data[b] + bias[b]; + output[b] = val > 0 ? val * val : 0; + } +#else + ispc::ActivateRelu_2(len, data, bias, output); +#endif + } else if (activation == ACTIVATION_SWISH) { +#ifndef USE_ISPC + for (size_t b = 0; b < len; b++) { + float val = data[b] + bias[b]; + output[b] = val / (1.0f + exp(-val)); + ; + } +#else + ispc::ActivateSwish(len, data, bias, output); +#endif + } else if (activation == ACTIVATION_SELU) { +#ifndef USE_ISPC + for (size_t b = 0; b < len; b++) { + float val = data[b] + bias[b]; + output[b] = selu(val); + } +#else + ispc::ActivateSelu(len, data, bias, output); #endif } else { for (size_t b = 0; b < len; b++) { @@ -116,17 +156,21 @@ void Activate(const size_t len, const float* data, const float* bias, void Activate(const size_t len, float gamma, const float* data, const float* bias, float beta, float* output, const ActivationFunction activation) { - if (activation == NONE) { + if (activation == ACTIVATION_NONE) { for (size_t b = 0; b < len; b++) { float val = gamma * data[b] + bias[b] + beta; output[b] = val; } - } else if (activation == RELU) { + } else if (activation == ACTIVATION_RELU) { +#ifndef USE_ISPC for (size_t b = 0; b < len; b++) { float val = gamma * data[b] + bias[b] + beta; output[b] = val > 0 ? val : 0; } - } else if (activation == MISH) { +#else + ispc::ActivateRelu(len, gamma, data, bias, beta, output); +#endif + } else if (activation == ACTIVATION_MISH) { #ifndef USE_ISPC for (size_t b = 0; b < len; b++) { float val = gamma * data[b] + bias[b] + beta; @@ -151,17 +195,21 @@ void BiasResidual(const size_t batch_size, const size_t channels, float* data, auto bias = biases[c]; auto arr = &data[c * kSquares]; auto res = &eltwise[c * kSquares]; - if (activation == NONE) { + if (activation == ACTIVATION_NONE) { for (size_t b = 0; b < kSquares; b++) { float val = res[b] + arr[b] + bias; arr[b] = val; } - } else if (activation == RELU) { + } else if (activation == ACTIVATION_RELU) { +#ifndef USE_ISPC for (size_t b = 0; b < kSquares; b++) { float val = res[b] + arr[b] + bias; arr[b] = val > 0 ? val : 0; } - } else if (activation == MISH) { +#else + ispc::ActivateRelu(kSquares, 1.0f, res, arr, bias, arr); +#endif + } else if (activation == ACTIVATION_MISH) { #ifndef USE_ISPC for (size_t b = 0; b < kSquares; b++) { float val = res[b] + arr[b] + bias; @@ -188,17 +236,17 @@ void BiasActivate(const size_t batch_size, const size_t channels, float* data, for (size_t c = 0; c < channels; ++c) { auto bias = biases[c]; auto arr = &data[c * kSquares]; - if (activation == NONE) { + if (activation == ACTIVATION_NONE) { for (size_t b = 0; b < kSquares; b++) { float val = arr[b] + bias; arr[b] = val; } - } else if (activation == RELU) { + } else if (activation == ACTIVATION_RELU) { for (size_t b = 0; b < kSquares; b++) { float val = arr[b] + bias; arr[b] = val > 0 ? val : 0; } - } else if (activation == MISH) { + } else if (activation == ACTIVATION_MISH) { #ifndef USE_ISPC for (size_t b = 0; b < kSquares; b++) { float val = arr[b] + bias; diff --git a/src/neural/shared/activation.h b/src/neural/shared/activation.h index 5937126b32..6f110886db 100644 --- a/src/neural/shared/activation.h +++ b/src/neural/shared/activation.h @@ -22,18 +22,37 @@ #include namespace lczero { -enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH }; +// The following list matches the one in net.proto. Ideally this would be done +// by including proto/net.pb.h, but this is incompatible with nvcc. +enum ActivationFunction { + ACTIVATION_DEFAULT = 0, + ACTIVATION_MISH = 1, + ACTIVATION_RELU = 2, + ACTIVATION_NONE = 3, + ACTIVATION_TANH = 4, + ACTIVATION_SIGMOID = 5, + ACTIVATION_SELU = 6, + ACTIVATION_SWISH = 7, + ACTIVATION_RELU_2 = 8, + ACTIVATION_SOFTMAX = 9, +}; + +struct Activations { + ActivationFunction default_activation = ACTIVATION_RELU; + ActivationFunction smolgen_activation = ACTIVATION_SWISH; + ActivationFunction ffn_activation = ACTIVATION_RELU_2; +}; // Softmax activation void SoftmaxActivation(const size_t size, const float* input, float* output); -void BiasResidual(const size_t batch_size, const size_t channels, float * data, +void BiasResidual(const size_t batch_size, const size_t channels, float* data, const float* biases, const float* eltwise, - const ActivationFunction activation = RELU); + const ActivationFunction activation); -void BiasActivate(const size_t batch_size, const size_t channels, float * data, +void BiasActivate(const size_t batch_size, const size_t channels, float* data, const float* biases, - const ActivationFunction activation = RELU); + const ActivationFunction activation); float Activate(const float val, const ActivationFunction activation); diff --git a/src/neural/shared/activation.ispc b/src/neural/shared/activation.ispc index 27cc36d116..987dc3e689 100644 --- a/src/neural/shared/activation.ispc +++ b/src/neural/shared/activation.ispc @@ -1,6 +1,6 @@ /* This file is part of Leela Chess Zero. - Copyright (C) 2022 The LCZero Authors + Copyright (C) 2022-2023 The LCZero Authors Leela Chess is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -35,3 +35,67 @@ export void ActivateMish(uniform const size_t len, uniform float gamma, output[b] = mish(val); } } + +export void ActivateRelu(uniform const size_t len, uniform float gamma, + const uniform float data[], const uniform float bias[], + uniform float beta, uniform float output[]) { + foreach (b = 0 ... len) { + float val = gamma * data[b] + bias[b] + beta; + output[b] = val > 0 ? val : 0; + } +} + +export void ActivateSwish(uniform const size_t len, const uniform float data[], + const uniform float bias[], uniform float output[]) { + foreach (b = 0 ... len) { + float val = data[b] + bias[b]; + output[b] = val / (1.0f + exp(-val)); + } +} + +export void ActivateRelu_2(uniform const size_t len, const uniform float data[], + const uniform float bias[], uniform float output[]) { + foreach (b = 0 ... len) { + float val = data[b] + bias[b]; + output[b] = val > 0 ? val * val : 0; + } +} + +static inline float selu(float val) { + float alpha = 1.67326324f, scale = 1.05070098f; + if (val > 0) { + return scale * val; + } else { + return scale * alpha * (exp(val) - 1.0f); + } +} + +export void ActivateSelu(uniform const size_t len, const uniform float data[], + const uniform float bias[], uniform float output[]) { + foreach (b = 0 ... len) { + float val = data[b] + bias[b]; + output[b] = selu(val); + } +} + +export void SoftmaxActivation(uniform const size_t size, + const uniform float input[], + uniform float output[]) { + float vmax = -3.4e38f; + foreach (c = 0 ... size) { + if (input[c] > vmax) vmax = input[c]; + } + uniform float alpha = reduce_max(vmax); + + float t = 0.0f; + foreach (c = 0 ... size) { + float val = exp(input[c] - alpha); + output[c] = val; + t += val; + } + uniform float denom = 1.0f / reduce_add(t); + + foreach (c = 0 ... size) { + output[c] *= denom; + } +} diff --git a/src/neural/shared/attention_policy_map.h b/src/neural/shared/attention_policy_map.h index 5ab0966654..3cd910acb9 100644 --- a/src/neural/shared/attention_policy_map.h +++ b/src/neural/shared/attention_policy_map.h @@ -381,29 +381,324 @@ const short kAttnPolicyMap[] = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1852, 1853, 1854, 1855, 1856, 1857}; -} // namespace lczero - - - - - - - - - - - - - - - - - - - - - - - - +constexpr int kNumPosEncodingChannels = 64; + +const float kPosEncoding[64][kNumPosEncodingChannels] = { + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0}; +} // namespace lczero diff --git a/src/selfplay/game.cc b/src/selfplay/game.cc index c06a0e3d85..2c0ef82d79 100644 --- a/src/selfplay/game.cc +++ b/src/selfplay/game.cc @@ -88,13 +88,13 @@ SelfPlayGame::SelfPlayGame(PlayerOptions white, PlayerOptions black, SearchParams(*black.uci_options).GetHistoryFill(), white.network->GetCapabilities().input_format) { orig_fen_ = opening.start_fen; - tree_[0] = std::make_shared(); + tree_[0] = std::make_shared(*options_[0].uci_options); tree_[0]->ResetToPosition(orig_fen_, {}); if (shared_tree) { tree_[1] = tree_[0]; } else { - tree_[1] = std::make_shared(); + tree_[1] = std::make_shared(*options_[1].uci_options); tree_[1]->ResetToPosition(orig_fen_, {}); } int ply = 0; @@ -132,9 +132,8 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, bool blacks_move = tree_[0]->IsBlackToMove(); // If we are training, verify that input formats are consistent. - if (training && - options_[0].network->GetCapabilities().input_format != - options_[1].network->GetCapabilities().input_format) { + if (training && options_[0].network->GetCapabilities().input_format != + options_[1].network->GetCapabilities().input_format) { throw Exception("Can't mix networks with different input format!"); } // Take syzygy tablebases from player1 options. @@ -162,6 +161,7 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, const int idx = blacks_move ? 1 : 0; if (!options_[idx].uci_options->Get(kReuseTreeId)) { tree_[idx]->TrimTreeAtHead(); + tree_[idx]->TTClear(); } { std::lock_guard lock(mutex_); @@ -180,7 +180,7 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, } search_ = std::make_unique( - *tree_[idx], options_[idx].network, std::move(responder), + tree_[idx].get(), options_[idx].network, std::move(responder), /* searchmoves */ MoveList(), std::chrono::steady_clock::now(), std::move(stoppers), /* infinite */ false, *options_[idx].uci_options, options_[idx].cache, @@ -206,9 +206,8 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, max_eval_[0] = std::max(max_eval_[0], blacks_move ? best_l : best_w); max_eval_[1] = std::max(max_eval_[1], best_d); max_eval_[2] = std::max(max_eval_[2], blacks_move ? best_w : best_l); - if (enable_resign && - move_number >= - options_[idx].uci_options->Get(kResignEarliestMoveId)) { + if (enable_resign && move_number >= options_[idx].uci_options->Get( + kResignEarliestMoveId)) { const float resignpct = options_[idx].uci_options->Get(kResignPercentageId) / 100; if (options_[idx].uci_options->Get(kResignWDLStyleId)) { @@ -296,10 +295,11 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, } // Append training data. The GameResult is later overwritten. NNCacheLock nneval = - search_->GetCachedNNEval(tree_[idx]->GetCurrentHead()); + search_->GetCachedNNEval(tree_[idx]->GetPositionHistory()); training_data_.Add(tree_[idx]->GetCurrentHead(), tree_[idx]->GetPositionHistory(), best_eval, - played_eval, best_is_proof, best_move, move, nneval); + played_eval, best_is_proof, best_move, move, nneval, + search_->GetParams().GetPolicySoftmaxTemp()); } // Must reset the search before mutating the tree. search_.reset(); @@ -312,16 +312,9 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, } std::vector SelfPlayGame::GetMoves() const { - std::vector moves; - for (Node* node = tree_[0]->GetCurrentHead(); - node != tree_[0]->GetGameBeginNode(); node = node->GetParent()) { - moves.push_back(node->GetParent()->GetEdgeToNode(node)->GetMove()); - } std::vector result; Position pos = tree_[0]->GetPositionHistory().Starting(); - while (!moves.empty()) { - Move move = moves.back(); - moves.pop_back(); + for (auto move : tree_[0]->GetMoves()) { if (!chess960_) move = pos.GetBoard().GetLegacyMove(move); pos = Position(pos, move); // Position already flipped, therefore flip the move if white to move. diff --git a/src/selfplay/game.h b/src/selfplay/game.h index 918c328ce1..c05ce10916 100644 --- a/src/selfplay/game.h +++ b/src/selfplay/game.h @@ -105,6 +105,7 @@ class SelfPlayGame { // Node tree for player1 and player2. If the tree is shared between players, // tree_[0] == tree_[1]. std::shared_ptr tree_[2]; + std::string orig_fen_; int start_ply_; diff --git a/src/trainingdata/reader.cc b/src/trainingdata/reader.cc index 64f89e5f82..cac8abb55d 100644 --- a/src/trainingdata/reader.cc +++ b/src/trainingdata/reader.cc @@ -36,7 +36,7 @@ InputPlanes PlanesFromTrainingData(const V6TrainingData& data) { result.back().mask = ReverseBitsInBytes(data.planes[i]); } switch (data.input_format) { - case pblczero::NetworkFormat::InputFormat::INPUT_CLASSICAL_112_PLANE: { + case pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE: { result.emplace_back(); result.back().mask = data.castling_us_ooo != 0 ? ~0LL : 0LL; result.emplace_back(); diff --git a/src/trainingdata/trainingdata.cc b/src/trainingdata/trainingdata.cc index 1285dc7b49..64099d24f4 100644 --- a/src/trainingdata/trainingdata.cc +++ b/src/trainingdata/trainingdata.cc @@ -114,7 +114,8 @@ void V6TrainingDataArray::Write(TrainingDataWriter* writer, GameResult result, void V6TrainingDataArray::Add(const Node* node, const PositionHistory& history, Eval best_eval, Eval played_eval, bool best_is_proven, Move best_move, - Move played_move, const NNCacheLock& nneval) { + Move played_move, const NNCacheLock& nneval, + float softmax_temp) { V6TrainingData result; const auto& position = history.Last(); @@ -146,24 +147,20 @@ void V6TrainingDataArray::Add(const Node* node, const PositionHistory& history, // Set moves probabilities according to their relative amount of visits. // Compute Kullback-Leibler divergence in nats (between policy and visits). float kld_sum = 0; - float max_p = -std::numeric_limits::infinity(); std::vector intermediate; if (nneval) { - int last_idx = 0; + // The cache stores policies in GenerateLegalMoves() order. + auto legal_moves = history.Last().GetBoard().GenerateLegalMoves(); for (const auto& child : node->Edges()) { - auto nn_idx = child.edge()->GetMove().as_nn_index(transform); + auto move = child.edge()->GetMove(); float p = 0; - for (int i = 0; i < nneval->p.size(); i++) { - // Optimization: usually moves are stored in the same order as queried. - const auto& move = nneval->p[last_idx++]; - if (last_idx == nneval->p.size()) last_idx = 0; - if (move.first == nn_idx) { - p = move.second; + for (size_t i = 0; i < legal_moves.size(); i++) { + if (move == legal_moves[i]) { + p = nneval->eval->edges[i].GetP(); break; } } intermediate.emplace_back(p); - max_p = std::max(max_p, p); } } float total = 0.0; @@ -172,7 +169,8 @@ void V6TrainingDataArray::Add(const Node* node, const PositionHistory& history, auto nn_idx = child.edge()->GetMove().as_nn_index(transform); float fracv = total_n > 0 ? child.GetN() / static_cast(total_n) : 1; if (nneval) { - float P = std::exp(*it - max_p); + // Undo any softmax temperature in the cached data. + float P = std::pow(*it, softmax_temp); if (fracv > 0) { kld_sum += fracv * std::log(fracv / P); } @@ -236,9 +234,9 @@ void V6TrainingDataArray::Add(const Node* node, const PositionHistory& history, Eval orig_eval; if (nneval) { - orig_eval.wl = nneval->q; - orig_eval.d = nneval->d; - orig_eval.ml = nneval->m; + orig_eval.wl = nneval->eval->q; + orig_eval.d = nneval->eval->d; + orig_eval.ml = nneval->eval->m; } else { orig_eval.wl = std::numeric_limits::quiet_NaN(); orig_eval.d = std::numeric_limits::quiet_NaN(); diff --git a/src/trainingdata/trainingdata.h b/src/trainingdata/trainingdata.h index 6fc3b3b8a5..601b8a80d9 100644 --- a/src/trainingdata/trainingdata.h +++ b/src/trainingdata/trainingdata.h @@ -28,6 +28,7 @@ #pragma once #include "mcts/node.h" +#include "neural/cache.h" #include "trainingdata/writer.h" namespace lczero { @@ -98,7 +99,7 @@ class V6TrainingDataArray { // Add a chunk. void Add(const Node* node, const PositionHistory& history, Eval best_eval, Eval played_eval, bool best_is_proven, Move best_move, - Move played_move, const NNCacheLock& nneval); + Move played_move, const NNCacheLock& nneval, float softmax_temp); // Writes training data to a file. void Write(TrainingDataWriter* writer, GameResult result, diff --git a/src/utils/cppattributes.h b/src/utils/cppattributes.h index 63a5827ecb..13e95f870b 100644 --- a/src/utils/cppattributes.h +++ b/src/utils/cppattributes.h @@ -27,6 +27,8 @@ #pragma once +#include + // Enable thread safety attributes only with clang. // The attributes can be safely erased when compiling with other compilers. #if defined(__clang__) && (!defined(SWIG)) @@ -37,24 +39,11 @@ #define CAPABILITY(x) ATTRIBUTE__(capability(x)) #define SCOPED_CAPABILITY ATTRIBUTE__(scoped_lockable) -#define GUARDED_BY(x) ATTRIBUTE__(guarded_by(x)) -#define PT_GUARDED_BY(x) ATTRIBUTE__(pt_guarded_by(x)) -#define ACQUIRED_BEFORE(...) ATTRIBUTE__(acquired_before(__VA_ARGS__)) -#define ACQUIRED_AFTER(...) ATTRIBUTE__(acquired_after(__VA_ARGS__)) -#define REQUIRES(...) ATTRIBUTE__(requires_capability(__VA_ARGS__)) -#define REQUIRES_SHARED(...) \ - ATTRIBUTE__(requires_shared_capability(__VA_ARGS__)) #define ACQUIRE(...) ATTRIBUTE__(acquire_capability(__VA_ARGS__)) -#define ACQUIRE_SHARED(...) ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__)) #define RELEASE(...) ATTRIBUTE__(release_capability(__VA_ARGS__)) +#define ACQUIRE_SHARED(...) ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__)) #define RELEASE_SHARED(...) ATTRIBUTE__(release_shared_capability(__VA_ARGS__)) -#define TRY_ACQUIRE(...) ATTRIBUTE__(try_acquire_capability(__VA_ARGS__)) -#define TRY_ACQUIRE_SHARED(...) \ - ATTRIBUTE__(try_acquire_shared_capability(__VA_ARGS__)) -#define EXCLUDES(...) ATTRIBUTE__(locks_excluded(__VA_ARGS__)) -#define ASSERT_CAPABILITY(x) ATTRIBUTE__(assert_capability(x)) -#define ASSERT_SHARED_CAPABILITY(x) ATTRIBUTE__(assert_shared_capability(x)) -#define RETURN_CAPABILITY(x) ATTRIBUTE__(lock_returned(x)) +#define REQUIRES(...) ATTRIBUTE__(requires_capability(__VA_ARGS__)) +#define REQUIRES_SHARED(...) \ + ATTRIBUTE__(requires_shared_capability(__VA_ARGS__)) #define PACKED_STRUCT ATTRIBUTE__(packed) - -#define NO_THREAD_SAFETY_ANALYSIS ATTRIBUTE__(no_thread_safety_analysis) diff --git a/src/utils/fastmath.h b/src/utils/fastmath.h index 3011b556b5..6aa49412aa 100644 --- a/src/utils/fastmath.h +++ b/src/utils/fastmath.h @@ -85,6 +85,13 @@ inline float FastLog(const float a) { // Fast approximate exp(x). Does only limited range checking. inline float FastExp(const float a) { return FastExp2(1.442695040f * a); } +// Safeguarded fast logistic function, based on FastExp(). +inline float FastLogistic(const float a) { + if (a > 20.0f) {return 1.0f;} + if (a < -20.0f) {return 0.0f;} + return 1.0f / (1.0f + FastExp(-a)); +} + inline float FastSign(const float a) { // Microsoft compiler does not have a builtin for copysign and emits a // library call which is too expensive for hot paths. diff --git a/src/utils/fp16_utils.cc b/src/utils/fp16_utils.cc deleted file mode 100644 index 826389b1e9..0000000000 --- a/src/utils/fp16_utils.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* - This file is part of Leela Chess Zero. - Copyright (C) 2020 The LCZero Authors - - Leela Chess is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - Leela Chess is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with Leela Chess. If not, see . - - Additional permission under GNU GPL version 3 section 7 - - If you modify this Program, or any covered work, by linking or - combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA - Toolkit and the NVIDIA CUDA Deep Neural Network library (or a - modified version of those libraries), containing parts covered by the - terms of the respective license agreement, the licensors of this - Program grant you additional permission to convey the resulting work. -*/ - -#include -#include - -// Define NO_F16C to avoid the F16C intrinsics. Also disabled with NO_POPCNT -// since it catches most processors without F16C instructions. - -#if defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || \ - defined(__x86_64__) -#include -#else -#define NO_F16C -#endif - -namespace lczero { - -uint16_t FP32toFP16(float f32) { -#if defined(NO_POPCNT) || defined(NO_F16C) || \ - (defined(__GNUC__) && !defined(__F16C__)) - unsigned int x; - unsigned int sign = 0; - memcpy(&x, &f32, sizeof(float)); - if (x & 0x80000000) sign = 0x8000; - x &= 0x7fffffff; - if (x >= 0x477ff000) { - if ((x & 0x7f800000) == 0x7f800000 && (x & 0x7fffff)) { - x = ((x >> 13) - 0x38000) | 0x200; - } else { - x = 0x7c00; - } - } else if (x <= 0x33000000) - x = 0; - else if (x <= 0x387fefff) { - int shift = 126 - ((x >> 23) & 0xff); - x = (x & 0x7fffff) | 0x800000; - if (x & (0x17fffff >> (24 - shift))) x += 0x800000 >> (24 - shift); - x >>= shift; - } else { - // Adjust exponent and round to nearest even. - if (x & 0x2fff) { - x -= 0x37fff000; - } else { - x -= 0x38000000; - } - x >>= 13; - } - return x | sign; -#else - __m128 A = _mm_set_ss(f32); - __m128i H = _mm_cvtps_ph(A, 0); - return _mm_extract_epi16(H, 0); -#endif -} - -float FP16toFP32(uint16_t f16) { -#if defined(NO_POPCNT) || defined(NO_F16C) || \ - (defined(__GNUC__) && !defined(__F16C__)) - unsigned int x; - float f; - x = f16 & 0x7fff; - if ((x & 0x7c00) == 0) { - f = 5.9604645e-8f * x; - memcpy(&x, &f, sizeof(float)); - } else if (x >= 0x7c00) { - if (x & 0x1ff) x |= 0x200; - x = (x + 0x38000) << 13; - } else { - x = (x + 0x1c000) << 13; - } - if (f16 & 0x8000) x |= 0x80000000; - memcpy(&f, &x, sizeof(float)); - return f; -#else - __m128i H = _mm_setzero_si128(); - H = _mm_insert_epi16(H, f16, 0); - __m128 A = _mm_cvtph_ps(H); - return _mm_cvtss_f32(A); -#endif -} - -} // namespace lczero diff --git a/src/utils/fp16_utils.h b/src/utils/fp16_utils.h index fadf83d031..2680536599 100644 --- a/src/utils/fp16_utils.h +++ b/src/utils/fp16_utils.h @@ -25,9 +25,88 @@ Program grant you additional permission to convey the resulting work. */ #pragma once + +#include +#include + +// Define NO_F16C to avoid the F16C intrinsics. Also disabled with NO_POPCNT +// since it catches most processors without F16C instructions. +#if defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || \ + defined(__x86_64__) +#include +#else +#define NO_F16C +#endif + namespace lczero { -uint16_t FP32toFP16(float f32); -float FP16toFP32(uint16_t f16); +#if defined(NO_POPCNT) || defined(NO_F16C) || \ + (defined(__GNUC__) && !defined(__F16C__)) + +inline uint16_t FP32toFP16(float f32) { + unsigned int x; + unsigned int sign = 0; + memcpy(&x, &f32, sizeof(float)); + if (x & 0x80000000) sign = 0x8000; + x &= 0x7fffffff; + if (x >= 0x477ff000) { + if ((x & 0x7f800000) == 0x7f800000 && (x & 0x7fffff)) { + x = ((x >> 13) - 0x38000) | 0x200; + } else { + x = 0x7c00; + } + } else if (x <= 0x33000000) + x = 0; + else if (x <= 0x387fefff) { + int shift = 126 - ((x >> 23) & 0xff); + x = (x & 0x7fffff) | 0x800000; + if (x & (0x17fffff >> (24 - shift))) x += 0x800000 >> (24 - shift); + x >>= shift; + } else { + // Adjust exponent and round to nearest even. + if (x & 0x2fff) { + x -= 0x37fff000; + } else { + x -= 0x38000000; + } + x >>= 13; + } + return x | sign; +} + +inline float FP16toFP32(uint16_t f16) { + unsigned int x; + float f; + x = f16 & 0x7fff; + if ((x & 0x7c00) == 0) { + f = 5.9604645e-8f * x; + memcpy(&x, &f, sizeof(float)); + } else if (x >= 0x7c00) { + if (x & 0x1ff) x |= 0x200; + x = (x + 0x38000) << 13; + } else { + x = (x + 0x1c000) << 13; + } + if (f16 & 0x8000) x |= 0x80000000; + memcpy(&f, &x, sizeof(float)); + return f; +} + +#else + +inline uint16_t FP32toFP16(float f32) { + __m128 A = _mm_set_ss(f32); + __m128i H = _mm_cvtps_ph(A, 0); + return _mm_extract_epi16(H, 0); +} + +inline float FP16toFP32(uint16_t f16) { + __m128i H = _mm_setzero_si128(); + H = _mm_insert_epi16(H, f16, 0); + __m128 A = _mm_cvtph_ps(H); + return _mm_cvtss_f32(A); +} + +#endif } // namespace lczero diff --git a/src/utils/numa.cc b/src/utils/numa.cc index 3a9699a171..f373fb2c97 100644 --- a/src/utils/numa.cc +++ b/src/utils/numa.cc @@ -41,7 +41,7 @@ int Numa::threads_per_core_ = 1; void Numa::Init() { #if defined(_WIN64) && _WIN32_WINNT >= 0x0601 SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX* buffer; - DWORD len; + DWORD len = 0; GetLogicalProcessorInformationEx(RelationProcessorCore, NULL, &len); buffer = static_cast(malloc(len)); GetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &len); diff --git a/src/utils/optionsparser.cc b/src/utils/optionsparser.cc index e725348555..2c76d35484 100644 --- a/src/utils/optionsparser.cc +++ b/src/utils/optionsparser.cc @@ -324,8 +324,12 @@ bool StringOption::ProcessShortFlagWithValue(char flag, } std::string StringOption::GetHelp(const OptionsDict& dict) const { - return FormatFlag(GetShortFlag(), GetLongFlag() + "=STRING", GetHelpText(), - GetUciOption(), GetVal(dict)); + std::string long_flag = GetLongFlag(); + if (!long_flag.empty()) { + long_flag += "=STRING"; + } + return FormatFlag(GetShortFlag(), long_flag, GetHelpText(), GetUciOption(), + GetVal(dict)); } std::string StringOption::GetOptionString(const OptionsDict& dict) const { diff --git a/src/utils/protomessage.cc b/src/utils/protomessage.cc index dfa1b8aaf2..224ea20ffb 100644 --- a/src/utils/protomessage.cc +++ b/src/utils/protomessage.cc @@ -100,26 +100,83 @@ void ProtoMessage::MergeFromString(std::string_view str) { } void ProtoMessage::AppendVarInt(int field_id, std::uint64_t value, - std::string* out) const { + std::string* out) { *out += EncodeVarInt(field_id << 3); *out += EncodeVarInt(value); } void ProtoMessage::AppendInt64(int field_id, std::uint64_t value, - std::string* out) const { + std::string* out) { *out += EncodeVarInt(1 + (field_id << 3)); WriteFixed(value, 8, out); } void ProtoMessage::AppendInt32(int field_id, std::uint32_t value, - std::string* out) const { + std::string* out) { *out += EncodeVarInt(5 + (field_id << 3)); WriteFixed(value, 4, out); } void ProtoMessage::AppendString(int field_id, std::string_view value, - std::string* out) const { + std::string* out) { *out += EncodeVarInt(2 + (field_id << 3)); *out += EncodeVarInt(value.size()); *out += value; } +void ProtoMessage::AppendJsonFieldPrefix(const std::string& name, + bool* is_first, std::string* out) { + if (*is_first) { + out->append(","); + *is_first = false; + } + AppendJsonValue(name, out); + out->append(":"); +} + +namespace { +std::string EscapeJsonString(const std::string& str) { + static const char kHex[] = "0123456789abcdef"; + std::string out; + for (const auto c : str) { + if (c >= 0 && c <= ' ') { + out += std::string("\\u00") + kHex[c / 16] + kHex[c % 16]; + } else if (c == '\\') { + out += "\\\\"; + } else if (c == '"') { + out += "\\\""; + } else { + out += c; + } + } + return out; +} + +} // namespace + +void ProtoMessage::AppendJsonValue(const std::string& val, std::string* out) { + out->append("\""); + out->append(EscapeJsonString(val)); + out->append("\""); +} +void ProtoMessage::AppendJsonValue(bool val, std::string* out) { + out->append(val ? "true" : "false"); +} +void ProtoMessage::AppendJsonValue(double val, std::string* out) { + out->append(std::to_string(val)); +} +void ProtoMessage::AppendJsonValue(uint64_t val, std::string* out) { + out->append(std::to_string(val)); +} +void ProtoMessage::AppendJsonValue(int64_t val, std::string* out) { + out->append(std::to_string(val)); +} +void ProtoMessage::AppendJsonValue(uint32_t val, std::string* out) { + out->append(std::to_string(val)); +} +void ProtoMessage::AppendJsonValue(int32_t val, std::string* out) { + out->append(std::to_string(val)); +} +void ProtoMessage::AppendJsonValue(const ProtoMessage& val, std::string* out) { + out->append(val.OutputAsJson()); +} + } // namespace lczero \ No newline at end of file diff --git a/src/utils/protomessage.h b/src/utils/protomessage.h index 2d90f06546..a25c0c3f57 100644 --- a/src/utils/protomessage.h +++ b/src/utils/protomessage.h @@ -1,12 +1,12 @@ #pragma once +#include #include #include #include #include #include #include -#include #include // Undef g++ macros to ged rid of warnings. @@ -27,6 +27,7 @@ class ProtoMessage { void ParseFromString(std::string_view); void MergeFromString(std::string_view); virtual std::string OutputAsString() const = 0; + virtual std::string OutputAsJson() const = 0; protected: template @@ -40,17 +41,46 @@ class ProtoMessage { } } - void AppendVarInt(int field_id, std::uint64_t value, std::string* out) const; - void AppendInt64(int field_id, std::uint64_t value, std::string* out) const; - void AppendInt32(int field_id, std::uint32_t value, std::string* out) const; - void AppendString(int field_id, std::string_view value, - std::string* out) const; + static void AppendVarInt(int field_id, std::uint64_t value, std::string* out); + static void AppendInt64(int field_id, std::uint64_t value, std::string* out); + static void AppendInt32(int field_id, std::uint32_t value, std::string* out); + static void AppendString(int field_id, std::string_view value, + std::string* out); + template + static void AppendJsonRepeatedField(const std::string& name, + const std::vector& val, bool* is_first, + std::string* out) { + AppendJsonFieldPrefix(name, is_first, out); + out->append("["); + for (std::size_t i = 0; i < val.size(); ++i) { + if (i > 0) out->append(","); + AppendJsonValue(val[i], out); + } + out->append("]"); + } + template + static void AppendJsonField(const std::string& name, const T& val, + bool* is_first, std::string* out) { + AppendJsonFieldPrefix(name, is_first, out); + AppendJsonValue(val, out); + } private: virtual void SetVarInt(int /* field_id */, uint64_t /* value */) {} virtual void SetInt64(int /* field_id */, uint64_t /* value */) {} virtual void SetInt32(int /* field_id */, uint32_t /* value */) {} virtual void SetString(int /* field_id */, std::string_view /* value */) {} + + static void AppendJsonFieldPrefix(const std::string& name, bool* is_first, + std::string* out); + static void AppendJsonValue(const std::string& val, std::string* out); + static void AppendJsonValue(bool val, std::string* out); + static void AppendJsonValue(double val, std::string* out); + static void AppendJsonValue(uint64_t val, std::string* out); + static void AppendJsonValue(int64_t val, std::string* out); + static void AppendJsonValue(uint32_t val, std::string* out); + static void AppendJsonValue(int32_t val, std::string* out); + static void AppendJsonValue(const ProtoMessage& val, std::string* out); }; } // namespace lczero \ No newline at end of file diff --git a/src/version.inc b/src/version.inc index cf3a3aab94..56f884bdc6 100644 --- a/src/version.inc +++ b/src/version.inc @@ -1,4 +1,4 @@ #define LC0_VERSION_MAJOR 0 -#define LC0_VERSION_MINOR 30 +#define LC0_VERSION_MINOR 31 #define LC0_VERSION_PATCH 0 -#define LC0_VERSION_POSTFIX "dev" +#define LC0_VERSION_POSTFIX "dag" diff --git a/subprojects/abseil-cpp.wrap b/subprojects/abseil-cpp.wrap new file mode 100644 index 0000000000..4e3e63bcb3 --- /dev/null +++ b/subprojects/abseil-cpp.wrap @@ -0,0 +1,23 @@ +[wrap-file] +directory = abseil-cpp-20211102.0 +source_url = https://github.com/abseil/abseil-cpp/archive/20211102.0.tar.gz +source_filename = abseil-cpp-20211102.0.tar.gz +source_hash = dcf71b9cba8dc0ca9940c4b316a0c796be8fab42b070bb6b7cab62b48f0e66c4 +patch_filename = abseil-cpp_20211102.0-3_patch.zip +patch_url = https://wrapdb.mesonbuild.com/v2/abseil-cpp_20211102.0-3/get_patch +patch_hash = 3b49ff46df64414bac034b58bf36a07457375b6f8d8b5158e341157c6f9efad2 + +[provide] +absl_base = absl_base_dep +absl_container = absl_container_dep +absl_debugging = absl_debugging_dep +absl_flags = absl_flags_dep +absl_hash = absl_hash_dep +absl_numeric = absl_numeric_dep +absl_random = absl_random_dep +absl_status = absl_status_dep +absl_strings = absl_strings_dep +absl_synchronization = absl_synchronization_dep +absl_time = absl_time_dep +absl_types = absl_types_dep + diff --git a/subprojects/eigen.wrap b/subprojects/eigen.wrap index 1973b8998a..e46839c90b 100644 --- a/subprojects/eigen.wrap +++ b/subprojects/eigen.wrap @@ -1,10 +1,12 @@ [wrap-file] -directory = eigen-3.3.7 +directory = eigen-3.4.0 +source_url = https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.bz2 +source_filename = eigen-3.4.0.tar.bz2 +source_hash = b4c198460eba6f28d34894e3a5710998818515104d6e74e5cc331ce31e46e626 +patch_filename = eigen_3.4.0-1_patch.zip +patch_url = https://wrapdb.mesonbuild.com/v2/eigen_3.4.0-1/get_patch +patch_hash = fae999acdb3ea23eada3becdbde7f7f76755e94ad85fee7775b7ab1cf12e84e3 -source_url = https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.tar.bz2 -source_filename = eigen-3.3.7.tar.bz2 -source_hash = 685adf14bd8e9c015b78097c1dc22f2f01343756f196acdc76a678e1ae352e11 +[provide] +eigen3 = eigen_dep -patch_url = https://github.com/borg323/eigen/files/5124100/eigen-3.3.7-1u-wrap.zip -patch_filename = eigen-3.3.7-1u-wrap.zip -patch_hash = f9d8cc10cb135a5cc3761f316e84d3e3e42949970245f4256470b8ec71eac3e3