diff --git a/.circleci/config.yml b/.circleci/config.yml
index c96fcc903b..6743e5ef02 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -13,38 +13,54 @@ jobs:
- run:
name: Update Meson
command: pip3 install --upgrade meson==0.58.1
- - run:
- name: Create Meson build dirs
- command: mkdir build-gcc && mkdir build-clang
- - run:
- name: Meson Clang
- environment:
- CC: clang
- CXX: clang++
- command: meson build-clang
- run:
name: Meson GCC
environment:
CC: gcc-8
CXX: g++-8
- command: meson build-gcc
- - run:
- name: Build Clang
- command: |
- cd build-clang
- ninja
+ command: meson build-gcc -Dgtest=false
- run:
name: Build GCC
command: |
cd build-gcc
ninja -j 4
+ "mac":
+ macos:
+ xcode: 13.4.1
+ steps:
+ - checkout
- run:
- command: cp build-clang/lc0 /tmp/lc0-clang
+ name: "Pull Submodules"
+ command: |
+ git submodule init
+ git submodule update --remote
- run:
- command: cp build-gcc/lc0 /tmp/lc0-g++
- - store_artifacts:
- path: /tmp/lc0-clang
- destination: lc0-ubuntu-18-04-clang
+ name: Install build tools
+ command: |
+ pip3 install meson==0.63
+ pip3 install ninja
+ brew install ispc
+ - run:
+ name: Build lc0
+ command: |
+ meson build --buildtype=release -Dgtest=false -Dopencl=false
+ cd build
+ ninja
+ - run:
+ name: Build lc0 arm
+ command: |
+ meson build-arm --buildtype=release -Dgtest=false -Dopencl=false --cross-file cross-files/aarch64-darwin
+ cd build-arm
+ ninja
+ - run:
+ name: Make universal binary
+ command: lipo -create -o /tmp/lc0 build/lc0 build-arm/lc0
- store_artifacts:
- path: /tmp/lc0-g++
- destination: lc0-ubuntu-18-04-g++
+ path: /tmp/lc0
+ destination: lc0-macos_12.3.1
+workflows:
+ version: 2
+ builds:
+ jobs:
+ - build
+ - "mac"
diff --git a/.gitignore b/.gitignore
index c90b403032..ea18c9ee58 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,4 +16,5 @@ testdata/
xcuserdata
.clang-tidy
compile_flags.txt
-.vscode
\ No newline at end of file
+.vscode
+.mesonpy*
\ No newline at end of file
diff --git a/FLAGS.md b/FLAGS.md
index 7d687657f2..6a974ccd4b 100644
--- a/FLAGS.md
+++ b/FLAGS.md
@@ -41,7 +41,6 @@ List of command line flags:
| --slowmover=NUM | Scale thinking time | Parameter value `X` means that the whole remaining time is split in such a way that the current move gets `X × Y` seconds, and next moves will get `1 × Y` seconds. However, due to smart pruning, the engine usually doesn't use all allocated time.
Default: `2.2`|
| --move-overhead=NUM | Move time overhead in milliseconds | How much overhead should the engine allocate for every move (to counteract things like slow connection, interprocess communication, etc.).
Default: `100` ms. |
| --minibatch-size=NUM | Minibatch size for NN inference | How many positions the engine tries to batch together for computation. Theoretically larger batches may reduce strengths a bit, especially on a small number of playouts.
Default is `256`. Every backend/hardware has different optimal value (e.g., `1` if batching is not supported). |
-| --max-prefetch=NUM | Maximum prefetch nodes per NN call | When the engine can't gather a large enough batch for immediate use, try to prefetch up to `X` positions, which are likely to be useful soon, and put them in the cache.
Default: `32`. |
| --cpuct=NUM | Cpuct MCTS option | C_puct constant from Upper Confidence Tree search algorithm. Higher values promote more exploration/wider search, lower values promote more confidence/deeper search.
Default: `1.2`. |
| --temperature=NUM | Initial temperature | Tau value from softmax formula. If equal to 0, the engine also picks the best move to make. Larger values increase randomness while making the move.
Default: `0` |
| --tempdecay-moves=NUM | Moves with temperature decay | Reduce temperature for every move linearly from initial temperature to `0`, during this number of moves since the game started. `0` disables temperature decay.
Default: `0` |
diff --git a/README.md b/README.md
index 8a76835ee9..1c2c6cb406 100644
--- a/README.md
+++ b/README.md
@@ -9,27 +9,27 @@ Lc0 is a UCI-compliant chess engine designed to play chess via neural network, s
Lc0 can be acquired either via a git clone or an archive download from GitHub. Be aware that there is a required submodule which isn't included in source archives.
-For essentially all purposes, including selfplay game generation and match play, we highly recommend using the latest `release/version` branch (for example `release/0.28`), which is equivalent to using the latest version tag.
+For essentially all purposes, including selfplay game generation and match play, we highly recommend using the latest `release/version` branch (for example `release/0.29`), which is equivalent to using the latest version tag.
Versioning follows the Semantic Versioning guidelines, with major, minor and patch sections. The training server enforces game quality using the versions output by the client and engine.
Download using git:
-```
-git clone -b release/0.28 --recurse-submodules https://github.com/LeelaChessZero/lc0.git
+```shell
+git clone -b release/0.29 --recurse-submodules https://github.com/LeelaChessZero/lc0.git
```
If you have cloned already an old version, fetch, view and checkout a new branch:
-```
+```shell
git fetch --all
git branch --all
-git checkout -t remotes/origin/release/0.28
+git checkout -t remotes/origin/release/0.29
```
If you prefer to download an archive, you need to also download and place the submodule:
- * Download the [.zip](https://api.github.com/repos/LeelaChessZero/lc0/zipball/release/0.28) file ([.tar.gz](https://api.github.com/repos/LeelaChessZero/lc0/tarball/release/0.28) archive is also available)
+ * Download the [.zip](https://api.github.com/repos/LeelaChessZero/lc0/zipball/release/0.29) file ([.tar.gz](https://api.github.com/repos/LeelaChessZero/lc0/tarball/release/0.29) archive is also available)
* Extract
* Download https://github.com/LeelaChessZero/lczero-common/archive/master.zip (also available as [.tar.gz](https://github.com/LeelaChessZero/lczero-common/archive/master.tar.gz))
* Move the second archive into the first archive's `libs/lczero-common/` folder and extract
@@ -48,7 +48,7 @@ Backend support includes (in theory) any CBLAS-compatible library for CPU usage,
Finally, lc0 requires a compiler supporting C++17. Minimal versions seem to be g++ v8.0, clang v5.0 (with C++17 stdlib) or Visual Studio 2017.
-*Note* that cuda checks the compiler version and stops even with newer compilers, and to work around this we have added the `nvcc_ccbin` build option. This is more of an issue with new Linux versions, where we recommend to install `g++-7` and add `-Dnvcc_ccbin=g++-7` to the `build.sh` command.
+*Note* that cuda checks the compiler version and stops even with newer compilers, and to work around this we have added the `nvcc_ccbin` build option. This is more of an issue with new Linux versions, but you can get around it by using an earlier version of gcc just for cuda. As an example, adding `-Dnvcc_ccbin=g++-9` to the `build.sh` command line will use g++-9 with cuda instead of the system compiler.
Given those basics, the OS and backend specific instructions are below.
@@ -179,7 +179,7 @@ You'll need to be running the latest Raspberry Pi OS "buster".
1. Install OpenBLAS
-```
+```shell
git clone https://github.com/xianyi/OpenBLAS.git
cd OpenBLAS/
make
@@ -189,20 +189,20 @@ cd ..
2. Install Meson
-```
-pip3 install meson
-pip3 install ninja
+```shell
+pip install meson
+pip install ninja
```
3. Install compiler and standard libraries
-```
+```shell
sudo apt install clang-6.0 libstdc++-8-dev
```
4. Clone lc0 and compile
-```
+```shell
git clone https://github.com/LeelaChessZero/lc0.git
cd lc0
git submodule update --init --recursive
@@ -211,6 +211,18 @@ CC=clang-6.0 CXX=clang++-6.0 ./build.sh -Ddefault_library=static
5. The resulting binary will be in build/release
+## Python bindings
+
+Python bindings can be built and installed as follows.
+
+```shell
+pip install --user git+https://github.com/LeelaChessZero/lc0.git
+```
+
+This will build the package `lczero-bindings` and install it to your Python user install directory.
+All the `lc0` functionality related to position evaluation is now available in the module `lczero.backends`.
+An example interactive session can be found [here](https://github.com/LeelaChessZero/lc0/pull/1261#issuecomment-622951248).
+
## License
Leela Chess is free software: you can redistribute it and/or modify
diff --git a/appveyor.yml b/appveyor.yml
index fb64dc3d18..e111a6bdb6 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -109,17 +109,22 @@ cache:
- C:\ndk\android-ndk-r19c\toolchains\llvm\prebuilt\windows-x86_64
before_build:
- cmd: git submodule update --init --recursive
-- cmd: IF %BLAS%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
-- cmd: IF %ANDROID%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
+- cmd: IF %BLAS%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
+- cmd: IF %ANDROID%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
- cmd: SET BUILD_BLAS=%BLAS%
- cmd: IF %OPENCL%==true SET BUILD_BLAS=true
- cmd: IF %DX%==true SET BUILD_BLAS=true
- cmd: SET EMBED=false
- cmd: IF %APPVEYOR_REPO_TAG%==true IF %ANDROID%==true SET EMBED=true
+- cmd: SET POPCNT=true
+- cmd: IF %NAME%==cpu-openblas SET POPCNT=false
+- cmd: SET F16C=true
+- cmd: IF %NAME%==cpu-openblas SET F16C=false
+- cmd: IF %CUDA%==true SET F16C=false
- cmd: SET EXTRA=
- cmd: IF %ANDROID%==false SET EXTRA=-Db_vscrt=md
- cmd: IF %ONNX_DML%==true SET EXTRA=-Db_vscrt=md -Donnx_libdir=C:\cache\%ONNX_NAME%\lib -Donnx_include=C:\cache\%ONNX_NAME%\include
-- cmd: IF %ANDROID%==false meson build --backend vs2017 --buildtype release -Dgtest=%GTEST% -Dopencl=%OPENCL% -Dblas=%BUILD_BLAS% -Ddnnl=true -Ddx=%DX% -Dcudnn=%CUDNN% -Donednn=%ONEDNN% -Dispc_native_only=false -Dpopcnt=false -Dcudnn_include="%CUDA_PATH%\include","%CUDA_PATH%\cuda\include" -Dcudnn_libdirs="%CUDA_PATH%\lib\x64","%CUDA_PATH%\cuda\lib\x64" -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\dist64\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\dist64\lib" -Ddnnl_dir="%PKG_FOLDER%\%DNNL_NAME%" -Dopencl_include="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\include" -Dopencl_libdirs="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\lib\x64" -Ddefault_library=static -Dmalloc=mimalloc -Dmimalloc_libdir="%MIMALLOC_PATH%"\out\msvc-x64\Release %EXTRA%
+- cmd: IF %ANDROID%==false meson build --backend vs2017 --buildtype release -Dgtest=%GTEST% -Dopencl=%OPENCL% -Dblas=%BUILD_BLAS% -Ddnnl=true -Ddx=%DX% -Dcudnn=%CUDNN% -Donednn=%ONEDNN% -Dispc_native_only=false -Dpopcnt=%POPCNT% -Df16c=%F16C% -Dcudnn_include="%CUDA_PATH%\include","%CUDA_PATH%\cuda\include" -Dcudnn_libdirs="%CUDA_PATH%\lib\x64","%CUDA_PATH%\cuda\lib\x64" -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\dist64\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\dist64\lib" -Ddnnl_dir="%PKG_FOLDER%\%DNNL_NAME%" -Dopencl_include="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\include" -Dopencl_libdirs="%PKG_FOLDER%\opencl-nug.0.777.77\build\native\lib\x64" -Ddefault_library=static -Dmalloc=mimalloc -Dmimalloc_libdir="%MIMALLOC_PATH%"\out\msvc-x64\Release %EXTRA%
- cmd: IF %ANDROID%==true meson arm64-v8a --buildtype release -Dgtest=false -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\android-aarch64\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\android-aarch64\lib" -Dembed=%EMBED% -Ddefault_library=static --cross-file crossfile-aarch64
- cmd: IF %ANDROID%==true meson armeabi-v7a --buildtype release -Dgtest=false -Dopenblas_include="%PKG_FOLDER%\OpenBLAS\android-armv7a\include" -Dopenblas_libdirs="%PKG_FOLDER%\OpenBLAS\android-armv7a\lib" -Dembed=%EMBED% -Ddefault_library=static --cross-file crossfile-armv7a -Dispc=false -Dneon=false
build_script:
diff --git a/build.cmd b/build.cmd
index 9f0d3d3144..9b607e01be 100644
--- a/build.cmd
+++ b/build.cmd
@@ -30,7 +30,11 @@ set CXX=cl
set CC_LD=link
set CXX_LD=link
-if exist "C:\Program Files (x86)\Microsoft Visual Studio\2019" (
+if exist "C:\Program Files\Microsoft Visual Studio\2022" (
+ where /q cl
+ if errorlevel 1 call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64
+ set backend=vs2022
+) else if exist "C:\Program Files (x86)\Microsoft Visual Studio\2019" (
where /q cl
if errorlevel 1 call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64
set backend=vs2019
diff --git a/build.sh b/build.sh
index deca6e99db..419c0229d2 100755
--- a/build.sh
+++ b/build.sh
@@ -1,9 +1,10 @@
#!/usr/bin/env bash
-pushd "$(dirname "$0")"
-
set -e
+# Move to this script's directory.
+CDPATH= cd -- "$(dirname -- "$0")"
+
case $1 in
plain|debug|debugoptimized|release|minsize)
BUILDTYPE=$1
@@ -16,27 +17,15 @@ esac
BUILDDIR=build/${BUILDTYPE}
-if ! hash meson 2>/dev/null && [ -x ${HOME}/.local/bin/meson ]
-then
- export PATH=${PATH}:${HOME}/.local/bin
-fi
-
-if [ -f ${BUILDDIR}/build.ninja ]
-then
- meson configure ${BUILDDIR} -Dbuildtype=${BUILDTYPE} -Dprefix=${INSTALL_PREFIX:-/usr/local} "$@"
-else
- meson ${BUILDDIR} --buildtype ${BUILDTYPE} --prefix ${INSTALL_PREFIX:-/usr/local} "$@"
-fi
-
-cd ${BUILDDIR}
-
-NINJA=$(awk '/ninja/ {ninja=$4} END {print ninja}' meson-logs/meson-log.txt)
+MESON=$(PATH="${PATH}:${HOME}/.local/bin" command -v meson || :)
+MESON=${MESON:?"Could not find meson. Is it installed and in PATH?"}
-if [ -n "${INSTALL_PREFIX}" ]
+if [ -f "${BUILDDIR}/build.ninja" ]
then
- ${NINJA} install
+ "${MESON}" configure "${BUILDDIR}" -Dbuildtype="${BUILDTYPE}" -Dprefix="${INSTALL_PREFIX:-/usr/local}" "$@"
else
- ${NINJA}
+ "${MESON}" "${BUILDDIR}" --buildtype "${BUILDTYPE}" --prefix "${INSTALL_PREFIX:-/usr/local}" "$@"
fi
-popd
+"${MESON}" compile -C "${BUILDDIR}"
+[ -n "${INSTALL_PREFIX}" ] && "${MESON}" install -C "${BUILDDIR}"
diff --git a/changelog.txt b/changelog.txt
index 39444f5848..c869b9f5c5 100644
--- a/changelog.txt
+++ b/changelog.txt
@@ -1,4 +1,52 @@
-v0.29.0-rc0 (2022-04-03)
+v0.30.0-rc1 (2023-04-24)
+~~~~~~~
+* Support for networks with attention body and smolgen added to blas, cuda,
+ metal and onnx backends.
+* Persistent L2 cache optimization for the cuda backend. Use the
+ `cache_opt=true` backend option to turn it on.
+* Some performance improvements for the cuda, onnx and blas backends.
+* Added the `threads` backend option to onnx, defaults to 0 (let the
+ onnxruntime decide) except for onnx-cpu that defaults to 1.
+* The onnx-dml package now includes a `directml.dll` installation script.
+* Some users experienced memory issues with onnx-dml, so the defaults were
+ changed. This may affect performance, in which case you can use the `steps=8`
+ backend option to get the old behavior.
+* The Python bindings are available as a package, see the README for
+ instructions.
+* Some assorted fixes and code cleanups.
+
+v0.29.0 (2022-12-13)
+~~~~~~~
+* Updated onednn version to the latest one.
+
+v0.29.0-rc1 (2022-12-09)
+~~~~~~~
+* New metal backend for apple systems. This is now the default backend for
+ macos builds.
+* New onnx-dml backend to use DirectML under windows, has better net
+ compatibility than dx12 and is faster than opencl. See the README for use
+ instructions, a separate download of the DirectML dll is required.
+* Full attention policy support in cuda, cudnn, metal, onnx, blas, dnnl, and
+ eigen backends.
+* Partial attention policy support in onednn backend (good enough for T79).
+* Now the onnx backends can use fp16 when running with a network file (not with
+ .onnx model files). This is the default for onnx-cuda and onnx-dml, can be
+ switched on or off with by setting the `fp16` backend option to `true` or
+ `false` respectively.
+* The onednn package comes with a dnnl compiled to allow running on an intel gpu
+ by adding `gpu=0` to the backend options.
+* The default net is now 791556 for most backends except opencl and dx12 that
+ get 753723 (as they lack attention policy support).
+* Support for using pgn book with long lines in training: selfplay can start at
+ a random point in the book.
+* New "simple" time manager.
+* Support for double Fischer random chess (dfrc).
+* Added TC-dependent output to the backendbench assistant.
+* Starting with this version, the check backend compares policy for valid moves
+ after softmax.
+* Some assorted fixes and code cleanups.
+
+v0.29.0-rc0 (2022-04-03)
~~~~~~~
* Initial support for attention policy, only cuda backend and partially in
blas/dnnl/eigen (good enough for T79).
diff --git a/cross-files/aarch64-darwin b/cross-files/aarch64-darwin
new file mode 100644
index 0000000000..441f101d56
--- /dev/null
+++ b/cross-files/aarch64-darwin
@@ -0,0 +1,27 @@
+
+[host_machine]
+system = 'darwin'
+cpu_family = 'aarch64'
+cpu = 'aarch64'
+endian = 'little'
+
+[properties]
+
+
+[binaries]
+c = 'clang'
+cpp = 'clang++'
+objc = 'clang'
+objcpp = 'clang++'
+ar = 'ar'
+ld = 'ld'
+
+[built-in options]
+c_args = ['-arch', 'arm64']
+cpp_args = ['-arch', 'arm64']
+objc_args = ['-arch', 'arm64']
+objcpp_args = ['-arch', 'arm64']
+c_link_args = ['-arch', 'arm64']
+cpp_link_args = ['-arch', 'arm64']
+objc_link_args = ['-arch', 'arm64']
+objcpp_link_args = ['-arch', 'arm64']
diff --git a/cross-files/aarch64-linux-android b/cross-files/aarch64-linux-android
index 40e1faca3a..4a55d838de 100644
--- a/cross-files/aarch64-linux-android
+++ b/cross-files/aarch64-linux-android
@@ -7,7 +7,7 @@
[host_machine]
system = 'android'
-cpu_family = 'arm'
+cpu_family = 'aarch64'
cpu = 'aarch64'
endian = 'little'
diff --git a/dist/README-onnx-dml.txt b/dist/README-onnx-dml.txt
index 6b721d2c8e..5e34b3eb52 100644
--- a/dist/README-onnx-dml.txt
+++ b/dist/README-onnx-dml.txt
@@ -4,8 +4,9 @@ Lc0 is a UCI-compliant chess engine designed to play chess via
neural network, specifically those of the LeelaChessZero project
(https://lczero.org).
-To run this version you will most likely need a very recent DirectML dll.
-You can download the currently latest nuget installer package from
+To run this version you will most likely need a very recent DirectML dll,
+which you can get by running the included `install.cmd` script. Alternatively,
+you can download the currently latest nuget installer package from
.
If you don't know how to use nuget installer packages, you can change the
extension to .zip and open it as a normal zip file, the dll you need is
diff --git a/dist/install-dml.cmd b/dist/install-dml.cmd
new file mode 100644
index 0000000000..099f42958c
--- /dev/null
+++ b/dist/install-dml.cmd
@@ -0,0 +1,31 @@
+@echo off
+where /q tar
+if errorlevel 1 goto error
+
+where /q lc0.exe
+if errorlevel 1 cd /d %~dp0
+where /q lc0.exe
+if errorlevel 1 (
+ echo This script must run in the lc0 folder.
+ pause
+ exit /b
+)
+
+cls
+echo Installing the DirectML.dll version required by the Lc0 onnx-dml backend.
+curl -# --ssl-no-revoke -o tmp_directml.zip https://globalcdn.nuget.org/packages/microsoft.ai.directml.1.10.0.nupkg"
+if errorlevel 1 goto error
+
+tar -xzOf tmp_directml.zip bin/x64-win/DirectML.dll >DirectML.dll
+if errorlevel 1 goto error
+
+del /q tmp_directml.zip
+echo Installation successful.
+pause
+exit /b
+
+:error
+cls
+echo Installation failed - see the README for an alternative approach.
+pause
+
diff --git a/libs/lczero-common b/libs/lczero-common
index 4dfa4ce833..fafda0f59c 160000
--- a/libs/lczero-common
+++ b/libs/lczero-common
@@ -1 +1 @@
-Subproject commit 4dfa4ce8339357819f7de01517e6297d4c768cdf
+Subproject commit fafda0f59c8511b5d933ef758c1e4b10a62da1e0
diff --git a/meson.build b/meson.build
index 49b60c1cd6..f073c42705 100644
--- a/meson.build
+++ b/meson.build
@@ -16,7 +16,7 @@
project('lc0', 'cpp',
default_options : ['cpp_std=c++17', 'b_ndebug=if-release', 'warning_level=3', 'b_lto=true', 'b_vscrt=mt'],
- meson_version: '>=0.52')
+ meson_version: '>=0.54')
cc = meson.get_compiler('cpp')
@@ -48,7 +48,7 @@ endif
if host_machine.system() == 'windows'
add_project_arguments('-DNOMINMAX', language : 'cpp')
endif
-if host_machine.cpu_family() == 'arm'
+if ['arm', 'aarch64'].contains(host_machine.cpu_family())
if get_option('neon')
add_project_arguments(cc.get_supported_arguments(['-mfpu=neon']), language : 'cpp')
add_project_link_arguments(cc.get_supported_arguments(['-mfpu=neon']), language : 'cpp')
@@ -209,7 +209,6 @@ files += [
'src/utils/random.cc',
'src/utils/string.cc',
'src/utils/weights_adapter.cc',
- 'src/utils/fp16_utils.cc',
'src/version.cc',
]
includes += include_directories('src')
@@ -267,12 +266,14 @@ if get_option('build_backends')
if get_option('blas')
if get_option('mkl') and mkl_lib.found()
- add_project_arguments(['-DUSE_MKL', '-DUSE_BLAS'], language : 'cpp')
mkl_inc = get_option('mkl_include')
if run_command('scripts/checkdir.py', mkl_inc).returncode() == 0
includes += include_directories(mkl_inc)
endif
- deps += [ mkl_lib ]
+ if cc.has_header('mkl.h')
+ add_project_arguments(['-DUSE_MKL', '-DUSE_BLAS'], language : 'cpp')
+ deps += [ mkl_lib ]
+ endif
elif get_option('dnnl') and dnnl_lib.found()
add_project_arguments(['-DUSE_DNNL', '-DUSE_BLAS'], language : 'cpp')
@@ -313,7 +314,7 @@ if get_option('build_backends')
ispc_arch = 'x86-64'
ispc_extra_args = []
if get_option('ispc') and ispc.found()
- ispc_native_only = get_option('ispc_native_only')
+ ispc_native_only = get_option('ispc_native_only') and not meson.is_cross_build()
if host_machine.system() == 'windows'
outputnames = [ '@BASENAME@.obj']
if not ispc_native_only
@@ -330,26 +331,27 @@ if get_option('build_backends')
'@BASENAME@_avx512knl.o', '@BASENAME@_avx512skx.o' ]
endif
endif
- if ispc_native_only
- ispc_target = 'host'
- else
- ispc_target = 'sse2-i32x8,sse4-i32x8,avx1-i32x8,avx2-i32x8,avx512knl-i32x16,avx512skx-i32x16'
- endif
+ ispc_target = 'sse2-i32x8,sse4-i32x8,avx1-i32x8,avx2-i32x8,avx512knl-i32x16,avx512skx-i32x16'
if host_machine.system() == 'android'
ispc_extra_args += ['--target-os=android']
- if host_machine.cpu_family() == 'arm'
- outputnames = [ '@BASENAME@.o']
- if host_machine.cpu() == 'aarch64'
- ispc_target = 'neon-i32x8'
- ispc_arch = 'aarch64'
- else
- ispc_target = 'neon-i32x4'
- ispc_arch = 'arm'
- endif
+ endif
+
+ if ['arm', 'aarch64'].contains(host_machine.cpu_family())
+ outputnames = [ '@BASENAME@.o']
+ if host_machine.cpu_family() == 'aarch64'
+ ispc_target = 'neon-i32x8'
+ ispc_arch = 'aarch64'
+ else
+ ispc_target = 'neon-i32x4'
+ ispc_arch = 'arm'
endif
endif
+ if ispc_native_only
+ ispc_target = 'host'
+ endif
+
iscp_gen = generator(ispc,
output: [ '@BASENAME@_ispc.h', outputnames ],
arguments: [ '-O2', '--wno-perf', '--arch=' + ispc_arch,
@@ -377,6 +379,7 @@ if get_option('build_backends')
if get_option('ispc') and ispc.found()
files += iscp_gen.process('src/neural/blas/winograd_transform.ispc')
+ files += iscp_gen.process('src/neural/blas/layer_norm.ispc')
files += iscp_gen.process('src/neural/shared/activation.ispc')
add_project_arguments('-DUSE_ISPC', language : 'cpp')
endif
@@ -458,6 +461,7 @@ if get_option('build_backends')
includes += include_directories('src/neural/cuda/')
cuda_arguments = ['-c', '@INPUT@', '-o', '@OUTPUT@',
+ '-I', meson.source_root() + '/subprojects/abseil-cpp-20211102.0',
'-I', meson.current_source_dir() + '/src']
if host_machine.system() == 'windows'
if get_option('b_vscrt') == 'mt'
@@ -497,27 +501,32 @@ if get_option('build_backends')
)
# Handling of fp16 cuda code.
- nvcc_arch = '-arch=compute_70'
- nvcc_sm_list = ['sm_80', 'sm_75', 'sm_86', 'sm_70']
+ nvcc_sm_list = ['80', '75', '86', '70', '89', '90']
if host_machine.system() != 'windows'
- nvcc_arch = '-arch=compute_60'
- nvcc_sm_list += ['sm_60']
+ nvcc_sm_list += ['60']
if ['arm', 'aarch64'].contains(host_machine.cpu_family())
# Add Jetson versions to the list.
message('Jetson support enabled.')
- nvcc_arch = '-arch=compute_53'
- nvcc_sm_list += ['sm_72', 'sm_62', 'sm_53']
+ nvcc_sm_list += ['72', '62', '53', '87']
endif
endif
# Ignore the given CC for fp16 when it is not in the supported list.
if cuda_cc == '' or not nvcc_sm_list.contains('sm_' + cuda_cc)
- nvcc_extra_args = [nvcc_arch]
+ nvcc_extra_args = []
nvcc_help = run_command(nvcc, '-h').stdout()
foreach x : nvcc_sm_list
- if nvcc_help.contains(x)
- nvcc_extra_args += '-code=' + x
+ if nvcc_help.contains('sm_' + x)
+ nvcc_extra_args += '-gencode=arch=compute_' + x + ',code=sm_' + x
endif
endforeach
+ # For forward compatibility.
+ if nvcc_help.contains('sm_90') # Cuda 12+
+ nvcc_extra_args += '-gencode=arch=compute_90,code=compute_90'
+ elif nvcc_help.contains('sm_80') # Cuda 11+
+ nvcc_extra_args += '-gencode=arch=compute_80,code=compute_80'
+ elif nvcc_help.contains('sm_75') # Cuda 10+
+ nvcc_extra_args += '-gencode=arch=compute_75,code=compute_75'
+ endif
endif
files += custom_target('cuda fp16 code',
input : 'src/neural/cuda/fp16_kernels.cu',
@@ -619,6 +628,15 @@ endif
#############################################################################
## Dependencies
#############################################################################
+
+ ## ~~~~~~
+ ## Abseil
+ ## ~~~~~~
+ # $ meson wrap install abseil-cpp
+ absl = subproject('abseil-cpp', default_options : ['warning_level=0'])
+ deps += absl.get_variable('absl_container_dep').as_system()
+ includes += absl.get_variable('absl_include_dir')
+
## ~~~~
## zlib
## ~~~~
@@ -650,6 +668,10 @@ if not get_option('popcnt')
add_project_arguments('-DNO_POPCNT', language : 'cpp')
endif
+if not get_option('f16c')
+ add_project_arguments('-DNO_F16C', language : 'cpp')
+endif
+
if not get_option('pext')
add_project_arguments('-DNO_PEXT', language : 'cpp')
endif
@@ -723,5 +745,7 @@ if get_option('python_bindings')
python.extension_module('backends',
[py_files + files],
include_directories: [includes],
- dependencies: [cpython] + deps)
+ dependencies: [cpython] + deps,
+ subdir: 'lczero',
+ install: true)
endif
diff --git a/meson_options.txt b/meson_options.txt
index 002584666e..cc2364be45 100644
--- a/meson_options.txt
+++ b/meson_options.txt
@@ -133,6 +133,11 @@ option('popcnt',
value: true,
description: 'Use the popcnt instruction')
+option('f16c',
+ type: 'boolean',
+ value: true,
+ description: 'Use natice fp16 conversion instructions')
+
option('pext',
type: 'boolean',
value: false,
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000..b55ffdc83f
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,27 @@
+[build-system]
+requires = ["meson-python"]
+build-backend = "mesonpy"
+
+[project]
+name = "lczero_bindings"
+version = "0.1.0"
+description = "Leela Chess Zero Python bindings"
+authors = [{ name = "The LCZero Authors" }]
+license = {file = "COPYING"}
+readme = "README.md"
+requires-python = ">=3.7"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "Topic :: Games/Entertainment :: Board Games",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Environment :: GPU"
+]
+
+[project.urls]
+homepage = "https://github.com/LeelaChessZero/lc0"
+
+[tool.meson-python.args]
+dist = []
+setup = ["-Dpython_bindings=true"]
+compile = []
+install = []
\ No newline at end of file
diff --git a/scripts/appveyor_win_package.cmd b/scripts/appveyor_win_package.cmd
index 57296c35f3..76124cd5ec 100644
--- a/scripts/appveyor_win_package.cmd
+++ b/scripts/appveyor_win_package.cmd
@@ -27,10 +27,12 @@ IF %NAME%==onednn copy "%PKG_FOLDER%\%DNNL_NAME%\THIRD-PARTY-PROGRAMS" dist\DNNL
IF %NAME%==onednn 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\DNNL-LICENSE
IF %NAME%==onednn 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\DNNL-THIRD-PARTY-PROGRAMS
IF %ONNX_DML%==true type dist\README-onnx-dml.txt |more /P > dist\README.txt
+IF %ONNX_DML%==true type dist\install-dml.cmd |more /P > dist\install.cmd
IF %ONNX_DML%==true copy "%PKG_FOLDER%\%ONNX_NAME%\LICENSE" dist\ONNX-DML-LICENSE
IF %ONNX_DML%==true copy "%PKG_FOLDER%\%ONNX_NAME%\ThirdPartyNotices.txt" dist\ONNX-DML-ThirdPartyNotices.txt
IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip "%PKG_FOLDER%\%ONNX_NAME%\lib\onnxruntime.dll"
IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\README.txt
+IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\install.cmd
IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\ONNX-DML-LICENSE
IF %ONNX_DML%==true 7z a lc0-%APPVEYOR_REPO_TAG_NAME%-windows-%NAME%.zip .\dist\ONNX-DML-ThirdPartyNotices.txt
IF %OPENCL%==true type scripts\check_opencl.bat |more /P > dist\check_opencl.bat
diff --git a/scripts/compile_proto.py b/scripts/compile_proto.py
index b7bd4eda6b..1a21c2da4e 100755
--- a/scripts/compile_proto.py
+++ b/scripts/compile_proto.py
@@ -81,6 +81,7 @@
class Lexer:
+
def __init__(self, text):
self.text = text
self.grammar = [(re.compile(x, re.S + re.M), y) for x, y in GRAMMAR]
@@ -130,13 +131,16 @@ def NextTokenOrWhitespace(self):
def Error(self, text):
'''Throws an error with context in the file read.'''
+ line = self.text[:self.cur_offset].count('\n') + 1
line_start = self.text.rfind('\n', 0, self.cur_offset) + 1
line_end = self.text.find('\n', line_start)
+ if line_end == -1:
+ line_end = len(self.text)
sys.stderr.write('%s:\n' % text)
sys.stderr.write(self.text[line_start:line_end] + '\n')
sys.stderr.write(' ' * (self.cur_offset - line_start) + '^^^\n')
- raise ValueError("Parse error: %s at offset %d." %
- (text, self.cur_offset))
+ raise ValueError("Parse error: %s at line %d column %d." %
+ (text, line, (self.cur_offset - line_start)))
def ReadIdentifierPath(lexer):
@@ -155,7 +159,7 @@ def LookupType(name, stack):
for x in y:
if x.GetName() == name[0]:
if len(name) == 1:
- return x.GetType()
+ return x
else:
return LookupType(name[1:], [x.GetTypes()])
raise ValueError("Cannot find type: %s." % '.'.join(name))
@@ -167,6 +171,7 @@ def LookupType(name, stack):
class ProtoTypeParser:
+
def __init__(self, lexer, object_stack):
token, match = lexer.Pick()
if token in TYPES:
@@ -175,10 +180,16 @@ def __init__(self, lexer, object_stack):
lexer.Consume(token)
elif token == 'identifier':
self.name = ReadIdentifierPath(lexer)
- self.typetype = LookupType(self.name, object_stack)
+ self.typetype = 'forward'
else:
lexer.Error('Type expected')
+ def LookupForwardFieldType(self, object_stack):
+ if self.IsForward():
+ typ = LookupType(self.name, object_stack)
+ self.typetype = typ.GetType()
+ self.name = [typ.GetFullName()]
+
def IsZigzag(self):
if self.typetype == 'basic':
return self.name in ZIGZAG_TYPES
@@ -188,7 +199,7 @@ def GetCppType(self):
if self.typetype == 'basic':
return TYPES[self.name]
else:
- return '::'.join(self.name)
+ return '_'.join(self.name)
def GetVariableCppType(self):
if self.IsBytesType():
@@ -196,6 +207,9 @@ def GetVariableCppType(self):
else:
return self.GetCppType()
+ def IsEnumType(self):
+ return self.typetype == 'enum'
+
def IsVarintType(self):
return self.typetype == 'enum' or (self.typetype == 'basic'
and self.name in VARINT_TYPES)
@@ -231,6 +245,9 @@ def GetWireType(self):
def IsMessage(self):
return self.typetype == 'message'
+ def IsForward(self):
+ return self.typetype == 'forward'
+
def IsIntegralType(self):
if self.typetype == 'basic':
if self.name == 'double':
@@ -251,6 +268,7 @@ def IsIntegralType(self):
class ProtoFieldParser:
+
def __init__(self, lexer, object_stack):
token, match = lexer.Pick()
if token not in ['repeated', 'optional', 'required']:
@@ -266,6 +284,9 @@ def __init__(self, lexer, object_stack):
def IsType(self):
return False
+ def LookupForwardFieldType(self, object_stack):
+ self.type.LookupForwardFieldType(object_stack)
+
def GetParser(self):
name = self.name.group(0)
if self.type.IsMessage():
@@ -331,42 +352,91 @@ def GenerateOutput(self, w):
w.Write('%s %s(%d, %s, &out);' %
(prefix, fname[wire_id], self.number, name))
- def GenerateFunctions(self, w):
+ def GenerateJsonOutput(self, w):
+ name = self.name.group(0)
+ if self.category == 'repeated':
+ prefix = 'if (!%s_.empty())' % name
+ funcname = 'AppendJsonRepeatedField'
+ else:
+ prefix = 'if (has_%s_)' % name
+ funcname = 'AppendJsonField'
+ if self.type.IsEnumType():
+ value = '%s_Name(%s_)' % (self.type.GetCppType(), name)
+ else:
+ value = name + "_"
+ w.Write('%s %s("%s", %s, &first, &out);' %
+ (prefix, funcname, name, value))
+
+ def GenerateFunctionDeclarations(self, w):
name = self.name.group(0)
cpp_type = self.type.GetCppType()
var_cpp_type = self.type.GetVariableCppType()
if self.category == 'repeated':
if self.type.IsMessage():
- w.Write("%s* add_%s() { return &%s_.emplace_back(); }" %
- (cpp_type, name, name))
+ w.Write("%s* add_%s();" % (cpp_type, name))
else:
- w.Write("void add_%s(%s val) { %s_.emplace_back(val); }" %
- (name, cpp_type, name))
- w.Write("const std::vector<%s>& %s() const { return %s_; }" %
- (var_cpp_type, name, name))
+ w.Write("void add_%s(%s val);" % (name, cpp_type))
+ w.Write("const std::vector<%s>& %s() const;" %
+ (var_cpp_type, name))
if self.type.IsMessage():
- w.Write("const %s& %s(size_t idx) const { return %s_[idx]; }" %
- (cpp_type, name, name))
+ w.Write("const %s& %s(size_t idx) const;" % (cpp_type, name))
else:
- w.Write("%s %s(size_t idx) const { return %s_[idx]; }" %
- (cpp_type, name, name))
- w.Write("size_t %s_size() const { return %s_.size(); }" %
- (name, name))
+ w.Write("%s %s(size_t idx) const;" % (cpp_type, name))
+ w.Write("size_t %s_size() const;" % (name))
else:
- w.Write("bool has_%s() const { return has_%s_; }" % (name, name))
+ w.Write("bool has_%s() const;" % (name))
if self.type.IsMessage():
- w.Write("const %s& %s() const { return %s_; }" %
- (cpp_type, name, name))
- w.Write("%s* mutable_%s() {" % (cpp_type, name))
+ w.Write("const %s& %s() const;" % (cpp_type, name))
+ w.Write("%s* mutable_%s();" % (cpp_type, name))
+ else:
+ w.Write("%s %s() const;" % (cpp_type, name))
+ w.Write("void set_%s(%s val);" % (name, cpp_type))
+
+ def GenerateFunctionDefinitions(self, w, class_name):
+ name = self.name.group(0)
+ cpp_type = self.type.GetCppType()
+ var_cpp_type = self.type.GetVariableCppType()
+ if self.category == 'repeated':
+ if self.type.IsMessage():
+ w.Write(
+ "inline %s* %s::add_%s() { return &%s_.emplace_back(); }" %
+ (cpp_type, class_name, name, name))
+ else:
+ w.Write(
+ "inline void %s::add_%s(%s val) { %s_.emplace_back(val); }"
+ % (class_name, name, cpp_type, name))
+ w.Write(
+ "inline const std::vector<%s>& %s::%s() const { return %s_; }"
+ % (var_cpp_type, class_name, name, name))
+ if self.type.IsMessage():
+ w.Write(
+ "inline const %s& %s::%s(size_t idx) const { return %s_[idx]; }"
+ % (cpp_type, class_name, name, name))
+ else:
+ w.Write(
+ "inline %s %s::%s(size_t idx) const { return %s_[idx]; }" %
+ (cpp_type, class_name, name, name))
+ w.Write(
+ "inline size_t %s::%s_size() const { return %s_.size(); }" %
+ (class_name, name, name))
+ else:
+ w.Write("inline bool %s::has_%s() const { return has_%s_; }" %
+ (class_name, name, name))
+ if self.type.IsMessage():
+ w.Write("inline const %s& %s::%s() const { return %s_; }" %
+ (cpp_type, class_name, name, name))
+ w.Write("inline %s* %s::mutable_%s() {" %
+ (cpp_type, class_name, name))
w.Indent()
w.Write('has_%s_ = true;' % (name))
w.Write('return &%s_;' % name)
w.Unindent()
w.Write("}")
else:
- w.Write("%s %s() const { return %s_; }" %
- (cpp_type, name, name))
- w.Write("void set_%s(%s val) {" % (name, cpp_type))
+ w.Write("inline %s %s::%s() const { return %s_; }" %
+ (cpp_type, class_name, name, name))
+ w.Write("inline void %s::set_%s(%s val) {" %
+ (class_name, name, cpp_type))
w.Indent()
w.Write("has_%s_ = true;" % name)
w.Write("%s_ = val;" % name)
@@ -385,10 +455,12 @@ def GenerateVariable(self, w):
class ProtoEnumParser:
- def __init__(self, lexer):
+
+ def __init__(self, lexer, scope):
lexer.Consume('enum')
self.name = lexer.Consume('identifier').group(0)
self.values = []
+ self.scope = scope[:]
lexer.Consume('{')
while True:
token, match = lexer.Pick()
@@ -404,21 +476,54 @@ def __init__(self, lexer):
def GetName(self):
return self.name
+ def GetFullName(self):
+ return '_'.join([x.GetName() for x in self.scope] + [self.name])
+
def GetType(self):
return 'enum'
def IsType(self):
return True
- def Generate(self, w):
+ def ResolveForwardDeclarations(self, _):
+ pass
+
+ def GenerateMessageDeclarations(self, w):
+ pass
+
+ def GenerateMessageDefinitions(self, w):
+ pass
+
+ def GenerateFunctionDefinitions(self, w):
+ pass
+
+ def GenerateEnumDefinitions(self, w):
# Protobuf enum is mapped directly to C++ enum.
- w.Write('enum %s {' % self.name)
+ w.Write('enum %s : int {' % self.GetFullName())
w.Indent()
for key, value in self.values:
- w.Write('%s = %d,' % (key, value))
+ w.Write('%s_%s = %d,' % (self.GetFullName(), key, value))
+ w.Unindent()
+ w.Write('};')
+ w.Write('inline std::string %s_Name(%s val) {' %
+ (self.GetFullName(), self.GetFullName()))
+ w.Indent()
+ w.Write('switch (val) {')
+ w.Indent()
+ for key, _ in self.values:
+ w.Write('case %s_%s:' % (self.GetFullName(), key))
+ w.Write(' return "%s";' % key)
w.Unindent()
w.Write('};')
- # Static array of all possible enum values.
+ w.Write('return "%s(" + std::to_string(val) + ")";' % self.name)
+ w.Unindent()
+ w.Write('}')
+
+ def GenerateUsingDirectives(self, w):
+ w.Write('using %s = %s;' % (self.name, self.GetFullName()))
+ for key, _ in self.values:
+ w.Write('static constexpr %s %s =' % (self.name, key))
+ w.Write(' %s_%s;' % (self.GetFullName(), key))
w.Write('static constexpr std::array<%s,%d> %s_AllValues = {' %
(self.name, len(self.values), self.name))
w.Indent()
@@ -430,22 +535,18 @@ def Generate(self, w):
w.Write('static std::string %s_Name(%s val) {' %
(self.name, self.name))
w.Indent()
- w.Write('switch (val) {')
- w.Indent()
- for key, _ in self.values:
- w.Write('case %s:' % key)
- w.Write(' return "%s";' % key)
- w.Unindent()
- w.Write('};')
- w.Write('return "%s(" + std::to_string(val) + ")";' % self.name)
+ w.Write('return %s_Name(val);' % (self.GetFullName()))
w.Unindent()
w.Write('}')
class ProtoMessageParser:
- def __init__(self, lexer, type_stack):
+
+ def __init__(self, lexer, type_stack, scope):
+ type_stack[0].append(self)
self.types = []
self.fields = []
+ self.scope = scope[:]
lexer.Consume('message')
self.name = lexer.Consume('identifier').group(0)
lexer.Consume('{')
@@ -454,10 +555,10 @@ def __init__(self, lexer, type_stack):
if token == '}':
break
elif token == 'message':
- self.types.append(
- ProtoMessageParser(lexer, [self.types, *type_stack]))
+ ProtoMessageParser(lexer, [self.types, *type_stack],
+ self.scope + [self])
elif token == 'enum':
- self.types.append(ProtoEnumParser(lexer))
+ self.types.append(ProtoEnumParser(lexer, self.scope + [self]))
elif token in ['repeated', 'optional', 'required']:
self.fields.append(
ProtoFieldParser(lexer, [self.types, *type_stack]))
@@ -468,6 +569,9 @@ def __init__(self, lexer, type_stack):
def GetName(self):
return self.name
+ def GetFullName(self):
+ return '_'.join([x.GetName() for x in self.scope] + [self.name])
+
def GetType(self):
return 'message'
@@ -483,7 +587,15 @@ def GetFieldsGruppedByWireType(self):
type_to_fields.setdefault(x.type.GetWireType(), []).append(x)
return type_to_fields
- def WriteFieldParser(self, w, wire_id, fields):
+ def ResolveForwardDeclarations(self, type_stack):
+ type_stack.append(self.types)
+ for x in self.types:
+ x.ResolveForwardDeclarations(type_stack)
+ for x in self.fields:
+ x.LookupForwardFieldType(type_stack)
+ type_stack.pop()
+
+ def WriteFieldParserDeclaration(self, w, wire_id, fields):
fname = {0: 'SetVarInt', 1: 'SetInt64', 2: 'SetString', 5: 'SetInt32'}
tname = {
0: 'std::uint64_t',
@@ -491,8 +603,19 @@ def WriteFieldParser(self, w, wire_id, fields):
2: 'std::string_view',
5: 'std::uint32_t'
}
- w.Write('void %s(int field_id, %s val) override {' %
+ w.Write('void %s(int field_id, %s val) final;' %
(fname[wire_id], tname[wire_id]))
+
+ def WriteFieldParserDefinition(self, w, wire_id, fields):
+ fname = {0: 'SetVarInt', 1: 'SetInt64', 2: 'SetString', 5: 'SetInt32'}
+ tname = {
+ 0: 'std::uint64_t',
+ 1: 'std::uint64_t',
+ 2: 'std::string_view',
+ 5: 'std::uint32_t'
+ }
+ w.Write('inline void %s::%s(int field_id, %s val) {' %
+ (self.GetFullName(), fname[wire_id], tname[wire_id]))
w.Indent()
w.Write('switch (field_id) {')
w.Indent()
@@ -503,19 +626,67 @@ def WriteFieldParser(self, w, wire_id, fields):
w.Unindent()
w.Write('}')
- def Generate(self, w):
+ def GenerateUsingDirectives(self, w):
+ w.Write('using %s = %s;' % (self.name, self.GetFullName()))
+
+ def GenerateMessageDeclarations(self, w):
+ w.Write(f'class %s;' % self.GetFullName())
+ for x in self.types:
+ x.GenerateMessageDeclarations(w)
+
+ def GenerateEnumDefinitions(self, w):
+ for x in self.types:
+ x.GenerateEnumDefinitions(w)
+
+ def GenerateMessageDefinitions(self, w):
+ # Writing nested messages.
+ for x in self.types:
+ if x.GetType() == 'message':
+ x.GenerateMessageDefinitions(w)
# Protobuf message is a C++ class.
- w.Write('class %s : public lczero::ProtoMessage {' % self.name)
+ w.Write('class %s final : public lczero::ProtoMessage {' %
+ self.GetFullName())
w.Write(' public:')
w.Indent()
- # Writing submessages and enums.
+ # Writing using directives.
for x in self.types:
- x.Generate(w)
+ x.GenerateUsingDirectives(w)
+ # Writing function declarations.
for x in self.fields:
w.Write('')
- x.GenerateFunctions(w)
+ x.GenerateFunctionDeclarations(w)
w.Write('')
- w.Write('std::string OutputAsString() const override {')
+ w.Write('std::string OutputAsString() const final;')
+ w.Write('std::string OutputAsJson() const final;')
+ w.Write('void Clear() final;')
+
+ w.Unindent()
+ w.Write('')
+ w.Write(' private:')
+ w.Indent()
+ for k, v in self.GetFieldsGruppedByWireType().items():
+ self.WriteFieldParserDeclaration(w, k, v)
+ w.Write('')
+ for x in self.fields:
+ x.GenerateVariable(w)
+ w.Unindent()
+ w.Write('};')
+ w.Write('')
+
+ def GenerateFunctionDefinitions(self, w):
+ # Writing nested messages.
+ for x in self.types:
+ if x.GetType() == 'message':
+ x.GenerateFunctionDefinitions(w)
+ self.GenerateOutputAsStringFunc(w)
+ self.GenerateOutputAsJsonFunc(w)
+ self.GenerateClearFunc(w)
+ self.GenerateParserFuncs(w)
+ self.GenerateFieldAccessorFuncs(w)
+
+ def GenerateOutputAsStringFunc(self, w):
+ w.Write('inline std::string %s::OutputAsString() const {' %
+ self.GetFullName())
w.Indent()
w.Write('std::string out;')
for x in sorted(self.fields, key=lambda x: x.number):
@@ -523,31 +694,44 @@ def Generate(self, w):
w.Write('return out;')
w.Unindent()
w.Write('}')
- w.Write('')
- w.Write('void Clear() override {')
+
+ def GenerateOutputAsJsonFunc(self, w):
+ w.Write('inline std::string %s::OutputAsJson() const {' %
+ self.GetFullName())
w.Indent()
+ if self.fields:
+ w.Write('bool first = true;')
+ w.Write('std::string out = "{";')
for x in self.fields:
- x.GenerateClear(w)
+ x.GenerateJsonOutput(w)
+ w.Write('out += "}";')
+ w.Write('return out;')
w.Unindent()
w.Write('}')
- w.Unindent()
- w.Write('')
- w.Write(' private:')
+
+ def GenerateClearFunc(self, w):
+ w.Write('inline void %s::Clear() {' % self.GetFullName())
w.Indent()
- for k, v in self.GetFieldsGruppedByWireType().items():
- self.WriteFieldParser(w, k, v)
- w.Write('')
for x in self.fields:
- x.GenerateVariable(w)
+ x.GenerateClear(w)
w.Unindent()
- w.Write('};')
+ w.Write('}')
+
+ def GenerateParserFuncs(self, w):
+ for k, v in self.GetFieldsGruppedByWireType().items():
+ self.WriteFieldParserDefinition(w, k, v)
+
+ def GenerateFieldAccessorFuncs(self, w):
+ for x in self.fields:
+ x.GenerateFunctionDefinitions(w, self.GetFullName())
class ProtoFileParser:
'''Root grammar of .proto file'''
+
def __init__(self, lexer):
self.package = None
- self.objects = []
+ self.types = []
while True:
token, match = lexer.Pick()
if token == 'EOF':
@@ -558,6 +742,8 @@ def __init__(self, lexer):
self.ParsePackage(lexer)
elif token == 'message':
self.ParseMessage(lexer)
+ elif token == 'enum':
+ self.ParseEnum(lexer)
else:
lexer.Error('Expected message or something similar')
@@ -575,7 +761,10 @@ def ParsePackage(self, lexer):
lexer.Consume(';')
def ParseMessage(self, lexer):
- self.objects.append(ProtoMessageParser(lexer, [self.objects]))
+ ProtoMessageParser(lexer, [self.types], [])
+
+ def ParseEnum(self, lexer):
+ self.types.append(ProtoEnumParser(lexer, []))
def Generate(self, w):
w.Write('// This file is AUTOGENERATED, do not edit.')
@@ -583,16 +772,32 @@ def Generate(self, w):
w.Write('#include "utils/protomessage.h"')
for x in self.package:
w.Write('namespace %s {' % x)
- w.Indent()
- for object in self.objects:
- object.Generate(w)
- w.Unindent()
+ w.Write('')
+ w.Write('// Forward declarations.')
+ for object in self.types:
+ object.GenerateMessageDeclarations(w)
+ for object in self.types:
+ object.GenerateEnumDefinitions(w)
+ w.Write('')
+ w.Write('// Class declarations.')
+ for object in self.types:
+ object.GenerateMessageDefinitions(w)
+ w.Write('')
+ w.Write('// Function definitions.')
+ for object in self.types:
+ object.GenerateFunctionDefinitions(w)
for x in reversed(self.package):
w.Write('} // namespace %s' % x)
+ def ResolveForwardDeclarations(self):
+ type_stack = [self.types]
+ for object in self.types:
+ object.ResolveForwardDeclarations(type_stack)
+
class Writer:
'''A helper class for writing file line by line with indent.'''
+
def __init__(self, fo):
self.fo = fo
self.indent = 0
@@ -626,5 +831,6 @@ def Write(self, text):
with open(args.input, 'r') as input, open(dest_path, 'w') as output:
proto_file = ProtoFileParser(Lexer(input.read()))
+ proto_file.ResolveForwardDeclarations()
writer = Writer(output)
proto_file.Generate(writer)
diff --git a/src/benchmark/backendbench.cc b/src/benchmark/backendbench.cc
index 6792f9b778..401d8bea1b 100644
--- a/src/benchmark/backendbench.cc
+++ b/src/benchmark/backendbench.cc
@@ -29,6 +29,7 @@
#include "chess/board.h"
#include "mcts/node.h"
+#include "neural/encoder.h"
#include "neural/factory.h"
#include "utils/optionsparser.h"
diff --git a/src/benchmark/benchmark.cc b/src/benchmark/benchmark.cc
index 5679ff45b6..a48d7b1e3f 100644
--- a/src/benchmark/benchmark.cc
+++ b/src/benchmark/benchmark.cc
@@ -97,12 +97,12 @@ void Benchmark::Run() {
NNCache cache;
cache.SetCapacity(option_dict.Get(kNNCacheSizeId));
- NodeTree tree;
+ NodeTree tree = {option_dict};
tree.ResetToPosition(position, {});
const auto start = std::chrono::steady_clock::now();
auto search = std::make_unique(
- tree, network.get(),
+ &tree, network.get(),
std::make_unique(
std::bind(&Benchmark::OnBestMove, this, std::placeholders::_1),
std::bind(&Benchmark::OnInfo, this, std::placeholders::_1)),
diff --git a/src/chess/position.cc b/src/chess/position.cc
index ec085f1e37..9efbde0f6a 100644
--- a/src/chess/position.cc
+++ b/src/chess/position.cc
@@ -78,7 +78,7 @@ Position::Position(const ChessBoard& board, int rule50_ply, int game_ply)
}
uint64_t Position::Hash() const {
- return HashCat({us_board_.Hash(), static_cast(repetitions_)});
+ return us_board_.Hash();
}
std::string Position::DebugString() const { return us_board_.DebugString(); }
@@ -130,7 +130,7 @@ int PositionHistory::ComputeLastMoveRepetitions(int* cycle_length) const {
// TODO(crem) implement hash/cache based solution.
if (last.GetRule50Ply() < 4) return 0;
- for (int idx = positions_.size() - 3; idx >= 0; idx -= 2) {
+ for (int idx = positions_.size() - 5; idx >= 0; idx -= 2) {
const auto& pos = positions_[idx];
if (pos.GetBoard() == last.GetBoard()) {
*cycle_length = positions_.size() - 1 - idx;
diff --git a/src/chess/position.h b/src/chess/position.h
index 69241467c4..879e08fb2c 100644
--- a/src/chess/position.h
+++ b/src/chess/position.h
@@ -27,6 +27,7 @@
#pragma once
+#include
#include
#include "chess/board.h"
@@ -99,11 +100,22 @@ GameResult operator-(const GameResult& res);
class PositionHistory {
public:
PositionHistory() = default;
- PositionHistory(const PositionHistory& other) = default;
+ PositionHistory(const PositionHistory& other) {
+ positions_.reserve(
+ std::max(other.positions_.size() + 1, other.positions_.capacity()));
+ positions_ = other.positions_;
+ }
PositionHistory(PositionHistory&& other) = default;
- PositionHistory& operator=(const PositionHistory& other) = default;
- PositionHistory& operator=(PositionHistory&& other) = default;
+ PositionHistory& operator=(const PositionHistory& other) {
+ if (this == &other) return *this;
+ positions_.clear();
+ positions_.reserve(
+ std::max(other.positions_.size() + 1, other.positions_.capacity()));
+ positions_ = other.positions_;
+ return *this;
+ }
+ PositionHistory& operator=(PositionHistory&& other) = default;
// Returns first position of the game (or fen from which it was initialized).
const Position& Starting() const { return positions_.front(); }
diff --git a/src/engine.cc b/src/engine.cc
index a632ebb0ad..fe267a55b4 100644
--- a/src/engine.cc
+++ b/src/engine.cc
@@ -96,7 +96,7 @@ void EngineController::PopulateOptions(OptionsParser* options) {
NetworkFactory::PopulateOptions(options);
options->Add(kThreadsOptionId, 1, 128) = kDefaultThreads;
- options->Add(kNNCacheSizeId, 0, 999999999) = 2000000;
+ options->Add(kNNCacheSizeId, 0, 999999999) = 2000;
SearchParams::Populate(options);
options->Add(kSyzygyTablebaseId);
@@ -132,9 +132,11 @@ void EngineController::UpdateFromUciOptions() {
if (!syzygy_tb_->init(tb_paths)) {
CERR << "Failed to load Syzygy tablebases!";
syzygy_tb_ = nullptr;
- } else {
- tb_paths_ = tb_paths;
}
+ tb_paths_ = tb_paths;
+ } else if (tb_paths.empty()) {
+ syzygy_tb_ = nullptr;
+ tb_paths_.clear();
}
// Network.
@@ -189,7 +191,7 @@ Position EngineController::ApplyPositionMoves() {
board.SetFromFen(current_position_.fen, &no_capture_ply, &game_move);
int game_ply = 2 * game_move - (board.flipped() ? 1 : 2);
Position pos(board, no_capture_ply, game_ply);
- for (std::string move_str: current_position_.moves) {
+ for (std::string move_str : current_position_.moves) {
Move move(move_str);
if (pos.IsBlackToMove()) move.Mirror();
pos = Position(pos, move);
@@ -204,7 +206,7 @@ void EngineController::SetupPosition(
UpdateFromUciOptions();
- if (!tree_) tree_ = std::make_unique();
+ if (!tree_) tree_ = std::make_unique(options_);
std::vector moves;
for (const auto& move : moves_str) moves.emplace_back(move);
@@ -294,7 +296,7 @@ void EngineController::Go(const GoParams& params) {
auto stopper = time_manager_->GetStopper(params, *tree_.get());
search_ = std::make_unique(
- *tree_, network_.get(), std::move(responder),
+ tree_.get(), network_.get(), std::move(responder),
StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()),
*move_start_time_, std::move(stopper), params.infinite || params.ponder,
options_, &cache_, syzygy_tb_.get());
diff --git a/src/lc0ctl/describenet.cc b/src/lc0ctl/describenet.cc
index f67997f1fc..634589f66b 100644
--- a/src/lc0ctl/describenet.cc
+++ b/src/lc0ctl/describenet.cc
@@ -96,6 +96,20 @@ void ShowNetworkFormatInfo(const pblczero::Net& weights) {
COUT << Justify("MLH")
<< NetworkFormat::MovesLeftFormat_Name(net_format.moves_left());
}
+ if (net_format.has_default_activation()) {
+ COUT << Justify("Default activation")
+ << NetworkFormat::DefaultActivation_Name(
+ net_format.default_activation());
+ }
+ if (net_format.has_smolgen_activation()) {
+ COUT << Justify("Smolgen activation")
+ << NetworkFormat::ActivationFunction_Name(
+ net_format.smolgen_activation());
+ }
+ if (net_format.has_ffn_activation()) {
+ COUT << Justify("FFN activation")
+ << NetworkFormat::ActivationFunction_Name(net_format.ffn_activation());
+ }
}
void ShowNetworkTrainingInfo(const pblczero::Net& weights) {
@@ -124,33 +138,84 @@ void ShowNetworkTrainingInfo(const pblczero::Net& weights) {
}
}
+void ShowNetworkWeightsBodyInfo(const pblczero::Net& weights) {
+ const auto& w = weights.weights();
+ if (w.encoder_size() > 0) {
+ COUT << Justify("Encoders") << w.encoder_size();
+ COUT << Justify("Encoder heads") << w.headcount();
+ COUT << Justify("Embedding size") << w.ip_emb_b().params().size() / 2;
+ COUT << Justify("Dmodel") << w.encoder(0).mha().q_b().params().size() / 2;
+ COUT << Justify("Encoder DFF")
+ << w.encoder(0).ffn().dense1_b().params().size() / 2;
+ } else {
+ COUT << Justify("Blocks") << w.residual_size();
+ int se_count = 0;
+ for (size_t i = 0; i < w.residual_size(); i++) {
+ if (w.residual(i).has_se()) se_count++;
+ }
+ if (se_count > 0) {
+ COUT << Justify("SE blocks") << se_count;
+ }
+ COUT << Justify("Filters")
+ << w.input().weights().params().size() / 2 / 112 / 9;
+ }
+}
+
+void ShowNetworkWeightsPolicyInfo(const pblczero::Net& weights) {
+ using pblczero::NetworkFormat;
+ const auto& w = weights.weights();
+ const auto& format = weights.format().network_format();
+ auto pol_activation = NetworkFormat::ACTIVATION_DEFAULT;
+ if (format.policy() == NetworkFormat::POLICY_ATTENTION) {
+ // Non-attentionbody nets use hardcoded SELU as policy activation and FFN
+ // activations.
+ auto ffn_activation = format.ffn_activation();
+ if (w.encoder_size() == 0) {
+ pol_activation = NetworkFormat::ACTIVATION_SELU;
+ ffn_activation = NetworkFormat::ACTIVATION_SELU;
+ }
+
+ COUT << Justify("Policy") << "Attention";
+ COUT << Justify("Policy activation")
+ << NetworkFormat::ActivationFunction_Name(pol_activation);
+
+ if (w.pol_encoder_size() > 0) {
+ COUT << Justify("Policy encoders") << w.pol_encoder_size();
+ COUT << Justify("Policy encoder heads") << w.pol_headcount();
+ COUT << Justify("Policy encoder Dmodel")
+ << w.pol_encoder(0).mha().q_b().params().size() / 2;
+ COUT << Justify("Policy encoder DFF")
+ << w.pol_encoder(0).ffn().dense1_b().params().size() / 2;
+ COUT << Justify("Policy FFN activation")
+ << NetworkFormat::ActivationFunction_Name(ffn_activation);
+ }
+ COUT << Justify("Policy Dmodel") << w.ip2_pol_b().params().size() / 2;
+ } else {
+ COUT << Justify("Policy") << (w.has_policy1() ? "Convolution" : "Dense");
+ COUT << Justify("Policy activation")
+ << NetworkFormat::ActivationFunction_Name(pol_activation);
+ if (!w.has_policy1()) {
+ int policy_channels = w.policy().biases().params().size() / 2;
+ if (policy_channels == 0) {
+ policy_channels = w.policy().bn_means().params().size() / 2;
+ }
+ COUT << Justify("Policy channels") << policy_channels;
+ }
+ }
+}
+
void ShowNetworkWeightsInfo(const pblczero::Net& weights) {
if (!weights.has_weights()) return;
COUT << "\nWeights";
COUT << "~~~~~~~";
+
+ ShowNetworkWeightsBodyInfo(weights);
+ ShowNetworkWeightsPolicyInfo(weights);
+
const auto& w = weights.weights();
- COUT << Justify("Blocks") << w.residual_size();
- int se_count = 0;
- for (size_t i = 0; i < w.residual_size(); i++) {
- if (w.residual(i).has_se()) se_count++;
- }
- if (se_count > 0) {
- COUT << Justify("SE blocks") << se_count;
- }
-
- COUT << Justify("Filters")
- << w.input().weights().params().size() / 2 / 112 / 9;
- COUT << Justify("Policy") << (w.has_policy1() ? "Convolution" : "Dense");
- if (!w.has_policy1()) {
- int policy_channels = w.policy().biases().params().size() / 2;
- if (policy_channels == 0) {
- policy_channels = w.policy().bn_means().params().size() / 2;
- }
- COUT << Justify("Policy channels") << policy_channels;
- }
COUT << Justify("Value")
<< (w.ip2_val_w().params().size() / 2 % 3 == 0 ? "WDL" : "Classical");
- COUT << Justify("MLH") << (w.has_moves_left() ? "Present" : "Absent");
+ COUT << Justify("MLH") << (w.has_ip2_mov_w() ? "Present" : "Absent");
}
void ShowNetworkOnnxInfo(const pblczero::Net& weights,
diff --git a/src/mcts/node.cc b/src/mcts/node.cc
index 9bda7f144a..8547805e0e 100644
--- a/src/mcts/node.cc
+++ b/src/mcts/node.cc
@@ -27,95 +27,24 @@
#include "mcts/node.h"
+#include
+
#include
#include
#include
#include
#include
+#include
+#include
#include
#include
+#include
-#include "neural/encoder.h"
-#include "neural/network.h"
#include "utils/exception.h"
#include "utils/hashcat.h"
namespace lczero {
-/////////////////////////////////////////////////////////////////////////
-// Node garbage collector
-/////////////////////////////////////////////////////////////////////////
-
-namespace {
-// Periodicity of garbage collection, milliseconds.
-const int kGCIntervalMs = 100;
-
-// Every kGCIntervalMs milliseconds release nodes in a separate GC thread.
-class NodeGarbageCollector {
- public:
- NodeGarbageCollector() : gc_thread_([this]() { Worker(); }) {}
-
- // Takes ownership of a subtree, to dispose it in a separate thread when
- // it has time.
- void AddToGcQueue(std::unique_ptr node, size_t solid_size = 0) {
- if (!node) return;
- Mutex::Lock lock(gc_mutex_);
- subtrees_to_gc_.emplace_back(std::move(node));
- subtrees_to_gc_solid_size_.push_back(solid_size);
- }
-
- ~NodeGarbageCollector() {
- // Flips stop flag and waits for a worker thread to stop.
- stop_.store(true);
- gc_thread_.join();
- }
-
- private:
- void GarbageCollect() {
- while (!stop_.load()) {
- // Node will be released in destructor when mutex is not locked.
- std::unique_ptr node_to_gc;
- size_t solid_size = 0;
- {
- // Lock the mutex and move last subtree from subtrees_to_gc_ into
- // node_to_gc.
- Mutex::Lock lock(gc_mutex_);
- if (subtrees_to_gc_.empty()) return;
- node_to_gc = std::move(subtrees_to_gc_.back());
- subtrees_to_gc_.pop_back();
- solid_size = subtrees_to_gc_solid_size_.back();
- subtrees_to_gc_solid_size_.pop_back();
- }
- // Solid is a hack...
- if (solid_size != 0) {
- for (size_t i = 0; i < solid_size; i++) {
- node_to_gc.get()[i].~Node();
- }
- std::allocator alloc;
- alloc.deallocate(node_to_gc.release(), solid_size);
- }
- }
- }
-
- void Worker() {
- while (!stop_.load()) {
- std::this_thread::sleep_for(std::chrono::milliseconds(kGCIntervalMs));
- GarbageCollect();
- };
- }
-
- mutable Mutex gc_mutex_;
- std::vector> subtrees_to_gc_ GUARDED_BY(gc_mutex_);
- std::vector subtrees_to_gc_solid_size_ GUARDED_BY(gc_mutex_);
-
- // When true, Worker() should stop and exit.
- std::atomic stop_{false};
- std::thread gc_thread_;
-};
-
-NodeGarbageCollector gNodeGc;
-} // namespace
-
/////////////////////////////////////////////////////////////////////////
// Edge
/////////////////////////////////////////////////////////////////////////
@@ -188,115 +117,102 @@ std::unique_ptr Edge::FromMovelist(const MoveList& moves) {
}
/////////////////////////////////////////////////////////////////////////
-// Node
+// LowNode + Node
/////////////////////////////////////////////////////////////////////////
-Node* Node::CreateSingleChildNode(Move move) {
- assert(!edges_);
- assert(!child_);
- edges_ = Edge::FromMovelist({move});
- num_edges_ = 1;
- child_ = std::make_unique(this, 0);
- return child_.get();
+// Put @low_node at the end of TT @gc_queue, if both @gc_queue and @low_node
+// are not null and &low_node is TT and about to become parent-less (has only
+// one parent).
+static void TTGCEnqueue(GCQueue* gc_queue, const LowNode* low_node) {
+ if (gc_queue && low_node && low_node->IsTT() &&
+ low_node->GetNumParents() == 1)
+ gc_queue->push_back(low_node->GetHash());
}
-void Node::CreateEdges(const MoveList& moves) {
- assert(!edges_);
- assert(!child_);
- edges_ = Edge::FromMovelist(moves);
- num_edges_ = moves.size();
-}
+void Node::Trim(GCQueue* gc_queue) {
+ wl_ = 0.0f;
+
+ TTGCEnqueue(gc_queue, low_node_);
+ UnsetLowNode();
+ // sibling_
-Node::ConstIterator Node::Edges() const {
- return {*this, !solid_children_ ? &child_ : nullptr};
+ d_ = 0.0f;
+ m_ = 0.0f;
+ n_ = 0;
+ n_in_flight_ = 0;
+
+ // edge_
+
+ // index_
+
+ terminal_type_ = Terminal::NonTerminal;
+ lower_bound_ = GameResult::BLACK_WON;
+ upper_bound_ = GameResult::WHITE_WON;
+ repetition_ = false;
}
-Node::Iterator Node::Edges() {
- return {*this, !solid_children_ ? &child_ : nullptr};
+
+Node* Node::GetChild() const {
+ if (!low_node_) return nullptr;
+ return low_node_->GetChild()->get();
}
+bool Node::HasChildren() const { return low_node_ && low_node_->HasChildren(); }
+
float Node::GetVisitedPolicy() const {
float sum = 0.0f;
- for (auto* node : VisitedNodes()) sum += GetEdgeToNode(node)->GetP();
+ for (auto* node : VisitedNodes()) sum += node->GetP();
return sum;
}
-Edge* Node::GetEdgeToNode(const Node* node) const {
- assert(node->parent_ == this);
- assert(node->index_ < num_edges_);
- return &edges_[node->index_];
+uint32_t Node::GetNInFlight() const {
+ return n_in_flight_.load(std::memory_order_acquire);
+}
+
+uint32_t Node::GetChildrenVisits() const {
+ return low_node_ ? low_node_->GetChildrenVisits() : 0;
}
-Edge* Node::GetOwnEdge() const { return GetParent()->GetEdgeToNode(this); }
+uint32_t Node::GetTotalVisits() const {
+ return low_node_ ? low_node_->GetN() : 0;
+}
+
+const Edge& LowNode::GetEdgeAt(uint16_t index) const { return edges_[index]; }
std::string Node::DebugString() const {
std::ostringstream oss;
- oss << " Term:" << static_cast(terminal_type_) << " This:" << this
- << " Parent:" << parent_ << " Index:" << index_
- << " Child:" << child_.get() << " Sibling:" << sibling_.get()
- << " WL:" << wl_ << " N:" << n_ << " N_:" << n_in_flight_
- << " Edges:" << static_cast(num_edges_)
+ oss << " This:" << this << " LowNode:" << low_node_
+ << " Index:" << index_ << " Move:" << GetMove().as_string()
+ << " Sibling:" << sibling_.get() << " P:" << GetP() << " WL:" << wl_
+ << " D:" << d_ << " M:" << m_ << " N:" << n_ << " N_:" << n_in_flight_
+ << " Term:" << static_cast(terminal_type_)
<< " Bounds:" << static_cast(lower_bound_) - 2 << ","
- << static_cast(upper_bound_) - 2
- << " Solid:" << solid_children_;
+ << static_cast(upper_bound_) - 2;
return oss.str();
}
-bool Node::MakeSolid() {
- if (solid_children_ || num_edges_ == 0 || IsTerminal()) return false;
- // Can only make solid if no immediate leaf childredn are in flight since we
- // allow the search code to hold references to leaf nodes across locks.
- Node* old_child_to_check = child_.get();
- uint32_t total_in_flight = 0;
- while (old_child_to_check != nullptr) {
- if (old_child_to_check->GetN() <= 1 &&
- old_child_to_check->GetNInFlight() > 0) {
- return false;
- }
- if (old_child_to_check->IsTerminal() &&
- old_child_to_check->GetNInFlight() > 0) {
- return false;
- }
- total_in_flight += old_child_to_check->GetNInFlight();
- old_child_to_check = old_child_to_check->sibling_.get();
- }
- // If the total of children in flight is not the same as self, then there are
- // collisions against immediate children (which don't update the GetNInFlight
- // of the leaf) and its not safe.
- if (total_in_flight != GetNInFlight()) {
- return false;
- }
- std::allocator alloc;
- auto* new_children = alloc.allocate(num_edges_);
- for (int i = 0; i < num_edges_; i++) {
- new (&(new_children[i])) Node(this, i);
- }
- std::unique_ptr old_child = std::move(child_);
- while (old_child) {
- int index = old_child->index_;
- new_children[index] = std::move(*old_child.get());
- // This isn't needed, but it helps crash things faster if something has gone wrong.
- old_child->parent_ = nullptr;
- gNodeGc.AddToGcQueue(std::move(old_child));
- new_children[index].UpdateChildrenParents();
- old_child = std::move(new_children[index].sibling_);
- }
- // This is a hack.
- child_ = std::unique_ptr(new_children);
- solid_children_ = true;
- return true;
+std::string LowNode::DebugString() const {
+ std::ostringstream oss;
+ oss << " This:" << this << " Hash:" << hash_
+ << " Edges:" << edges_.get()
+ << " NumEdges:" << static_cast(num_edges_)
+ << " Child:" << child_.get() << " WL:" << wl_ << " D:" << d_
+ << " M:" << m_ << " N:" << n_ << " NP:" << num_parents_
+ << " Term:" << static_cast(terminal_type_)
+ << " Bounds:" << static_cast(lower_bound_) - 2 << ","
+ << static_cast(upper_bound_) - 2
+ << " IsTransposition:" << is_transposition;
+ return oss.str();
}
-void Node::SortEdges() {
- assert(edges_);
- assert(!child_);
+void Edge::SortEdges(Edge* edges, int num_edges) {
// Sorting on raw p_ is the same as sorting on GetP() as a side effect of
// the encoding, and its noticeably faster.
- std::sort(edges_.get(), (edges_.get() + num_edges_),
+ std::sort(edges, (edges + num_edges),
[](const Edge& a, const Edge& b) { return a.p_ > b.p_; });
}
-void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) {
- if (type != Terminal::TwoFold) SetBounds(result, result);
+void LowNode::MakeTerminal(GameResult result, float plies_left, Terminal type) {
+ SetBounds(result, result);
terminal_type_ = type;
m_ = plies_left;
if (result == GameResult::DRAW) {
@@ -308,34 +224,105 @@ void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) {
} else if (result == GameResult::BLACK_WON) {
wl_ = -1.0f;
d_ = 0.0f;
- // Terminal losses have no uncertainty and no reason for their U value to be
- // comparable to another non-loss choice. Force this by clearing the policy.
- if (GetParent() != nullptr) GetOwnEdge()->SetP(0.0f);
}
+
+ assert(WLDMInvariantsHold());
}
-void Node::MakeNotTerminal() {
+void LowNode::MakeNotTerminal(const Node* node) {
+ assert(edges_);
+ if (!IsTerminal()) return;
+
terminal_type_ = Terminal::NonTerminal;
+ lower_bound_ = GameResult::BLACK_WON;
+ upper_bound_ = GameResult::WHITE_WON;
n_ = 0;
+ wl_ = 0.0;
+ d_ = 0.0;
+ m_ = 0.0;
- // If we have edges, we've been extended (1 visit), so include children too.
- if (edges_) {
- n_++;
- for (const auto& child : Edges()) {
+ // Include children too.
+ if (node->GetNumEdges() > 0) {
+ for (const auto& child : node->Edges()) {
const auto n = child.GetN();
if (n > 0) {
n_ += n;
// Flip Q for opponent.
// Default values don't matter as n is > 0.
- wl_ += -child.GetWL(0.0f) * n;
+ wl_ += child.GetWL(0.0f) * n;
d_ += child.GetD(0.0f) * n;
+ m_ += child.GetM(0.0f) * n;
}
}
// Recompute with current eval (instead of network's) and children's eval.
wl_ /= n_;
d_ /= n_;
+ m_ /= n_;
+ }
+
+ assert(WLDMInvariantsHold());
+}
+
+void LowNode::SetBounds(GameResult lower, GameResult upper) {
+ lower_bound_ = lower;
+ upper_bound_ = upper;
+}
+
+uint8_t Node::GetNumEdges() const {
+ return low_node_ ? low_node_->GetNumEdges() : 0;
+}
+
+void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) {
+ SetBounds(result, result);
+ terminal_type_ = type;
+ m_ = plies_left;
+ if (result == GameResult::DRAW) {
+ wl_ = 0.0f;
+ d_ = 1.0f;
+ } else if (result == GameResult::WHITE_WON) {
+ wl_ = 1.0f;
+ d_ = 0.0f;
+ } else if (result == GameResult::BLACK_WON) {
+ wl_ = -1.0f;
+ d_ = 0.0f;
+ // Terminal losses have no uncertainty and no reason for their U value to be
+ // comparable to another non-loss choice. Force this by clearing the policy.
+ SetP(0.0f);
+ }
+
+ assert(WLDMInvariantsHold());
+}
+
+void Node::MakeNotTerminal(bool also_low_node) {
+ // At least one of node and low node pair needs to be a terminal.
+ if (!IsTerminal() &&
+ (!also_low_node || !low_node_ || !low_node_->IsTerminal()))
+ return;
+
+ terminal_type_ = Terminal::NonTerminal;
+ repetition_ = false;
+ if (low_node_) { // Two-fold or derived terminal.
+ // Revert low node first.
+ if (also_low_node && low_node_) low_node_->MakeNotTerminal(this);
+
+ auto [lower_bound, upper_bound] = low_node_->GetBounds();
+ lower_bound_ = -upper_bound;
+ upper_bound_ = -lower_bound;
+ n_ = low_node_->GetN();
+ wl_ = -low_node_->GetWL();
+ d_ = low_node_->GetD();
+ m_ = low_node_->GetM() + 1;
+ } else { // Real terminal.
+ lower_bound_ = GameResult::BLACK_WON;
+ upper_bound_ = GameResult::WHITE_WON;
+ n_ = 0.0f;
+ wl_ = 0.0f;
+ d_ = 0.0f;
+ m_ = 0.0f;
}
+
+ assert(WLDMInvariantsHold());
}
void Node::SetBounds(GameResult lower, GameResult upper) {
@@ -344,108 +331,323 @@ void Node::SetBounds(GameResult lower, GameResult upper) {
}
bool Node::TryStartScoreUpdate() {
- if (n_ == 0 && n_in_flight_ > 0) return false;
- ++n_in_flight_;
+ if (n_ > 0) {
+ n_in_flight_.fetch_add(1, std::memory_order_acq_rel);
+ } else {
+ uint32_t expected_n_if_flight_ = 0;
+ if (!n_in_flight_.compare_exchange_strong(expected_n_if_flight_, 1,
+ std::memory_order_acq_rel)) {
+ return false;
+ }
+ }
+
return true;
}
-void Node::CancelScoreUpdate(int multivisit) {
- n_in_flight_ -= multivisit;
+void Node::CancelScoreUpdate(uint32_t multivisit) {
+ assert(GetNInFlight() >= (uint32_t)multivisit);
+ n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
+}
+
+void LowNode::FinalizeScoreUpdate(float v, float d, float m,
+ uint32_t multivisit) {
+ assert(edges_);
+ // Recompute Q.
+ wl_ += multivisit * (v - wl_) / (n_ + multivisit);
+ d_ += multivisit * (d - d_) / (n_ + multivisit);
+ m_ += multivisit * (m - m_) / (n_ + multivisit);
+
+ assert(WLDMInvariantsHold());
+
+ // Increment N.
+ n_ += multivisit;
+}
+
+void LowNode::AdjustForTerminal(float v, float d, float m,
+ uint32_t multivisit) {
+ assert(static_cast(multivisit) <= n_);
+
+ // Recompute Q.
+ wl_ += multivisit * v / n_;
+ d_ += multivisit * d / n_;
+ m_ += multivisit * m / n_;
+
+ assert(WLDMInvariantsHold());
}
-void Node::FinalizeScoreUpdate(float v, float d, float m, int multivisit) {
+void Node::FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit) {
// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
m_ += multivisit * (m - m_) / (n_ + multivisit);
+ assert(WLDMInvariantsHold());
+
// Increment N.
n_ += multivisit;
// Decrement virtual loss.
- n_in_flight_ -= multivisit;
+ assert(GetNInFlight() >= (uint32_t)multivisit);
+ n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
}
-void Node::AdjustForTerminal(float v, float d, float m, int multivisit) {
+void Node::AdjustForTerminal(float v, float d, float m, uint32_t multivisit) {
+ assert(static_cast(multivisit) <= n_);
+
// Recompute Q.
wl_ += multivisit * v / n_;
d_ += multivisit * d / n_;
m_ += multivisit * m / n_;
+
+ assert(WLDMInvariantsHold());
}
-void Node::RevertTerminalVisits(float v, float d, float m, int multivisit) {
- // Compute new n_ first, as reducing a node to 0 visits is a special case.
- const int n_new = n_ - multivisit;
- if (n_new <= 0) {
- // If n_new == 0, reset all relevant values to 0.
- wl_ = 0.0;
- d_ = 1.0;
- m_ = 0.0;
- n_ = 0;
- } else {
- // Recompute Q and M.
- wl_ -= multivisit * (v - wl_) / n_new;
- d_ -= multivisit * (d - d_) / n_new;
- m_ -= multivisit * (m - m_) / n_new;
- // Decrement N.
- n_ -= multivisit;
+void Node::IncrementNInFlight(uint32_t multivisit) {
+ n_in_flight_.fetch_add(multivisit, std::memory_order_acq_rel);
+}
+
+void LowNode::ReleaseChildren(GCQueue* gc_queue) {
+ for (auto child = GetChild()->get(); child != nullptr;
+ child = child->GetSibling()->get()) {
+ TTGCEnqueue(gc_queue, child->GetLowNode());
}
+ child_.reset();
}
-void Node::UpdateChildrenParents() {
- if (!solid_children_) {
- Node* cur_child = child_.get();
- while (cur_child != nullptr) {
- cur_child->parent_ = this;
- cur_child = cur_child->sibling_.get();
- }
- } else {
- Node* child_array = child_.get();
- for (int i = 0; i < num_edges_; i++) {
- child_array[i].parent_ = this;
+void LowNode::ReleaseChildrenExceptOne(Node* node_to_save, GCQueue* gc_queue) {
+ // Stores node which will have to survive (or nullptr if it's not found).
+ atomic_unique_ptr saved_node;
+ // Pointer to unique_ptr, so that we could move from it.
+ for (auto node = &child_; *node != nullptr; node = (*node)->GetSibling()) {
+ // If current node is the one that we have to save.
+ if (node->get() == node_to_save) {
+ // Save the node, and take the ownership from the unique_ptr.
+ saved_node = std::move(*node);
+ node = &saved_node;
+ } else {
+ TTGCEnqueue(gc_queue, (*node)->GetLowNode());
}
}
+ // Kill all remaining siblings.
+ saved_node->GetSibling()->reset();
+ // Make saved node the only child. (kills previous siblings).
+ child_ = std::move(saved_node);
}
-void Node::ReleaseChildren() {
- gNodeGc.AddToGcQueue(std::move(child_), solid_children_ ? num_edges_ : 0);
+void Node::ReleaseChildrenExceptOne(Node* node_to_save,
+ GCQueue* gc_queue) const {
+ // Sometime we have no graph yet or a reverted terminal without low node.
+ if (low_node_) low_node_->ReleaseChildrenExceptOne(node_to_save, gc_queue);
}
-void Node::ReleaseChildrenExceptOne(Node* node_to_save) {
- if (solid_children_) {
- std::unique_ptr saved_node;
- if (node_to_save != nullptr) {
- saved_node = std::make_unique(this, node_to_save->index_);
- *saved_node = std::move(*node_to_save);
- }
- gNodeGc.AddToGcQueue(std::move(child_), num_edges_);
- child_ = std::move(saved_node);
- if (child_) {
- child_->UpdateChildrenParents();
+void Node::SetLowNode(LowNode* low_node) {
+ assert(!low_node_);
+ low_node->AddParent();
+ low_node_ = low_node;
+}
+void Node::UnsetLowNode() {
+ if (low_node_) low_node_->RemoveParent();
+ low_node_ = nullptr;
+}
+
+static std::string PtrToNodeName(const void* ptr) {
+ std::ostringstream oss;
+ oss << "n_" << ptr;
+ return oss.str();
+}
+
+std::string LowNode::DotNodeString() const {
+ std::ostringstream oss;
+ oss << PtrToNodeName(this) << " ["
+ << "shape=box";
+ // Adjust formatting to limit node size.
+ oss << std::fixed << std::setprecision(3);
+ oss << ",label=\"" //
+ << std::showpos //
+ << "WL=" << wl_ //
+ << std::noshowpos //
+ << "\\lD=" << d_ << "\\lM=" << m_ << "\\lN=" << n_ << "\\l\"";
+ // Set precision for tooltip.
+ oss << std::fixed << std::showpos << std::setprecision(5);
+ oss << ",tooltip=\"" //
+ << std::showpos //
+ << "WL=" << wl_ //
+ << std::noshowpos //
+ << "\\nD=" << d_ << "\\nM=" << m_ << "\\nN=" << n_
+ << "\\nNP=" << num_parents_
+ << "\\nTerm=" << static_cast(terminal_type_) //
+ << std::showpos //
+ << "\\nBounds=" << static_cast(lower_bound_) - 2 << ","
+ << static_cast(upper_bound_) - 2
+ << "\\nIsTransposition=" << is_transposition //
+ << std::noshowpos //
+ << "\\n\\nThis=" << this << "\\nEdges=" << edges_.get()
+ << "\\nNumEdges=" << static_cast(num_edges_)
+ << "\\nChild=" << child_.get() << "\\n\"";
+ oss << "];";
+ return oss.str();
+}
+
+std::string Node::DotEdgeString(bool as_opponent, const LowNode* parent) const {
+ std::ostringstream oss;
+ oss << (parent == nullptr ? "top" : PtrToNodeName(parent)) << " -> "
+ << (low_node_ ? PtrToNodeName(low_node_) : PtrToNodeName(this)) << " [";
+ oss << "label=\""
+ << (parent == nullptr ? "N/A" : GetMove(as_opponent).as_string())
+ << "\\lN=" << n_ << "\\lN_=" << n_in_flight_;
+ oss << "\\l\"";
+ // Set precision for tooltip.
+ oss << std::fixed << std::setprecision(5);
+ oss << ",labeltooltip=\""
+ << "P=" << (parent == nullptr ? 0.0f : GetP()) //
+ << std::showpos //
+ << "\\nWL= " << wl_ //
+ << std::noshowpos //
+ << "\\nD=" << d_ << "\\nM=" << m_ << "\\nN=" << n_
+ << "\\nN_=" << n_in_flight_
+ << "\\nTerm=" << static_cast(terminal_type_) //
+ << std::showpos //
+ << "\\nBounds=" << static_cast(lower_bound_) - 2 << ","
+ << static_cast(upper_bound_) - 2 << "\\n\\nThis=" << this //
+ << std::noshowpos //
+ << "\\nLowNode=" << low_node_ << "\\nParent=" << parent
+ << "\\nIndex=" << index_ << "\\nSibling=" << sibling_.get() << "\\n\"";
+ oss << "];";
+ return oss.str();
+}
+
+std::string Node::DotGraphString(bool as_opponent) const {
+ std::ostringstream oss;
+ std::unordered_set seen;
+ std::list> unvisited_fifo;
+
+ oss << "strict digraph {" << std::endl;
+ oss << "edge ["
+ << "headport=n"
+ << ",tooltip=\" \"" // Remove default tooltips from edge parts.
+ << "];" << std::endl;
+ oss << "node ["
+ << "shape=point" // For fake nodes.
+ << ",style=filled" // Show tooltip everywhere on the node.
+ << ",fillcolor=ivory"
+ << "];" << std::endl;
+ oss << "ranksep=" << 4.0f * std::log10(GetN()) << std::endl;
+
+ oss << DotEdgeString(!as_opponent) << std::endl;
+ if (low_node_) {
+ seen.insert(low_node_);
+ unvisited_fifo.push_back(std::pair(this, as_opponent));
+ }
+
+ while (!unvisited_fifo.empty()) {
+ auto [parent_node, parent_as_opponent] = unvisited_fifo.front();
+ unvisited_fifo.pop_front();
+
+ auto parent_low_node = parent_node->GetLowNode();
+ seen.insert(parent_low_node);
+ oss << parent_low_node->DotNodeString() << std::endl;
+
+ for (auto& child_edge : parent_node->Edges()) {
+ auto child = child_edge.node();
+ if (child == nullptr) break;
+
+ oss << child->DotEdgeString(parent_as_opponent) << std::endl;
+ auto child_low_node = child->GetLowNode();
+ if (child_low_node != nullptr &&
+ (seen.find(child_low_node) == seen.end())) {
+ seen.insert(child_low_node);
+ unvisited_fifo.push_back(std::pair(child, !parent_as_opponent));
+ }
}
- solid_children_ = false;
- } else {
- // Stores node which will have to survive (or nullptr if it's not found).
- std::unique_ptr saved_node;
- // Pointer to unique_ptr, so that we could move from it.
- for (std::unique_ptr* node = &child_; *node;
- node = &(*node)->sibling_) {
- // If current node is the one that we have to save.
- if (node->get() == node_to_save) {
- // Kill all remaining siblings.
- gNodeGc.AddToGcQueue(std::move((*node)->sibling_));
- // Save the node, and take the ownership from the unique_ptr.
- saved_node = std::move(*node);
- break;
+ }
+
+ oss << "}" << std::endl;
+
+ return oss.str();
+}
+
+bool Node::ZeroNInFlight() const {
+ std::unordered_set seen;
+ std::list unvisited_fifo;
+ size_t nonzero_node_count = 0;
+
+ if (GetNInFlight() > 0) {
+ std::cerr << DebugString() << std::endl;
+ ++nonzero_node_count;
+ }
+ if (low_node_) {
+ seen.insert(low_node_);
+ unvisited_fifo.push_back(this);
+ }
+
+ while (!unvisited_fifo.empty()) {
+ auto parent_node = unvisited_fifo.front();
+ unvisited_fifo.pop_front();
+
+ for (auto& child_edge : parent_node->Edges()) {
+ auto child = child_edge.node();
+ if (child == nullptr) break;
+
+ if (child->GetNInFlight() > 0) {
+ std::cerr << child->DebugString() << std::endl;
+ ++nonzero_node_count;
+ }
+
+ auto child_low_node = child->GetLowNode();
+ if (child_low_node != nullptr &&
+ (seen.find(child_low_node) == seen.end())) {
+ seen.insert(child_low_node);
+ unvisited_fifo.push_back(child);
}
}
- // Make saved node the only child. (kills previous siblings).
- gNodeGc.AddToGcQueue(std::move(child_));
- child_ = std::move(saved_node);
}
- if (!child_) {
- num_edges_ = 0;
- edges_.reset(); // Clear edges list.
+
+ if (nonzero_node_count > 0) {
+ std::cerr << "GetNInFlight() is nonzero on " << nonzero_node_count
+ << " nodes" << std::endl;
+ return false;
}
+
+ return true;
+}
+
+void Node::SortEdges() const {
+ assert(low_node_);
+ low_node_->SortEdges();
+}
+
+uint64_t Node::GetHash() const {
+ if (low_node_) {
+ return low_node_->GetHash();
+ } else {
+ return 0;
+ }
+}
+bool Node::IsTT() const { return low_node_ && low_node_->IsTT(); }
+
+static constexpr float wld_tolerance = 0.000001f;
+static constexpr float m_tolerance = 0.000001f;
+
+static bool WLDMInvariantsHold(float wl, float d, float m) {
+ return -(1.0f + wld_tolerance) < wl && wl < (1.0f + wld_tolerance) && //
+ -(0.0f + wld_tolerance) < d && d < (1.0f + wld_tolerance) && //
+ -(0.0f + m_tolerance) < m && //
+ std::abs(wl + d) < (1.0f + wld_tolerance);
+}
+
+bool Node::WLDMInvariantsHold() const {
+ if (lczero::WLDMInvariantsHold(GetWL(), GetD(), GetM())) return true;
+
+ std::cerr << DebugString() << std::endl;
+
+ return false;
+}
+
+bool LowNode::WLDMInvariantsHold() const {
+ if (lczero::WLDMInvariantsHold(GetWL(), GetD(), GetM())) return true;
+
+ std::cerr << DebugString() << std::endl;
+
+ return false;
}
/////////////////////////////////////////////////////////////////////////
@@ -465,33 +667,64 @@ std::string EdgeAndNode::DebugString() const {
void NodeTree::MakeMove(Move move) {
if (HeadPosition().IsBlackToMove()) move.Mirror();
const auto& board = HeadPosition().GetBoard();
+ auto hash = GetHistoryHash(history_);
+ move = board.GetModernMove(move); // TODO: Why convert here?
+ // Find edge for @move, if it exists.
Node* new_head = nullptr;
- for (auto& n : current_head_->Edges()) {
- if (board.IsSameMove(n.GetMove(), move)) {
- new_head = n.GetOrSpawnNode(current_head_);
- // Ensure head is not terminal, so search can extend or visit children of
- // "terminal" positions, e.g., WDL hits, converted terminals, 3-fold draw.
- if (new_head->IsTerminal()) new_head->MakeNotTerminal();
- break;
+ while (new_head == nullptr) {
+ for (auto& n : current_head_->Edges()) {
+ if (board.IsSameMove(n.GetMove(), move)) {
+ new_head = n.GetOrSpawnNode(current_head_);
+ // Ensure head is not terminal, so search can extend or visit children
+ // of "terminal" positions, e.g., WDL hits, converted terminals, 3-fold
+ // draw.
+ if (new_head->IsTerminal()) new_head->MakeNotTerminal();
+ break;
+ }
+ }
+
+ if (new_head != nullptr) break;
+
+ // Current head node (if any) is non-TT, does not have a matching edge and
+ // will be removed by NonTTMaintenance later.
+ current_head_->UnsetLowNode();
+
+ // Check TT first, then create, if necessary.
+ auto tt_iter = tt_.find(hash);
+ if (tt_iter != tt_.end()) {
+ current_head_->SetLowNode(tt_iter->second.get());
+ if (current_head_->IsTerminal()) current_head_->MakeNotTerminal();
+ } else {
+ non_tt_.emplace_back(std::make_unique(hash, MoveList({move}),
+ static_cast(0)));
+ current_head_->SetLowNode(non_tt_.back().get());
}
}
- move = board.GetModernMove(move);
- current_head_->ReleaseChildrenExceptOne(new_head);
- new_head = current_head_->child_.get();
- current_head_ =
- new_head ? new_head : current_head_->CreateSingleChildNode(move);
+
+ // Remove edges that will not be needed any more.
+ current_head_->ReleaseChildrenExceptOne(new_head, &gc_queue_);
+ new_head = current_head_->GetChild();
+
+ // Move damaged node from TT to non-TT to avoid reuse.
+ // It can have TT parents, until they get garbage collected.
+ if (current_head_->IsTT()) {
+ auto tt_iter = tt_.find(current_head_->GetHash());
+ tt_iter->second->ClearTT();
+ non_tt_.emplace_back(std::move(tt_iter->second));
+ tt_.erase(tt_iter);
+ }
+
+ current_head_ = new_head;
+
history_.Append(move);
+ moves_.push_back(move);
}
void NodeTree::TrimTreeAtHead() {
- // If solid, this will be empty before move and will be moved back empty
- // afterwards which is fine.
- auto tmp = std::move(current_head_->sibling_);
- // Send dependent nodes for GC instead of destroying them immediately.
- current_head_->ReleaseChildren();
- *current_head_ = Node(current_head_->GetParent(), current_head_->index_);
- current_head_->sibling_ = std::move(tmp);
+ current_head_->Trim(&gc_queue_);
+ // Free unused non-TT low nodes.
+ NonTTMaintenance();
}
bool NodeTree::ResetToPosition(const std::string& starting_fen,
@@ -508,11 +741,12 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen,
}
if (!gamebegin_node_) {
- gamebegin_node_ = std::make_unique(nullptr, 0);
+ gamebegin_node_ = std::make_unique(0);
}
history_.Reset(starting_board, no_capture_ply,
full_moves * 2 - (starting_board.flipped() ? 1 : 2));
+ moves_.clear();
Node* old_head = current_head_;
current_head_ = gamebegin_node_.get();
@@ -522,6 +756,9 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen,
if (old_head == current_head_) seen_old_head = true;
}
+ // Remove any non-TT nodes that were not reused.
+ NonTTMaintenance();
+
// MakeMove guarantees that no siblings exist; but, if we didn't see the old
// head, it means we might have a position that was an ancestor to a
// previously searched position, which means that the current_head_ might
@@ -532,11 +769,82 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen,
}
void NodeTree::DeallocateTree() {
- // Same as gamebegin_node_.reset(), but actual deallocation will happen in
- // GC thread.
- gNodeGc.AddToGcQueue(std::move(gamebegin_node_));
- gamebegin_node_ = nullptr;
+ gamebegin_node_.reset();
current_head_ = nullptr;
+ // Free all nodes.
+ // There may be non-TT children of TT nodes that were not garbage collected
+ // fast enough.
+ NonTTMaintenance();
+ TTClear();
+ non_tt_.clear();
+ gc_queue_.clear();
+}
+
+LowNode* NodeTree::TTFind(uint64_t hash) {
+ auto tt_iter = tt_.find(hash);
+ if (tt_iter != tt_.end()) {
+ return tt_iter->second.get();
+ } else {
+ return nullptr;
+ }
+}
+
+std::pair NodeTree::TTGetOrCreate(uint64_t hash) {
+ auto [tt_iter, is_tt_miss] =
+ tt_.insert({hash, std::make_unique(hash)});
+ return {tt_iter->second.get(), is_tt_miss};
+}
+
+void NodeTree::TTMaintenance() { TTGCSome(0); }
+
+void NodeTree::TTClear() {
+ // Make sure destructors don't fail.
+ absl::c_for_each(
+ tt_, [](const auto& item) { item.second->ReleaseChildren(nullptr); });
+ // Remove any released non-TT children of TT nodes that were not garbage
+ // collected fast enough.
+ NonTTMaintenance();
+ tt_.clear();
+ gc_queue_.clear();
+}
+
+LowNode* NodeTree::NonTTAddClone(const LowNode& node) {
+ non_tt_.push_back(std::make_unique(node));
+ return non_tt_.back().get();
+}
+
+void NodeTree::NonTTMaintenance() {
+ // Release children of parent-less nodes.
+ absl::c_for_each(non_tt_, [this](const auto& item) {
+ if (item->GetNumParents() == 0) item->ReleaseChildren(&gc_queue_);
+ });
+ // Erase parent-less nodes.
+ for (auto item = non_tt_.begin(); item != non_tt_.end();) {
+ if ((*item)->GetNumParents() == 0) {
+ item = non_tt_.erase(item);
+ } else {
+ ++item;
+ }
+ }
+}
+
+bool NodeTree::TTGCSome(size_t count) {
+ if (gc_queue_.empty()) return false;
+
+ for (auto n = count > 0 ? std::min(count, gc_queue_.size())
+ : gc_queue_.size();
+ n > 0; --n) {
+ auto hash = gc_queue_.front();
+ gc_queue_.pop_front();
+ auto tt_iter = tt_.find(hash);
+ if (tt_iter != tt_.end()) {
+ if (tt_iter->second->GetNumParents() == 0) {
+ tt_.erase(tt_iter);
+ }
+ }
+ }
+
+ return gc_queue_.empty();
}
} // namespace lczero
diff --git a/src/mcts/node.h b/src/mcts/node.h
index 2982de24da..324013b0ab 100644
--- a/src/mcts/node.h
+++ b/src/mcts/node.h
@@ -27,8 +27,12 @@
#pragma once
+#include
+
#include
#include
+#include
+#include
#include
#include
#include
@@ -36,26 +40,26 @@
#include "chess/board.h"
#include "chess/callbacks.h"
#include "chess/position.h"
-#include "neural/cache.h"
-#include "neural/encoder.h"
-#include "proto/net.pb.h"
+#include "mcts/params.h"
#include "utils/mutex.h"
namespace lczero {
-// Children of a node are stored the following way:
-// * Edges and Nodes edges point to are stored separately.
-// * There may be dangling edges (which don't yet point to any Node object yet)
-// * Edges are stored are a simple array on heap.
-// * Nodes are stored as a linked list, and contain index_ field which shows
-// which edge of a parent that node points to.
-// Or they are stored a contiguous array of Node objects on the heap if
-// solid_children_ is true. If the children have been 'solidified' their
-// sibling links are unused and left empty. In this state there are no
-// dangling edges, but the nodes may not have ever received any visits.
+// Terminology:
+// * Edge - a potential edge with a move and policy information.
+// * Node - an existing edge with number of visits and evaluation.
+// * LowNode - a node with number of visits, evaluation and edges.
+//
+// Storage:
+// * Potential edges are stored in a simple array inside the LowNode as edges_.
+// * Existing edges are stored in a linked list starting with a child_ pointer
+// in the LowNode and continuing with a sibling_ pointer in each Node.
+// * Existing edges have a copy of their potential edge counterpart, index_
+// among potential edges and are linked to the target LowNode via the
+// low_node_ pointer.
//
// Example:
-// Parent Node
+// LowNode
// |
// +-------------+-------------+----------------+--------------+
// | | | | |
@@ -64,21 +68,103 @@ namespace lczero {
// Node, Q=0.5 Node, Q=-0.2
//
// Is represented as:
-// +--------------+
-// | Parent Node |
-// +--------------+ +--------+
-// | edges_ | -------------------------------------> | Edge[] |
-// | | +------------+ +--------+
-// | child_ | -> | Node | | Nf3 |
-// +--------------+ +------------+ | Bc5 |
-// | index_ = 1 | | a4 |
-// | q_ = 0.5 | +------------+ | Qxf7 |
-// | sibling_ | -> | Node | | a3 |
-// +------------+ +------------+ +--------+
-// | index_ = 3 |
-// | q_ = -0.2 |
-// | sibling_ | -> nullptr
-// +------------+
+// +-----------------+
+// | LowNode |
+// +-----------------+ +--------+
+// | edges_ | -------------------------------------> | Edge[] |
+// | | +------------+ +--------+
+// | child_ | -> | Node | | Nf3 |
+// | | +------------+ | Bc5 |
+// | ... | | edge_ | | a4 |
+// | | | index_ = 1 | | Qxf7 |
+// | | | q_ = 0.5 | +------------+ | a3 |
+// | | | sibling_ | -> | Node | +--------+
+// | | +------------+ +------------+
+// | | | edge_ |
+// +-----------------+ | index_ = 3 |
+// | q_ = -0.2 |
+// | sibling_ | -> nullptr
+// +------------+
+
+// Define __i386__ or __arm__ also for 32 bit Windows.
+#if defined(_M_IX86)
+#define __i386__
+#endif
+#if defined(_M_ARM) && !defined(_M_AMD64)
+#define __arm__
+#endif
+
+// Atomic unique_ptr based on the public domain code from
+// https://stackoverflow.com/a/42811152 .
+template
+class atomic_unique_ptr {
+ using pointer = T*;
+ using unique_pointer = std::unique_ptr;
+
+ public:
+ // Manage no pointer.
+ constexpr atomic_unique_ptr() noexcept : ptr() {}
+
+ // Make pointer @p managed.
+ explicit atomic_unique_ptr(pointer p) noexcept : ptr(p) {}
+
+ // Move the managed pointer ownership from another atomic_unique_ptr.
+ atomic_unique_ptr(atomic_unique_ptr&& p) noexcept : ptr(p.release()) {}
+ // Move the managed pointer ownership from another atomic_unique_ptr.
+ atomic_unique_ptr& operator=(atomic_unique_ptr&& p) noexcept {
+ reset(p.release());
+ return *this;
+ }
+
+ // Move the managed object ownership from a unique_ptr.
+ atomic_unique_ptr(unique_pointer&& p) noexcept : ptr(p.release()) {}
+ // Move the managed object ownership from a unique_ptr.
+ atomic_unique_ptr& operator=(unique_pointer&& p) noexcept {
+ reset(p.release());
+ return *this;
+ }
+
+ // Replace the managed pointer, deleting the old one.
+ void reset(pointer p = pointer()) noexcept {
+ auto old = ptr.exchange(p, std::memory_order_acquire);
+ if (old) delete old;
+ }
+ // Release ownership of and delete the owned pointer.
+ ~atomic_unique_ptr() { reset(); }
+
+ // Returns the managed pointer.
+ operator pointer() const noexcept { return ptr; }
+ // Returns the managed pointer.
+ pointer operator->() const noexcept { return ptr; }
+ // Returns the managed pointer.
+ pointer get() const noexcept { return ptr; }
+
+ // Checks whether there is a managed pointer.
+ explicit operator bool() const noexcept { return ptr != pointer(); }
+
+ // Replace the managed pointer, only releasing returning the old one.
+ pointer set(pointer p = pointer()) noexcept {
+ return ptr.exchange(p, std::memory_order_acquire);
+ }
+ // Return the managed pointer and release its ownership.
+ pointer release() noexcept { return set(pointer()); }
+
+ // Move managed pointer from @source, iff the managed pointer equals
+ // @expected.
+ bool compare_exchange(pointer expected,
+ atomic_unique_ptr& source) noexcept {
+ if (ptr.compare_exchange_strong(expected, source.ptr,
+ std::memory_order_acq_rel)) {
+ source.release();
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ private:
+ std::atomic ptr;
+};
class Node;
class Edge {
@@ -98,6 +184,8 @@ class Edge {
// Debug information about the edge.
std::string DebugString() const;
+ static void SortEdges(Edge* edges, int num_edges);
+
private:
// Move corresponding to this node. From the point of view of a player,
// i.e. black's e7e5 is stored as e2e4.
@@ -116,6 +204,29 @@ struct Eval {
float ml;
};
+struct NNEval {
+ // To minimize the number of padding bytes and to avoid having unnecessary
+ // padding when new fields are added, we arrange the fields by size, largest
+ // to smallest.
+
+ // 8 byte fields on 64-bit platforms, 4 byte on 32-bit.
+ // Array of edges.
+ std::unique_ptr edges;
+
+ // 4 byte fields.
+ float q = 0.0f;
+ float d = 0.0f;
+ float m = 0.0f;
+
+ // 1 byte fields.
+ // Number of edges in @edges.
+ uint8_t num_edges = 0;
+};
+
+typedef std::pair Bounds;
+
+enum class Terminal : uint8_t { NonTerminal, EndOfGame, Tablebase };
+
class EdgeAndNode;
template
class Edge_Iterator;
@@ -123,47 +234,56 @@ class Edge_Iterator;
template
class VisitedNode_Iterator;
+class LowNode;
+typedef std::list GCQueue;
class Node {
public:
using Iterator = Edge_Iterator;
using ConstIterator = Edge_Iterator;
- enum class Terminal : uint8_t { NonTerminal, EndOfGame, Tablebase, TwoFold };
-
- // Takes pointer to a parent node and own index in a parent.
- Node(Node* parent, uint16_t index)
- : parent_(parent),
+ // Takes own @index in the parent.
+ Node(uint16_t index)
+ : index_(index),
+ terminal_type_(Terminal::NonTerminal),
+ lower_bound_(GameResult::BLACK_WON),
+ upper_bound_(GameResult::WHITE_WON),
+ repetition_(false) {}
+ // Takes own @edge and @index in the parent.
+ Node(const Edge& edge, uint16_t index)
+ : edge_(edge),
index_(index),
terminal_type_(Terminal::NonTerminal),
lower_bound_(GameResult::BLACK_WON),
upper_bound_(GameResult::WHITE_WON),
- solid_children_(false) {}
-
- // We have a custom destructor, but its behavior does not need to be emulated
- // during move operations so default is fine.
- Node(Node&& move_from) = default;
- Node& operator=(Node&& move_from) = default;
-
- // Allocates a new edge and a new node. The node has to be no edges before
- // that.
- Node* CreateSingleChildNode(Move m);
-
- // Creates edges from a movelist. There has to be no edges before that.
- void CreateEdges(const MoveList& moves);
-
- // Gets parent node.
- Node* GetParent() const { return parent_; }
+ repetition_(false) {}
+ ~Node() { UnsetLowNode(); }
+
+ // Trim node, resetting everything except parent, sibling, edge and index.
+ void Trim(GCQueue* gc_queue);
+
+ // Get first child.
+ Node* GetChild() const;
+ // Get next sibling.
+ atomic_unique_ptr* GetSibling() { return &sibling_; }
+ // Moves sibling in.
+ void MoveSiblingIn(atomic_unique_ptr& sibling) {
+ sibling_ = std::move(sibling);
+ }
// Returns whether a node has children.
- bool HasChildren() const { return static_cast(edges_); }
+ bool HasChildren() const;
// Returns sum of policy priors which have had at least one playout.
float GetVisitedPolicy() const;
uint32_t GetN() const { return n_; }
- uint32_t GetNInFlight() const { return n_in_flight_; }
- uint32_t GetChildrenVisits() const { return n_ > 0 ? n_ - 1 : 0; }
- // Returns n = n_if_flight.
- int GetNStarted() const { return n_ + n_in_flight_; }
+ uint32_t GetNInFlight() const;
+ uint32_t GetChildrenVisits() const;
+ uint32_t GetTotalVisits() const;
+ // Returns n + n_in_flight.
+ int GetNStarted() const {
+ return n_ + n_in_flight_.load(std::memory_order_acquire);
+ }
+
float GetQ(float draw_score) const { return wl_ + draw_score * d_; }
// Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1
// for terminal nodes.
@@ -174,25 +294,16 @@ class Node {
// Returns whether the node is known to be draw/lose/win.
bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
bool IsTbTerminal() const { return terminal_type_ == Terminal::Tablebase; }
- bool IsTwoFoldTerminal() const { return terminal_type_ == Terminal::TwoFold; }
- typedef std::pair Bounds;
Bounds GetBounds() const { return {lower_bound_, upper_bound_}; }
- uint8_t GetNumEdges() const { return num_edges_; }
- // Output must point to at least max_needed floats.
- void CopyPolicy(int max_needed, float* output) const {
- if (!edges_) return;
- int loops = std::min(static_cast(num_edges_), max_needed);
- for (int i = 0; i < loops; i++) {
- output[i] = edges_[i].GetP();
- }
- }
+ uint8_t GetNumEdges() const;
// Makes the node terminal and sets it's score.
- void MakeTerminal(GameResult result, float plies_left = 0.0f,
+ void MakeTerminal(GameResult result, float plies_left = 1.0f,
Terminal type = Terminal::EndOfGame);
- // Makes the node not terminal and updates its visits.
- void MakeNotTerminal();
+ // Makes the node not terminal and recomputes bounds, visits and values.
+ // Changes low node as well unless @also_low_node is false.
+ void MakeNotTerminal(bool also_low_node = true);
void SetBounds(GameResult lower, GameResult upper);
// If this node is not in the process of being expanded by another thread
@@ -201,28 +312,19 @@ class Node {
// Otherwise return false.
bool TryStartScoreUpdate();
// Decrements n-in-flight back.
- void CancelScoreUpdate(int multivisit);
+ void CancelScoreUpdate(uint32_t multivisit);
// Updates the node with newly computed value v.
// Updates:
// * Q (weighted average of all V in a subtree)
- // * N (+=1)
- // * N-in-flight (-=1)
- void FinalizeScoreUpdate(float v, float d, float m, int multivisit);
+ // * N (+=multivisit)
+ // * N-in-flight (-=multivisit)
+ void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
- void AdjustForTerminal(float v, float d, float m, int multivisit);
- // Revert visits to a node which ended in a now reverted terminal.
- void RevertTerminalVisits(float v, float d, float m, int multivisit);
+ void AdjustForTerminal(float v, float d, float m, uint32_t multivisit);
// When search decides to treat one visit as several (in case of collisions
// or visiting terminal nodes several times), it amplifies the visit by
// incrementing n_in_flight.
- void IncrementNInFlight(int multivisit) { n_in_flight_ += multivisit; }
-
- // Updates max depth, if new depth is larger.
- void UpdateMaxDepth(int depth);
-
- // Calculates the full depth if new depth is larger, updates it, returns
- // in depth parameter, and returns true if it was indeed updated.
- bool UpdateFullDepth(uint16_t* depth);
+ void IncrementNInFlight(uint32_t multivisit);
// Returns range for iterating over edges.
ConstIterator Edges() const;
@@ -232,48 +334,55 @@ class Node {
VisitedNode_Iterator VisitedNodes() const;
VisitedNode_Iterator VisitedNodes();
- // Deletes all children.
- void ReleaseChildren();
-
// Deletes all children except one.
// The node provided may be moved, so should not be relied upon to exist
// afterwards.
- void ReleaseChildrenExceptOne(Node* node);
+ void ReleaseChildrenExceptOne(Node* node_to_save, GCQueue* gc_queue) const;
- // For a child node, returns corresponding edge.
- Edge* GetEdgeToNode(const Node* node) const;
+ // Returns move from the point of view of the player making it (if as_opponent
+ // is false) or as opponent (if as_opponent is true).
+ Move GetMove(bool as_opponent = false) const {
+ return edge_.GetMove(as_opponent);
+ }
+ // Returns or sets value of Move policy prior returned from the neural net
+ // (but can be changed by adding Dirichlet noise or when turning terminal).
+ // Must be in [0,1].
+ float GetP() const { return edge_.GetP(); }
+ void SetP(float val) { edge_.SetP(val); }
- // Returns edge to the own node.
- Edge* GetOwnEdge() const;
+ LowNode* GetLowNode() const { return low_node_; }
+
+ void SetLowNode(LowNode* low_node);
+ void UnsetLowNode();
// Debug information about the node.
std::string DebugString() const;
+ // Return string describing the edge from node's parent to its low node in the
+ // Graphviz dot format.
+ std::string DotEdgeString(bool as_opponent = false,
+ const LowNode* parent = nullptr) const;
+ // Return string describing the graph starting at this node in the Graphviz
+ // dot format.
+ std::string DotGraphString(bool as_opponent = false) const;
- // Reallocates this nodes children to be in a solid block, if possible and not
- // already done. Returns true if the transformation was performed.
- bool MakeSolid();
+ // Returns true if graph under this node has every n_in_flight_ == 0 and
+ // prints offending nodes and low nodes and stats to cerr otherwise.
+ bool ZeroNInFlight() const;
- void SortEdges();
+ void SortEdges() const;
- // Index in parent edges - useful for correlated ordering.
+ // Index in parent's edges - useful for correlated ordering.
uint16_t Index() const { return index_; }
- ~Node() {
- if (solid_children_ && child_) {
- // As a hack, solid_children is actually storing an array in here, release
- // so we can correctly invoke the array delete.
- for (int i = 0; i < num_edges_; i++) {
- child_.get()[i].~Node();
- }
- std::allocator alloc;
- alloc.deallocate(child_.release(), num_edges_);
- }
- }
+ void SetRepetition() { repetition_ = true; }
+ bool IsRepetition() const { return repetition_; }
- private:
- // For each child, ensures that its parent pointer is pointing to this.
- void UpdateChildrenParents();
+ uint64_t GetHash() const;
+ bool IsTT() const;
+ bool WLDMInvariantsHold() const;
+
+ private:
// To minimize the number of padding bytes and to avoid having unnecessary
// padding when new fields are added, we arrange the fields by size, largest
// to smallest.
@@ -281,22 +390,16 @@ class Node {
// 8 byte fields.
// Average value (from value head of neural network) of all visited nodes in
// subtree. For terminal nodes, eval is stored. This is from the perspective
- // of the player who "just" moved to reach this position, rather than from the
- // perspective of the player-to-move for the position.
- // WL stands for "W minus L". Is equal to Q if draw score is 0.
+ // of the player who "just" moved to reach this position, rather than from
+ // the perspective of the player-to-move for the position. WL stands for "W
+ // minus L". Is equal to Q if draw score is 0.
double wl_ = 0.0f;
// 8 byte fields on 64-bit platforms, 4 byte on 32-bit.
- // Array of edges.
- std::unique_ptr edges_;
- // Pointer to a parent node. nullptr for the root.
- Node* parent_ = nullptr;
- // Pointer to a first child. nullptr for a leaf node.
- // As a 'hack' actually a unique_ptr to Node[] if solid_children.
- std::unique_ptr child_;
+ // Pointer to the low node.
+ LowNode* low_node_ = nullptr;
// Pointer to a next sibling. nullptr if there are no further siblings.
- // Also null in the solid case.
- std::unique_ptr sibling_;
+ atomic_unique_ptr sibling_;
// 4 byte fields.
// Averaged draw probability. Works similarly to WL, except that D is not
@@ -309,48 +412,228 @@ class Node {
// (AKA virtual loss.) How many threads currently process this node (started
// but not finished). This value is added to n during selection which node
// to pick in MCTS, and also when selecting the best move.
- uint32_t n_in_flight_ = 0;
+ std::atomic n_in_flight_ = 0;
+
+ // Move and policy for this edge.
+ Edge edge_;
// 2 byte fields.
// Index of this node is parent's edge list.
uint16_t index_;
+ // 1 byte fields.
+ // Bit fields using parts of uint8_t fields initialized in the constructor.
+ // Whether or not this node end game (with a winning of either sides or
+ // draw).
+ Terminal terminal_type_ : 2;
+ // Best and worst result for this node.
+ GameResult lower_bound_ : 2;
+ GameResult upper_bound_ : 2;
+ // Edge was handled as a repetition at some point.
+ bool repetition_ : 1;
+};
+
+// Check that Node still fits into an expected cache line size.
+static_assert(sizeof(Node) <= 64, "Node is too large");
+
+class LowNode {
+ public:
+ // For TT nodes.
+ LowNode(uint64_t hash)
+ : hash_(hash),
+ terminal_type_(Terminal::NonTerminal),
+ lower_bound_(GameResult::BLACK_WON),
+ upper_bound_(GameResult::WHITE_WON),
+ is_transposition(false),
+ is_tt_(true) {}
+ // Init from another low node, but use it for NNEval only.
+ // For non-TT nodes.
+ LowNode(const LowNode& p)
+ : wl_(p.wl_),
+ hash_(p.hash_),
+ d_(p.d_),
+ m_(p.m_),
+ num_edges_(p.num_edges_),
+ terminal_type_(Terminal::NonTerminal),
+ lower_bound_(GameResult::BLACK_WON),
+ upper_bound_(GameResult::WHITE_WON),
+ is_transposition(false),
+ is_tt_(false) {
+ assert(p.edges_);
+ edges_ = std::make_unique(num_edges_);
+ std::memcpy(edges_.get(), p.edges_.get(), num_edges_ * sizeof(Edge));
+ }
+ // Init @edges_ with moves from @moves and 0 policy.
+ // Also create the first child at @index.
+ // For non-TT nodes.
+ LowNode(uint64_t hash, const MoveList& moves, uint16_t index)
+ : hash_(hash),
+ num_edges_(moves.size()),
+ terminal_type_(Terminal::NonTerminal),
+ lower_bound_(GameResult::BLACK_WON),
+ upper_bound_(GameResult::WHITE_WON),
+ is_transposition(false),
+ is_tt_(false) {
+ edges_ = Edge::FromMovelist(moves);
+ child_ = std::make_unique(edges_[index], index);
+ }
+
+ void SetNNEval(const NNEval* eval) {
+ assert(!edges_);
+ assert(n_ == 0);
+ assert(!child_);
+
+ edges_ = std::make_unique(eval->num_edges);
+ std::memcpy(edges_.get(), eval->edges.get(),
+ eval->num_edges * sizeof(Edge));
+
+ wl_ = eval->q;
+ d_ = eval->d;
+ m_ = eval->m;
+
+ assert(WLDMInvariantsHold());
+
+ num_edges_ = eval->num_edges;
+ }
+
+ // Gets the first child.
+ atomic_unique_ptr* GetChild() { return &child_; }
+
+ // Returns whether a node has children.
+ bool HasChildren() const { return num_edges_ > 0; }
+
+ uint32_t GetN() const { return n_; }
+ uint32_t GetChildrenVisits() const { return n_ - 1; }
+
+ // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1
+ // for terminal nodes.
+ float GetWL() const { return wl_; }
+ float GetD() const { return d_; }
+ float GetM() const { return m_; }
+
+ // Returns whether the node is known to be draw/loss/win.
+ bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
+ Bounds GetBounds() const { return {lower_bound_, upper_bound_}; }
+ Terminal GetTerminalType() const { return terminal_type_; }
+
+ uint8_t GetNumEdges() const { return num_edges_; }
+ // Gets pointer to the start of the edge array.
+ Edge* GetEdges() const { return edges_.get(); }
+
+ // Makes the node terminal and sets it's score.
+ void MakeTerminal(GameResult result, float plies_left = 0.0f,
+ Terminal type = Terminal::EndOfGame);
+ // Makes the low node not terminal and recomputes bounds, visits and values
+ // using incoming @node.
+ void MakeNotTerminal(const Node* node);
+ void SetBounds(GameResult lower, GameResult upper);
+
+ // Decrements n-in-flight back.
+ void CancelScoreUpdate(uint32_t multivisit);
+ // Updates the node with newly computed value v.
+ // Updates:
+ // * Q (weighted average of all V in a subtree)
+ // * N (+=multivisit)
+ // * N-in-flight (-=multivisit)
+ void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit);
+ // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
+ void AdjustForTerminal(float v, float d, float m, uint32_t multivisit);
+
+ // Deletes all children.
+ void ReleaseChildren(GCQueue* gc_queue);
+
+ // Deletes all children except one.
+ // The node provided may be moved, so should not be relied upon to exist
+ // afterwards.
+ void ReleaseChildrenExceptOne(Node* node_to_save, GCQueue* gc_queue);
+
+ // Return move policy for edge/node at @index.
+ const Edge& GetEdgeAt(uint16_t index) const;
+
+ // Debug information about the node.
+ std::string DebugString() const;
+ // Return string describing this node in the Graphviz dot format.
+ std::string DotNodeString() const;
+
+ void SortEdges() {
+ assert(edges_);
+ assert(!child_);
+ Edge::SortEdges(edges_.get(), num_edges_);
+ }
+
+ // Add new parent with @n_in_flight visits.
+ void AddParent() {
+ ++num_parents_;
+
+ assert(num_parents_ > 0);
+
+ is_transposition |= num_parents_ > 1;
+ }
+ // Remove parent and its first visit.
+ void RemoveParent() {
+ assert(num_parents_ > 0);
+ --num_parents_;
+ }
+ uint16_t GetNumParents() const { return num_parents_; }
+ bool IsTransposition() const { return is_transposition; }
+
+ uint64_t GetHash() const { return hash_; }
+ bool IsTT() const { return is_tt_; }
+ void ClearTT() { is_tt_ = !is_tt_; }
+
+ bool WLDMInvariantsHold() const;
+
+ private:
+ // To minimize the number of padding bytes and to avoid having unnecessary
+ // padding when new fields are added, we arrange the fields by size, largest
+ // to smallest.
+
+ // 8 byte fields.
+ // Average value (from value head of neural network) of all visited nodes in
+ // subtree. For terminal nodes, eval is stored. This is from the perspective
+ // of the player who "just" moved to reach this position, rather than from the
+ // perspective of the player-to-move for the position.
+ // WL stands for "W minus L". Is equal to Q if draw score is 0.
+ double wl_ = 0.0f;
+ // Position hash and a TT key.
+ uint64_t hash_ = 0;
+
+ // 8 byte fields on 64-bit platforms, 4 byte on 32-bit.
+ // Array of edges.
+ std::unique_ptr edges_;
+ // Pointer to the first child. nullptr when no children.
+ atomic_unique_ptr child_;
+
+ // 4 byte fields.
+ // Averaged draw probability. Works similarly to WL, except that D is not
+ // flipped depending on the side to move.
+ float d_ = 0.0f;
+ // Estimated remaining plies.
+ float m_ = 0.0f;
+ // How many completed visits this node had.
+ uint32_t n_ = 0;
+
+ // 2 byte fields.
+ // Number of parents.
+ uint16_t num_parents_ = 0;
+
// 1 byte fields.
// Number of edges in @edges_.
uint8_t num_edges_ = 0;
-
// Bit fields using parts of uint8_t fields initialized in the constructor.
// Whether or not this node end game (with a winning of either sides or draw).
Terminal terminal_type_ : 2;
// Best and worst result for this node.
GameResult lower_bound_ : 2;
GameResult upper_bound_ : 2;
- // Whether the child_ is actually an array of equal length to edges.
- bool solid_children_ : 1;
-
- // TODO(mooskagh) Unfriend NodeTree.
- friend class NodeTree;
- friend class Edge_Iterator;
- friend class Edge_Iterator;
- friend class Edge;
- friend class VisitedNode_Iterator;
- friend class VisitedNode_Iterator;
+ // Low node is a transposition (for ever).
+ bool is_transposition : 1;
+ // Low node is in TT, i.e. was not evaluated or was modified.
+ bool is_tt_ : 1;
};
-// Define __i386__ or __arm__ also for 32 bit Windows.
-#if defined(_M_IX86)
-#define __i386__
-#endif
-#if defined(_M_ARM) && !defined(_M_AMD64)
-#define __arm__
-#endif
-
-// A basic sanity check. This must be adjusted when Node members are adjusted.
-#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__))
-static_assert(sizeof(Node) == 48, "Unexpected size of Node for 32bit compile");
-#else
-static_assert(sizeof(Node) == 64, "Unexpected size of Node");
-#endif
+// Check that LowNode still fits into an expected cache line size.
+static_assert(sizeof(LowNode) <= 64, "LowNode is too large");
// Contains Edge and Node pair and set of proxy functions to simplify access
// to them.
@@ -391,13 +674,15 @@ class EdgeAndNode {
// Whether the node is known to be terminal.
bool IsTerminal() const { return node_ ? node_->IsTerminal() : false; }
bool IsTbTerminal() const { return node_ ? node_->IsTbTerminal() : false; }
- Node::Bounds GetBounds() const {
+ Bounds GetBounds() const {
return node_ ? node_->GetBounds()
- : Node::Bounds{GameResult::BLACK_WON, GameResult::WHITE_WON};
+ : Bounds{GameResult::BLACK_WON, GameResult::WHITE_WON};
}
// Edge related getters.
- float GetP() const { return edge_->GetP(); }
+ float GetP() const {
+ return node_ != nullptr ? node_->GetP() : edge_->GetP();
+ }
Move GetMove(bool flip = false) const {
return edge_ ? edge_->GetMove(flip) : Move();
}
@@ -436,21 +721,20 @@ class EdgeAndNode {
template
class Edge_Iterator : public EdgeAndNode {
public:
- using Ptr = std::conditional_t*,
- std::unique_ptr*>;
+ using Ptr = std::conditional_t*,
+ atomic_unique_ptr*>;
// Creates "end()" iterator.
Edge_Iterator() {}
- // Creates "begin()" iterator. Also happens to be a range constructor.
- // child_ptr will be nullptr if parent_node is solid children.
- Edge_Iterator(const Node& parent_node, Ptr child_ptr)
- : EdgeAndNode(parent_node.edges_.get(), nullptr),
- node_ptr_(child_ptr),
- total_count_(parent_node.num_edges_) {
- if (edge_ && child_ptr != nullptr) Actualize();
- if (edge_ && child_ptr == nullptr) {
- node_ = parent_node.child_.get();
+ // Creates "begin()" iterator.
+ Edge_Iterator(LowNode* parent_node)
+ : EdgeAndNode(parent_node != nullptr ? parent_node->GetEdges() : nullptr,
+ nullptr) {
+ if (parent_node != nullptr) {
+ node_ptr_ = parent_node->GetChild();
+ total_count_ = parent_node->GetNumEdges();
+ if (edge_) Actualize();
}
}
@@ -466,11 +750,7 @@ class Edge_Iterator : public EdgeAndNode {
edge_ = nullptr;
} else {
++edge_;
- if (node_ptr_ != nullptr) {
- Actualize();
- } else {
- ++node_;
- }
+ Actualize();
}
}
Edge_Iterator& operator*() { return *this; }
@@ -478,27 +758,43 @@ class Edge_Iterator : public EdgeAndNode {
// If there is node, return it. Otherwise spawn a new one and return it.
Node* GetOrSpawnNode(Node* parent) {
if (node_) return node_; // If there is already a node, return it.
- // Should never reach here in solid mode.
- assert(node_ptr_ != nullptr);
- Actualize(); // But maybe other thread already did that.
- if (node_) return node_; // If it did, return.
- // Now we are sure we have to create a new node.
- // Suppose there are nodes with idx 3 and 7, and we want to insert one with
- // idx 5. Here is how it looks like:
- // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.7)
- // Here is how we do that:
- // 1. Store pointer to a node idx_.7:
- // node_ptr_ -> &Node(idx_.3).sibling_ -> nullptr
- // tmp -> Node(idx_.7)
- std::unique_ptr tmp = std::move(*node_ptr_);
- // 2. Create fresh Node(idx_.5):
- // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.5)
- // tmp -> Node(idx_.7)
- *node_ptr_ = std::make_unique(parent, current_idx_);
- // 3. Attach stored pointer back to a list:
- // node_ptr_ ->
- // &Node(idx_.3).sibling_ -> Node(idx_.5).sibling_ -> Node(idx_.7)
- (*node_ptr_)->sibling_ = std::move(tmp);
+
+ // We likely need to add a new node, prepare it now.
+ atomic_unique_ptr new_node = std::make_unique(
+ parent->GetLowNode()->GetEdgeAt(current_idx_), current_idx_);
+ while (true) {
+ auto node = Actualize(); // But maybe other thread already did that.
+ if (node_) return node_; // If it did, return.
+
+ // New node needs to be added, but we might be in a race with another
+ // thread doing what we do or adding a different index to the same
+ // sibling.
+
+ // Suppose there are nodes with idx 3 and 7, and we want to insert one
+ // with idx 5. Here is how it looks like:
+ // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.7)
+ // Here is how we do that:
+ // 1. Store pointer to a node idx_.7:
+ // node_ptr_ -> &Node(idx_.3).sibling_ -> nullptr
+ // tmp -> Node(idx_.7)
+ // 2. Create fresh Node(idx_.5):
+ // node_ptr_ -> &Node(idx_.3).sibling_ -> Node(idx_.5)
+ // tmp -> Node(idx_.7)
+ // 3. Attach stored pointer back to a list:
+ // node_ptr_ ->
+ // &Node(idx_.3).sibling_ -> Node(idx_.5).sibling_ -> Node(idx_.7)
+
+ // Atomically add the new node into the right place.
+ // Set new node's sibling to the expected sibling seen by Actualize in
+ // node_ptr_.
+ auto new_sibling = new_node->GetSibling();
+ new_sibling->set(node);
+ // Try to atomically insert the new node and stop if it works.
+ if (node_ptr_->compare_exchange(node, new_node)) break;
+ // Recover from failure and try again.
+ // Release expected sibling to avoid double free.
+ new_sibling->release();
+ }
// 4. Actualize:
// node_ -> &Node(idx_.5)
// node_ptr_ -> &Node(idx_.5).sibling_ -> Node(idx_.7)
@@ -507,24 +803,30 @@ class Edge_Iterator : public EdgeAndNode {
}
private:
- void Actualize() {
- // This must never be called in solid mode.
- assert(node_ptr_ != nullptr);
+ // Moves node_ptr_ as close as possible to the target index and returns the
+ // contents of node_ptr_ for use by atomic insert in GetOrSpawnNode.
+ Node* Actualize() {
// If node_ptr_ is behind, advance it.
// This is needed (and has to be 'while' rather than 'if') as other threads
// could spawn new nodes between &node_ptr_ and *node_ptr_ while we didn't
// see.
- while (*node_ptr_ && (*node_ptr_)->index_ < current_idx_) {
- node_ptr_ = &(*node_ptr_)->sibling_;
+ // Read the direct pointer just once as other threads may change it between
+ // uses.
+ auto node = node_ptr_->get();
+ while (node != nullptr && node->Index() < current_idx_) {
+ node_ptr_ = node->GetSibling();
+ node = node_ptr_->get();
}
// If in the end node_ptr_ points to the node that we need, populate node_
// and advance node_ptr_.
- if (*node_ptr_ && (*node_ptr_)->index_ == current_idx_) {
- node_ = (*node_ptr_).get();
- node_ptr_ = &node_->sibling_;
+ if (node != nullptr && node->Index() == current_idx_) {
+ node_ = node;
+ node_ptr_ = node->GetSibling();
} else {
node_ = nullptr;
}
+
+ return node;
}
// Pointer to a pointer to the next node. Has to be a pointer to pointer
@@ -534,6 +836,9 @@ class Edge_Iterator : public EdgeAndNode {
uint16_t total_count_ = 0;
};
+inline Node::ConstIterator Node::Edges() const { return {this->GetLowNode()}; }
+inline Node::Iterator Node::Edges() { return {this->GetLowNode()}; }
+
// TODO(crem) Replace this with less hacky iterator once we support C++17.
// This class has multiple hypostases within one class:
// * Range (begin() and end() functions)
@@ -549,16 +854,17 @@ class VisitedNode_Iterator {
// Creates "end()" iterator.
VisitedNode_Iterator() {}
- // Creates "begin()" iterator. Also happens to be a range constructor.
- // child_ptr will be nullptr if parent_node is solid children.
- VisitedNode_Iterator(const Node& parent_node, Node* child_ptr)
- : node_ptr_(child_ptr),
- total_count_(parent_node.num_edges_),
- solid_(parent_node.solid_children_) {
- if (node_ptr_ != nullptr && node_ptr_->GetN() == 0) {
- operator++();
+ // Creates "begin()" iterator.
+ VisitedNode_Iterator(LowNode* parent_node) {
+ if (parent_node != nullptr) {
+ node_ptr_ = parent_node->GetChild()->get();
+ total_count_ = parent_node->GetNumEdges();
+ if (node_ptr_ != nullptr && node_ptr_->GetN() == 0) {
+ operator++();
+ }
}
}
+
// These are technically wrong, but are usable to compare with end().
bool operator==(const VisitedNode_Iterator& other) const {
return node_ptr_ == other.node_ptr_;
@@ -574,72 +880,56 @@ class VisitedNode_Iterator {
// Functions to support iterator interface.
// Equality comparison operators are inherited from EdgeAndNode.
void operator++() {
- if (solid_) {
- while (++current_idx_ != total_count_ &&
- node_ptr_[current_idx_].GetN() == 0) {
- if (node_ptr_[current_idx_].GetNInFlight() == 0) {
- // Once there is not even n in flight, we can skip to the end. This is
- // due to policy being in sorted order meaning that additional n in
- // flight are always selected from the front of the section with no n
- // in flight or visited.
- current_idx_ = total_count_;
- break;
- }
- }
- if (current_idx_ == total_count_) {
+ do {
+ node_ptr_ = node_ptr_->GetSibling()->get();
+ // If n started is 0, can jump direct to end due to sorted policy
+ // ensuring that each time a new edge becomes best for the first time,
+ // it is always the first of the section at the end that has NStarted of
+ // 0.
+ if (node_ptr_ != nullptr && node_ptr_->GetN() == 0 &&
+ node_ptr_->GetNInFlight() == 0) {
node_ptr_ = nullptr;
+ break;
}
- } else {
- do {
- node_ptr_ = node_ptr_->sibling_.get();
- // If n started is 0, can jump direct to end due to sorted policy
- // ensuring that each time a new edge becomes best for the first time,
- // it is always the first of the section at the end that has NStarted of
- // 0.
- if (node_ptr_ != nullptr && node_ptr_->GetN() == 0 &&
- node_ptr_->GetNInFlight() == 0) {
- node_ptr_ = nullptr;
- break;
- }
- } while (node_ptr_ != nullptr && node_ptr_->GetN() == 0);
- }
- }
- Node* operator*() {
- if (solid_) {
- return &(node_ptr_[current_idx_]);
- } else {
- return node_ptr_;
- }
+ } while (node_ptr_ != nullptr && node_ptr_->GetN() == 0);
}
+ Node* operator*() { return node_ptr_; }
private:
// Pointer to current node.
Node* node_ptr_ = nullptr;
uint16_t current_idx_ = 0;
uint16_t total_count_ = 0;
- bool solid_ = false;
};
inline VisitedNode_Iterator Node::VisitedNodes() const {
- return {*this, child_.get()};
+ return {this->GetLowNode()};
}
inline VisitedNode_Iterator Node::VisitedNodes() {
- return {*this, child_.get()};
+ return {this->GetLowNode()};
}
class NodeTree {
public:
+ // Transposition Table (TT) type for holding all normal low nodes in the DAG.
+ typedef absl::flat_hash_map>
+ TranspositionTable;
+
+ // Apply search params.
+ NodeTree(const SearchParams& params)
+ : hash_history_length_(params.GetCacheHistoryLength() + 1) {}
+ // When search params are not available.
+ NodeTree() : hash_history_length_(1) {}
~NodeTree() { DeallocateTree(); }
+
// Adds a move to current_head_.
void MakeMove(Move move);
// Resets the current head to ensure it doesn't carry over details from a
// previous search.
void TrimTreeAtHead();
- // Sets the position in a tree, trying to reuse the tree.
- // If @auto_garbage_collect, old tree is garbage collected immediately. (may
- // take some milliseconds)
- // Returns whether a new position the same game as old position (with some
- // moves added). Returns false, if the position is completely different,
+ // Sets the position in the tree, trying to reuse the tree.
+ // Returns whether the new position is the same game as the old position (with
+ // some moves added). Returns false, if the position is completely different,
// or if it's shorter than before.
bool ResetToPosition(const std::string& starting_fen,
const std::vector& moves);
@@ -649,14 +939,60 @@ class NodeTree {
Node* GetCurrentHead() const { return current_head_; }
Node* GetGameBeginNode() const { return gamebegin_node_.get(); }
const PositionHistory& GetPositionHistory() const { return history_; }
+ const std::vector& GetMoves() const { return moves_; }
+
+ // Look up a low node in the Transposition Table by @hash and return it, or
+ // nullptr on failure.
+ LowNode* TTFind(uint64_t hash);
+ // Get a low node for the @hash from the Transposition Table or create a
+ // new low node and insert it into the Transposition Table if it is not there
+ // already. Return the low node for the hash.
+ std::pair TTGetOrCreate(uint64_t hash);
+ // Evict unused low nodes from the Transposition Table.
+ void TTMaintenance();
+ // Clear the Transposition Table.
+ // NOTE: Safe only when non-TT nodes were already detached.
+ void TTClear();
+ // Release the first @count items from TT GC list. @count == 0 means release
+ // all. Return true, if there is more to release.
+ bool TTGCSome(size_t count = 1);
+
+ // Add a clone of low @node to special nodes outside of the Transposition
+ // Table and return it.
+ LowNode* NonTTAddClone(const LowNode& node);
+
+ size_t AllocatedNodeCount() const { return tt_.size() + non_tt_.size(); };
+
+ // Get position hash used for TT nodes and NN cache.
+ uint64_t GetHistoryHash(const PositionHistory& history) const {
+ return history.HashLast(hash_history_length_);
+ }
private:
void DeallocateTree();
+
+ // Evict unused non-TT low nodes.
+ void NonTTMaintenance();
+
// A node which to start search from.
Node* current_head_ = nullptr;
// Root node of a game tree.
std::unique_ptr gamebegin_node_;
PositionHistory history_;
+ std::vector moves_;
+
+ // Transposition Table (TT) for holding references to all normal low nodes in
+ // the DAG.
+ TranspositionTable tt_;
+ // Collection of low nodes that are not fit for Transposition Table due to
+ // noise or incomplete information.
+ std::vector> non_tt_;
+
+ // History positions to hash in node hashes used in TT and NN cache.
+ int hash_history_length_;
+
+ // Garbage collection queue.
+ GCQueue gc_queue_;
};
} // namespace lczero
diff --git a/src/mcts/params.cc b/src/mcts/params.cc
index 6fbc4cfb30..88a6eb14a5 100644
--- a/src/mcts/params.cc
+++ b/src/mcts/params.cc
@@ -28,8 +28,11 @@
#include "mcts/params.h"
#include
+#include
+#include
#include "utils/exception.h"
+#include "utils/string.h"
#if __has_include("params_override.h")
#include "params_override.h"
@@ -38,11 +41,8 @@
#ifndef DEFAULT_MINIBATCH_SIZE
#define DEFAULT_MINIBATCH_SIZE 256
#endif
-#ifndef DEFAULT_MAX_PREFETCH
-#define DEFAULT_MAX_PREFETCH 32
-#endif
#ifndef DEFAULT_TASK_WORKERS
-#define DEFAULT_TASK_WORKERS 4
+#define DEFAULT_TASK_WORKERS 3
#endif
namespace lczero {
@@ -55,6 +55,97 @@ FillEmptyHistory EncodeHistoryFill(std::string history_fill) {
return FillEmptyHistory::NO;
}
+ContemptPerspective EncodeContemptPerspective(std::string perspective) {
+ if (perspective == "sidetomove") return ContemptPerspective::STM;
+ if (perspective == "white") return ContemptPerspective::WHITE;
+ if (perspective == "black") return ContemptPerspective::BLACK;
+ assert(perspective == "none");
+ return ContemptPerspective::NONE;
+}
+
+float GetContempt(std::string name, std::string contempt_str) {
+ float contempt = 0;
+ for (auto& entry : StrSplit(contempt_str, ",")) {
+ auto parts = StrSplit(entry, "=");
+ if (parts.size() == 1) {
+ try {
+ contempt = std::stof(parts[0]);
+ } catch (std::exception& e) {
+ throw Exception("Invalid default contempt: " + entry);
+ }
+ } else if (parts.size() == 2) {
+ if (std::search(name.begin(), name.end(), parts[0].begin(),
+ parts[0].end(), [](unsigned char a, unsigned char b) {
+ return std::tolower(a) == std::tolower(b);
+ }) != name.end()) {
+ try {
+ contempt = std::stof(parts[1]);
+ } catch (std::exception& e) {
+ throw Exception("Invalid contempt entry: " + entry);
+ }
+ break;
+ }
+ } else {
+ throw Exception("Invalid contempt entry:" + entry);
+ }
+ }
+ return contempt;
+}
+
+// Calculate ratio and diff for WDL conversion from the contempt settings.
+// More accurate model, allowing book bias dependent Elo calculation.
+// Doesn't take lower accuracy of opponent into account and needs clamping.
+SearchParams::WDLRescaleParams AccurateWDLRescaleParams(
+ float contempt, float draw_rate_target, float draw_rate_reference,
+ float book_exit_bias, float contempt_max, float contempt_attenuation) {
+ float scale_target =
+ 1.0f / std::log((1.0f + draw_rate_target) / (1.0f - draw_rate_target));
+ float scale_reference = 1.0f / std::log((1.0f + draw_rate_reference) /
+ (1.0f - draw_rate_reference));
+ float ratio = scale_target / scale_reference;
+ float diff =
+ scale_target / (scale_reference * scale_reference) /
+ (1.0f /
+ std::pow(std::cosh(0.5f * (1 - book_exit_bias) / scale_target), 2) +
+ 1.0f /
+ std::pow(std::cosh(0.5f * (1 + book_exit_bias) / scale_target), 2)) *
+ std::log(10) / 200 * std::clamp(contempt, -contempt_max, contempt_max) *
+ contempt_attenuation;
+ return SearchParams::WDLRescaleParams(ratio, diff);
+}
+
+// Calculate ratio and diff for WDL conversion from the contempt settings.
+// Less accurate Elo model, but automatically chooses draw rate and accuracy
+// based on the absolute Elo of both sides. Doesn't require clamping, but still
+// uses the parameter.
+SearchParams::WDLRescaleParams SimplifiedWDLRescaleParams(
+ float contempt, float draw_rate_reference, float elo_active,
+ float contempt_max, float contempt_attenuation) {
+ // Parameters for the Elo dependent draw rate and scaling:
+ const float scale_zero = 15.0f;
+ const float elo_slope = 425.0f;
+ const float offset = 6.75f;
+
+ float scale_reference = 1.0f / std::log((1.0f + draw_rate_reference) /
+ (1.0f - draw_rate_reference));
+ float elo_opp =
+ elo_active - std::clamp(contempt, -contempt_max, contempt_max);
+ float scale_active =
+ 1.0f / (1.0f / scale_zero + std::exp(elo_active / elo_slope - offset));
+ float scale_opp =
+ 1.0f / (1.0f / scale_zero + std::exp(elo_opp / elo_slope - offset));
+ float scale_target =
+ std::sqrt((scale_active * scale_active + scale_opp * scale_opp) / 2.0f);
+ float ratio = scale_target / scale_reference;
+ float mu_active =
+ -std::log(10) / 200 * scale_zero * elo_slope *
+ std::log(1.0f + std::exp(-elo_active / elo_slope + offset) / scale_zero);
+ float mu_opp =
+ -std::log(10) / 200 * scale_zero * elo_slope *
+ std::log(1.0f + std::exp(-elo_opp / elo_slope + offset) / scale_zero);
+ float diff = (mu_active - mu_opp) * contempt_attenuation;
+ return SearchParams::WDLRescaleParams(ratio, diff);
+}
} // namespace
const OptionId SearchParams::kMiniBatchSizeId{
@@ -62,11 +153,6 @@ const OptionId SearchParams::kMiniBatchSizeId{
"How many positions the engine tries to batch together for parallel NN "
"computation. Larger batches may reduce strength a bit, especially with a "
"small number of playouts."};
-const OptionId SearchParams::kMaxPrefetchBatchId{
- "max-prefetch", "MaxPrefetch",
- "When the engine cannot gather a large enough batch for immediate use, try "
- "to prefetch up to X positions which are likely to be useful soon, and put "
- "them into cache."};
const OptionId SearchParams::kCpuctId{
"cpuct", "CPuct",
"cpuct_init constant from \"UCT search\" algorithm. Higher values promote "
@@ -269,15 +355,52 @@ const OptionId SearchParams::kDrawScoreWhiteId{
const OptionId SearchParams::kDrawScoreBlackId{
"draw-score-black", "DrawScoreBlack",
"Adjustment, added to a draw score of a black player."};
+const OptionId SearchParams::kContemptPerspectiveId{
+ "contempt-perspective", "ContemptPerspective",
+ "Affects the way asymmetric WDL parameters are applied. Default is "
+ "'sidetomove' for matches, use 'white' and 'black' for analysis. Use "
+ "'none' to deactivate contempt and the WDL conversion."};
+const OptionId SearchParams::kContemptId{
+ "contempt", "Contempt",
+ "The simulated Elo advantage for the WDL conversion. Comma separated "
+ "list in the form [name=]value, where the name is compared with the "
+ "`UCI_Opponent` value to find the appropriate contempt value. The default "
+ "value is taken from `UCI_RatingAdv` and will be overridden if either a "
+ "value without name is given, or if a name match is found."};
+const OptionId SearchParams::kContemptMaxValueId{
+ "contempt-max-value", "ContemptMaxValue",
+ "The maximum value of contempt used. Higher values will be capped."};
+const OptionId SearchParams::kWDLCalibrationEloId{
+ "wdl-calibration-elo", "WDLCalibrationElo",
+ "Elo of the active side, adjusted for time control relative to rapid."};
+const OptionId SearchParams::kWDLContemptAttenuationId{
+ "wdl-contempt-attenuation", "WDLContemptAttenuation",
+ "Scales how Elo advantage is applied for contempt. Use 1.0 for realistic "
+ "analysis, and 0.5-0.6 for optimal match performance."};
+const OptionId SearchParams::kWDLEvalObjectivityId{
+ "wdl-eval-objectivity", "WDLEvalObjectivity",
+ "When calculating the centipawn eval output, decides how objective/"
+ "contempt influenced the reported eval should be. Value 0.0 reports the "
+ "internally used WDL values, 1.0 attempts an objective eval."};
+const OptionId SearchParams::kWDLDrawRateTargetId{
+ "wdl-draw-rate-target", "WDLDrawRateTarget",
+ "To define the accuracy of play, the target draw rate in equal "
+ "positions is used as a proxy."};
+const OptionId SearchParams::kWDLDrawRateReferenceId{
+ "wdl-draw-rate-reference", "WDLDrawRateReference",
+ "Set this to the draw rate predicted by the used neural network at "
+ "default settings. The accuracy rescaling is done relative to the "
+ "reference draw rate."};
+const OptionId SearchParams::kWDLBookExitBiasId{
+ "wdl-book-exit-bias", "WDLBookExitBias",
+ "The book exit bias used when measuring engine Elo. Value of startpos is "
+ "around 0.2, value of 50% white win is 1. Only relevant if target draw "
+ "rate is above 80%."};
const OptionId SearchParams::kNpsLimitId{
"nps-limit", "NodesPerSecondLimit",
"An option to specify an upper limit to the nodes per second searched. The "
"accuracy depends on the minibatch size used, increasing for lower sizes, "
"and on the length of the search. Zero to disable."};
-const OptionId SearchParams::kSolidTreeThresholdId{
- "solid-tree-threshold", "SolidTreeThreshold",
- "Only nodes with at least this number of visits will be considered for "
- "solidification for improved cache locality."};
const OptionId SearchParams::kTaskWorkersPerSearchWorkerId{
"task-workers", "TaskWorkers",
"The number of task workers to use to help the search worker."};
@@ -316,12 +439,19 @@ const OptionId SearchParams::kMaxCollisionVisitsScalingEndId{
const OptionId SearchParams::kMaxCollisionVisitsScalingPowerId{
"max-collision-visits-scaling-power", "MaxCollisionVisitsScalingPower",
"Power to apply to the interpolation between 1 and max to make it curved."};
+const OptionId SearchParams::kUCIOpponentId{
+ "", "UCI_Opponent",
+ "UCI option used by the GUI to pass the name and other information about "
+ "the current opponent."};
+const OptionId SearchParams::kUCIRatingAdvId{
+ "", "UCI_RatingAdv",
+ "UCI extension used by some GUIs to pass the estimated Elo advantage over "
+ "the current opponent, used as the default contempt value."};
void SearchParams::Populate(OptionsParser* options) {
// Here the uci optimized defaults" are set.
// Many of them are overridden with training specific values in tournament.cc.
options->Add(kMiniBatchSizeId, 1, 1024) = DEFAULT_MINIBATCH_SIZE;
- options->Add(kMaxPrefetchBatchId, 0, 1024) = DEFAULT_MAX_PREFETCH;
options->Add(kCpuctId, 0.0f, 100.0f) = 1.745f;
options->Add(kCpuctAtRootId, 0.0f, 100.0f) = 1.745f;
options->Add(kCpuctBaseId, 1.0f, 1000000000.0f) = 38739.0f;
@@ -369,7 +499,8 @@ void SearchParams::Populate(OptionsParser* options) {
"centipawn_2018",
"win_percentage",
"Q",
- "W-L"};
+ "W-L",
+ "WDL_mu"};
options->Add(kScoreTypeId, score_type) = "centipawn";
std::vector history_fill_opt{"no", "fen_only", "always"};
options->Add(kHistoryFillId, history_fill_opt) = "fen_only";
@@ -386,8 +517,22 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add(kDrawScoreOpponentId, -100, 100) = 0;
options->Add(kDrawScoreWhiteId, -100, 100) = 0;
options->Add(kDrawScoreBlackId, -100, 100) = 0;
+ std::vector perspective = {"sidetomove", "white", "black",
+ "none"};
+ options->Add(kContemptPerspectiveId, perspective) =
+ "sidetomove";
+ // The default kContemptId is empty, so the initial contempt value is taken
+ // from kUCIRatingAdvId. Adding any value (without name) in the comma
+ // separated kContemptId list will override this.
+ options->Add(kContemptId) = "";
+ options->Add(kContemptMaxValueId, 0, 10000.0f) = 420.0f;
+ options->Add(kWDLCalibrationEloId, 0, 10000.0f) = 0.0f;
+ options->Add(kWDLContemptAttenuationId, -10.0f, 10.0f) = 1.0f;
+ options->Add(kWDLEvalObjectivityId, 0.0f, 1.0f) = 1.0f;
+ options->Add(kWDLDrawRateTargetId, 0.001f, 0.999f) = 0.5f;
+ options->Add(kWDLDrawRateReferenceId, 0.001f, 0.999f) = 0.5f;
+ options->Add(kWDLBookExitBiasId, -2.0f, 2.0f) = 0.65f;
options->Add(kNpsLimitId, 0.0f, 1e6f) = 0.0f;
- options->Add(kSolidTreeThresholdId, 1, 2000000000) = 100;
options->Add(kTaskWorkersPerSearchWorkerId, 0, 128) =
DEFAULT_TASK_WORKERS;
options->Add(kMinimumWorkSizeForProcessingId, 2, 100000) = 20;
@@ -397,6 +542,8 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add(kMinimumWorkPerTaskForProcessingId, 1, 100000) = 8;
options->Add(kIdlingMinimumWorkId, 0, 10000) = 0;
options->Add(kThreadIdlingThresholdId, 0, 128) = 1;
+ options->Add(kUCIOpponentId);
+ options->Add(kUCIRatingAdvId, -10000.0f, 10000.0f) = 0.0f;
options->HideOption(kNoiseEpsilonId);
options->HideOption(kNoiseAlphaId);
@@ -415,6 +562,10 @@ void SearchParams::Populate(OptionsParser* options) {
options->HideOption(kTemperatureEndgameId);
options->HideOption(kTemperatureWinpctCutoffId);
options->HideOption(kTemperatureVisitOffsetId);
+ options->HideOption(kContemptMaxValueId);
+ options->HideOption(kWDLContemptAttenuationId);
+ options->HideOption(kWDLDrawRateTargetId);
+ options->HideOption(kWDLBookExitBiasId);
}
SearchParams::SearchParams(const OptionsDict& options)
@@ -465,12 +616,32 @@ SearchParams::SearchParams(const OptionsDict& options)
kDrawScoreOpponent{options.Get(kDrawScoreOpponentId) / 100.0f},
kDrawScoreWhite{options.Get(kDrawScoreWhiteId) / 100.0f},
kDrawScoreBlack{options.Get(kDrawScoreBlackId) / 100.0f},
+ kContemptPerspective(EncodeContemptPerspective(
+ options.Get(kContemptPerspectiveId))),
+ kContempt(options.IsDefault(kContemptId)
+ ? options.Get(kUCIRatingAdvId)
+ : GetContempt(options.Get(kUCIOpponentId),
+ options.Get(kContemptId))),
+ kWDLRescaleParams(
+ options.Get(kWDLCalibrationEloId) == 0
+ ? AccurateWDLRescaleParams(
+ kContempt, options.Get(kWDLDrawRateTargetId),
+ options.Get(kWDLDrawRateReferenceId),
+ options.Get(kWDLBookExitBiasId),
+ options.Get(kContemptMaxValueId),
+ options.Get(kWDLContemptAttenuationId))
+ : SimplifiedWDLRescaleParams(
+ kContempt, options.Get(kWDLDrawRateReferenceId),
+ options.Get(kWDLCalibrationEloId),
+ options.Get(kContemptMaxValueId),
+ options.Get(kWDLContemptAttenuationId))),
+ kWDLEvalObjectivity(options.Get(kWDLEvalObjectivityId)),
kMaxOutOfOrderEvals(std::max(
1, static_cast(options.Get(kMaxOutOfOrderEvalsId) *
options.Get(kMiniBatchSizeId)))),
kNpsLimit(options.Get(kNpsLimitId)),
- kSolidTreeThreshold(options.Get(kSolidTreeThresholdId)),
- kTaskWorkersPerSearchWorker(options.Get(kTaskWorkersPerSearchWorkerId)),
+ kTaskWorkersPerSearchWorker(
+ options.Get(kTaskWorkersPerSearchWorkerId)),
kMinimumWorkSizeForProcessing(
options.Get(kMinimumWorkSizeForProcessingId)),
kMinimumWorkSizeForPicking(
diff --git a/src/mcts/params.h b/src/mcts/params.h
index 7fcf0ef879..599992d0eb 100644
--- a/src/mcts/params.h
+++ b/src/mcts/params.h
@@ -33,19 +33,28 @@
namespace lczero {
+enum class ContemptPerspective { STM, WHITE, BLACK, NONE };
+
class SearchParams {
public:
SearchParams(const OptionsDict& options);
SearchParams(const SearchParams&) = delete;
+ // Use struct for WDLRescaleParams calculation to make them const.
+ struct WDLRescaleParams {
+ WDLRescaleParams(float r, float d) {
+ ratio = r;
+ diff = d;
+ }
+ float ratio;
+ float diff;
+ };
+
// Populates UciOptions with search parameters.
static void Populate(OptionsParser* options);
// Parameter getters.
- int GetMiniBatchSize() const { return kMiniBatchSize; }
- int GetMaxPrefetchBatch() const {
- return options_.Get(kMaxPrefetchBatchId);
- }
+ uint32_t GetMiniBatchSize() const { return kMiniBatchSize; }
float GetCpuct(bool at_root) const { return at_root ? kCpuctAtRoot : kCpuct; }
float GetCpuctBase(bool at_root) const {
return at_root ? kCpuctBaseAtRoot : kCpuctBase;
@@ -104,13 +113,18 @@ class SearchParams {
}
bool GetDisplayCacheUsage() const { return kDisplayCacheUsage; }
int GetMaxConcurrentSearchers() const { return kMaxConcurrentSearchers; }
+ ContemptPerspective GetContemptPerspective() const {
+ return kContemptPerspective;
+ }
float GetSidetomoveDrawScore() const { return kDrawScoreSidetomove; }
float GetOpponentDrawScore() const { return kDrawScoreOpponent; }
float GetWhiteDrawDelta() const { return kDrawScoreWhite; }
float GetBlackDrawDelta() const { return kDrawScoreBlack; }
- int GetMaxOutOfOrderEvals() const { return kMaxOutOfOrderEvals; }
+ float GetWDLRescaleRatio() const { return kWDLRescaleParams.ratio; }
+ float GetWDLRescaleDiff() const { return kWDLRescaleParams.diff; }
+ float GetWDLEvalObjectivity() const { return kWDLEvalObjectivity; }
+ uint32_t GetMaxOutOfOrderEvals() const { return kMaxOutOfOrderEvals; }
float GetNpsLimit() const { return kNpsLimit; }
- int GetSolidTreeThreshold() const { return kSolidTreeThreshold; }
int GetTaskWorkersPerSearchWorker() const {
return kTaskWorkersPerSearchWorker;
@@ -141,7 +155,6 @@ class SearchParams {
// Search parameter IDs.
static const OptionId kMiniBatchSizeId;
- static const OptionId kMaxPrefetchBatchId;
static const OptionId kCpuctId;
static const OptionId kCpuctAtRootId;
static const OptionId kCpuctBaseId;
@@ -184,13 +197,21 @@ class SearchParams {
static const OptionId kMovesLeftSlopeId;
static const OptionId kDisplayCacheUsageId;
static const OptionId kMaxConcurrentSearchersId;
+ static const OptionId kContemptPerspectiveId;
static const OptionId kDrawScoreSidetomoveId;
static const OptionId kDrawScoreOpponentId;
static const OptionId kDrawScoreWhiteId;
static const OptionId kDrawScoreBlackId;
+ static const OptionId kContemptId;
+ static const OptionId kContemptMaxValueId;
+ static const OptionId kWDLCalibrationEloId;
+ static const OptionId kWDLContemptAttenuationId;
+ static const OptionId kWDLEvalObjectivityId;
+ static const OptionId kWDLDrawRateTargetId;
+ static const OptionId kWDLDrawRateReferenceId;
+ static const OptionId kWDLBookExitBiasId;
static const OptionId kMaxOutOfOrderEvalsId;
static const OptionId kNpsLimitId;
- static const OptionId kSolidTreeThresholdId;
static const OptionId kTaskWorkersPerSearchWorkerId;
static const OptionId kMinimumWorkSizeForProcessingId;
static const OptionId kMinimumWorkSizeForPickingId;
@@ -201,6 +222,8 @@ class SearchParams {
static const OptionId kMaxCollisionVisitsScalingStartId;
static const OptionId kMaxCollisionVisitsScalingEndId;
static const OptionId kMaxCollisionVisitsScalingPowerId;
+ static const OptionId kUCIOpponentId;
+ static const OptionId kUCIRatingAdvId;
private:
const OptionsDict& options_;
@@ -244,9 +267,12 @@ class SearchParams {
const float kDrawScoreOpponent;
const float kDrawScoreWhite;
const float kDrawScoreBlack;
+ const ContemptPerspective kContemptPerspective;
+ const float kContempt;
+ const WDLRescaleParams kWDLRescaleParams;
+ const float kWDLEvalObjectivity;
const int kMaxOutOfOrderEvals;
const float kNpsLimit;
- const int kSolidTreeThreshold;
const int kTaskWorkersPerSearchWorker;
const int kMinimumWorkSizeForProcessing;
const int kMinimumWorkSizeForPicking;
diff --git a/src/mcts/search.cc b/src/mcts/search.cc
index 6fe8d1c7e6..557f1c7a12 100644
--- a/src/mcts/search.cc
+++ b/src/mcts/search.cc
@@ -34,12 +34,11 @@
#include
#include
#include
+#include
#include
#include
#include "mcts/node.h"
-#include "neural/cache.h"
-#include "neural/encoder.h"
#include "utils/fastmath.h"
#include "utils/random.h"
@@ -111,6 +110,11 @@ class MEvaluator {
const float child_m = child.GetM(parent_m_);
float m = std::clamp(m_slope_ * (child_m - parent_m_), -m_cap_, m_cap_);
m *= FastSign(-q);
+ if (q_threshold_ > 0.0f && q_threshold_ < 1.0f) {
+ // This allows a smooth M effect with higher q thresholds, which is
+ // necessary for using MLH together with contempt.
+ q = std::max(0.0f, (std::abs(q) - q_threshold_)) / (1.0f - q_threshold_);
+ }
m *= a_constant_ + a_linear_ * std::abs(q) + a_square_ * q * q;
return m;
}
@@ -145,7 +149,7 @@ class MEvaluator {
} // namespace
-Search::Search(const NodeTree& tree, Network* network,
+Search::Search(NodeTree* dag, Network* network,
std::unique_ptr uci_responder,
const MoveList& searchmoves,
std::chrono::steady_clock::time_point start_time,
@@ -154,10 +158,11 @@ Search::Search(const NodeTree& tree, Network* network,
SyzygyTablebase* syzygy_tb)
: ok_to_respond_bestmove_(!infinite),
stopper_(std::move(stopper)),
- root_node_(tree.GetCurrentHead()),
+ root_node_(dag->GetCurrentHead()),
cache_(cache),
+ dag_(dag),
syzygy_tb_(syzygy_tb),
- played_history_(tree.GetPositionHistory()),
+ played_history_(dag->GetPositionHistory()),
network_(network),
params_(options),
searchmoves_(searchmoves),
@@ -194,6 +199,44 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) {
}
} // namespace
+namespace {
+// WDL conversion formula based on random walk model.
+inline void WDLRescale(float& v, float& d, float* mu_uci,
+ float wdl_rescale_ratio, float wdl_rescale_diff,
+ float sign, bool invert) {
+ if (invert) {
+ wdl_rescale_diff = -wdl_rescale_diff;
+ wdl_rescale_ratio = 1.0f / wdl_rescale_ratio;
+ }
+ auto w = (1 + v - d) / 2;
+ auto l = (1 - v - d) / 2;
+ // Safeguard against numerical issues; skip WDL transformation if WDL is too
+ // extreme.
+ const float zero = 0.0001f;
+ const float one = 0.9999f;
+ if (w > zero && d > zero && l > zero && w < one && d < one && l < one) {
+ auto a = FastLog(1 / l - 1);
+ auto b = FastLog(1 / w - 1);
+ auto s = 2 / (a + b);
+ // Safeguard against unrealistically broad WDL distributions coming from
+ // the NN. Could be made into a parameter, but probably unnecessary.
+ if (!invert) s = std::min(1.4f, s);
+ auto mu = (a - b) / (a + b);
+ auto s_new = s * wdl_rescale_ratio;
+ if (invert) {
+ std::swap(s, s_new);
+ s = std::min(1.4f, s);
+ }
+ auto mu_new = mu + sign * s * s * wdl_rescale_diff;
+ auto w_new = FastLogistic((-1.0f + mu_new) / s_new);
+ auto l_new = FastLogistic((-1.0f - mu_new) / s_new);
+ v = w_new - l_new;
+ d = std::max(0.0f, 1.0f - w_new - l_new);
+ if (mu_uci) *mu_uci = mu_new;
+ }
+}
+} // namespace
+
void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) {
const auto max_pv = params_.GetMultiPv();
const auto edges = GetBestChildrenNoTemperature(root_node_, max_pv, 0);
@@ -235,12 +278,25 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) {
++multipv;
uci_infos.emplace_back(common_info);
auto& uci_info = uci_infos.back();
- const auto wl = edge.GetWL(default_wl);
- const auto floatD = edge.GetD(default_d);
+ auto wl = edge.GetWL(default_wl);
+ auto d = edge.GetD(default_d);
+ float mu_uci = 0.0f;
+ // Only the diff effect is inverted, so we only need to call if diff != 0.
+ if (params_.GetContemptPerspective() != ContemptPerspective::NONE) {
+ auto sign =
+ ((params_.GetContemptPerspective() == ContemptPerspective::STM) ||
+ ((params_.GetContemptPerspective() == ContemptPerspective::BLACK) ==
+ played_history_.IsBlackToMove()))
+ ? 1.0f
+ : -1.0f;
+ WDLRescale(wl, d, &mu_uci, params_.GetWDLRescaleRatio(),
+ params_.GetWDLRescaleDiff() * params_.GetWDLEvalObjectivity(),
+ sign, true);
+ }
const auto q = edge.GetQ(default_q, draw_score);
if (edge.IsTerminal() && wl != 0.0f) {
uci_info.mate = std::copysign(
- std::round(edge.GetM(0.0f)) / 2 + (edge.IsTbTerminal() ? 101 : 1),
+ std::round(edge.GetM(0.0f) + 1) / 2 + (edge.IsTbTerminal() ? 100 : 0),
wl);
} else if (score_type == "centipawn_with_drawscore") {
uci_info.score = 90 * tan(1.5637541897 * q);
@@ -256,20 +312,30 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) {
uci_info.score = q * 10000;
} else if (score_type == "W-L") {
uci_info.score = wl * 10000;
+ } else if (score_type == "WDL_mu") {
+ // Reports the WDL mu value whenever it is reasonable, and defaults to
+ // centipawn otherwise.
+ float centipawn_score = 90 * tan(1.5637541897 * wl);
+ uci_info.score =
+ mu_uci != 0.0f && std::abs(wl) + d < 0.99f &&
+ (std::abs(mu_uci) < 1.0f ||
+ std::abs(centipawn_score) < std::abs(100 * mu_uci))
+ ? 100 * mu_uci
+ : centipawn_score;
}
- auto w =
- std::max(0, static_cast(std::round(500.0 * (1.0 + wl - floatD))));
- auto l =
- std::max(0, static_cast(std::round(500.0 * (1.0 - wl - floatD))));
+ auto wdl_w =
+ std::max(0, static_cast(std::round(500.0 * (1.0 + wl - d))));
+ auto wdl_l =
+ std::max(0, static_cast(std::round(500.0 * (1.0 - wl - d))));
// Using 1000-w-l so that W+D+L add up to 1000.0.
- auto d = 1000 - w - l;
- if (d < 0) {
- w = std::min(1000, std::max(0, w + d / 2));
- l = 1000 - w;
- d = 0;
+ auto wdl_d = 1000 - wdl_w - wdl_l;
+ if (wdl_d < 0) {
+ wdl_w = std::min(1000, std::max(0, wdl_w + wdl_d / 2));
+ wdl_l = 1000 - wdl_w;
+ wdl_d = 0;
}
- uci_info.wdl = ThinkingInfo::WDL{w, d, l};
+ uci_info.wdl = ThinkingInfo::WDL{wdl_w, wdl_d, wdl_l};
if (network_->GetCapabilities().has_mlh()) {
uci_info.moves_left = static_cast(
(1.0f + edge.GetM(1.0f + root_node_->GetM())) / 2.0f);
@@ -278,10 +344,13 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) {
if (per_pv_counters) uci_info.nodes = edge.GetN();
bool flip = played_history_.IsBlackToMove();
int depth = 0;
+ auto history = played_history_;
for (auto iter = edge; iter;
iter = GetBestChildNoTemperature(iter.node(), depth), flip = !flip) {
uci_info.pv.push_back(iter.GetMove(flip));
- if (!iter.node()) break; // Last edge was dangling, cannot continue.
+ history.Append(iter.GetMove());
+ // Last edge was dangling or a draw by repetition, cannot continue.
+ if (!iter.node() || history.Last().GetRepetitions() >= 2) break;
depth += 1;
}
}
@@ -371,13 +440,12 @@ inline float ComputeCpuct(const SearchParams& params, uint32_t N,
} // namespace
std::vector Search::GetVerboseStats(Node* node) const {
- assert(node == root_node_ || node->GetParent() == root_node_);
const bool is_root = (node == root_node_);
const bool is_odd_depth = !is_root;
const bool is_black_to_move = (played_history_.IsBlackToMove() == is_root);
const float draw_score = GetDrawScore(is_odd_depth);
const float fpu = GetFpu(params_, node, is_root, draw_score);
- const float cpuct = ComputeCpuct(params_, node->GetN(), is_root);
+ const float cpuct = ComputeCpuct(params_, node->GetTotalVisits(), is_root);
const float U_coeff =
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
std::vector edges;
@@ -419,9 +487,13 @@ std::vector Search::GetVerboseStats(Node* node) const {
std::optional v;
if (n && n->IsTerminal()) {
v = n->GetQ(sign * draw_score);
- } else {
- NNCacheLock nneval = GetCachedNNEval(n);
- if (nneval) v = -nneval->q;
+ } else if (n) {
+ auto history = played_history_;
+ if (!is_root) {
+ history.Append(node->GetMove());
+ }
+ NNCacheLock nneval = GetCachedNNEval(history);
+ if (nneval) v = -nneval->eval->q;
}
if (v) {
print(oss, "(V: ", sign * *v, ") ", 7, 4);
@@ -436,10 +508,13 @@ std::vector Search::GetVerboseStats(Node* node) const {
up = -up;
std::swap(lo, up);
}
- *oss << (lo == up ? "(T) "
- : lo == GameResult::DRAW && up == GameResult::WHITE_WON ? "(W) "
- : lo == GameResult::BLACK_WON && up == GameResult::DRAW ? "(L) "
- : "");
+ *oss << (lo == up
+ ? "(T) "
+ : lo == GameResult::DRAW && up == GameResult::WHITE_WON
+ ? "(W) "
+ : lo == GameResult::BLACK_WON && up == GameResult::DRAW
+ ? "(L) "
+ : "");
}
};
@@ -502,18 +577,8 @@ void Search::SendMovesStats() const REQUIRES(counters_mutex_) {
}
}
-NNCacheLock Search::GetCachedNNEval(const Node* node) const {
- if (!node) return {};
-
- std::vector moves;
- for (; node != root_node_; node = node->GetParent()) {
- moves.push_back(node->GetOwnEdge()->GetMove());
- }
- PositionHistory history(played_history_);
- for (auto iter = moves.rbegin(), end = moves.rend(); iter != end; ++iter) {
- history.Append(*iter);
- }
- const auto hash = history.HashLast(params_.GetCacheHistoryLength() + 1);
+NNCacheLock Search::GetCachedNNEval(const PositionHistory& history) const {
+ const auto hash = dag_->GetHistoryHash(history);
NNCacheLock nneval(cache_, hash);
return nneval;
}
@@ -838,7 +903,8 @@ void Search::PopulateCommonIterationStats(IterationStats* stats) {
nps_start_time_ = std::chrono::steady_clock::now();
}
}
- stats->total_nodes = total_playouts_ + initial_visits_;
+ stats->total_visits = total_playouts_ + initial_visits_;
+ stats->total_allocated_nodes = dag_->AllocatedNodeCount();
stats->nodes_since_movestart = total_playouts_;
stats->batches_since_movestart = total_batches_;
stats->average_depth = cum_depth_ / (total_playouts_ ? total_playouts_ : 1);
@@ -957,10 +1023,9 @@ void Search::Wait() {
void Search::CancelSharedCollisions() REQUIRES(nodes_mutex_) {
for (auto& entry : shared_collisions_) {
- Node* node = entry.first;
- for (node = node->GetParent(); node != root_node_->GetParent();
- node = node->GetParent()) {
- node->CancelScoreUpdate(entry.second);
+ auto path = entry.first;
+ for (auto it = ++(path.crbegin()); it != path.crend(); ++it) {
+ std::get<0>(*it)->CancelScoreUpdate(entry.second);
}
}
shared_collisions_.clear();
@@ -972,7 +1037,16 @@ Search::~Search() {
{
SharedMutex::Lock lock(nodes_mutex_);
CancelSharedCollisions();
+
+#ifndef NDEBUG
+ assert(root_node_->ZeroNInFlight());
+#endif
}
+
+ // Free previously released nodes that were not reused during this search.
+ dag_->TTMaintenance();
+ dag_->TTMaintenance();
+
LOGFILE << "Search destroyed.";
}
@@ -1036,14 +1110,13 @@ void SearchWorker::RunTasks(int tid) {
if (task != nullptr) {
switch (task->task_type) {
case PickTask::kGathering: {
- PickNodesToExtendTask(task->start, task->base_depth,
- task->collision_limit, task->moves_to_base,
- &(task->results), &(task_workspaces_[tid]));
+ PickNodesToExtendTask(task->start_path, task->collision_limit,
+ task->history, &(task->results),
+ &(task_workspaces_[tid]));
break;
}
case PickTask::kProcessing: {
- ProcessPickedTask(task->start_idx, task->end_idx,
- &(task_workspaces_[tid]));
+ ProcessPickedTask(task->start_idx, task->end_idx);
break;
}
}
@@ -1059,7 +1132,7 @@ void SearchWorker::ExecuteOneIteration() {
if (params_.GetMaxConcurrentSearchers() != 0) {
while (true) {
- // If search is stop, we've not gathered or done anything and we don't
+ // If search is stopped, we've not gathered or done anything and we don't
// want to, so we can safely skip all below. But make sure we have done
// at least one iteration.
if (search_->stop_.load(std::memory_order_acquire) &&
@@ -1087,9 +1160,6 @@ void SearchWorker::ExecuteOneIteration() {
// 2b. Collect collisions.
CollectCollisions();
- // 3. Prefetch into cache.
- MaybePrefetchIntoCache();
-
if (params_.GetMaxConcurrentSearchers() != 0) {
search_->pending_searchers_.fetch_add(1, std::memory_order_acq_rel);
}
@@ -1132,8 +1202,9 @@ void SearchWorker::ExecuteOneIteration() {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void SearchWorker::InitializeIteration(
std::unique_ptr computation) {
- computation_ = std::make_unique(std::move(computation),
- search_->cache_);
+ computation_ = std::make_unique(
+ std::move(computation), search_->network_->GetCapabilities().input_format,
+ params_.GetHistoryFill(), search_->cache_);
computation_->Reserve(params_.GetMiniBatchSize());
minibatch_.clear();
minibatch_.reserve(2 * params_.GetMiniBatchSize());
@@ -1166,7 +1237,7 @@ int CalculateCollisionsLeft(int64_t nodes, const SearchParams& params) {
void SearchWorker::GatherMinibatch() {
// Total number of nodes to process.
- int minibatch_size = 0;
+ uint32_t minibatch_size = 0;
int cur_n = 0;
{
SharedMutex::Lock lock(search_->nodes_mutex_);
@@ -1176,7 +1247,7 @@ void SearchWorker::GatherMinibatch() {
// applied, which doesn't clearly make sense to include here...
int64_t remaining_n =
latest_time_manager_hints_.GetEstimatedRemainingPlayouts();
- int collisions_left = CalculateCollisionsLeft(
+ uint32_t collisions_left = CalculateCollisionsLeft(
std::min(static_cast(cur_n), remaining_n), params_);
// Number of nodes processed out of order.
@@ -1223,39 +1294,45 @@ void SearchWorker::GatherMinibatch() {
++minibatch_size;
}
- bool needs_wait = false;
- int ppt_start = new_start;
- if (params_.GetTaskWorkersPerSearchWorker() > 0 &&
- non_collisions >= params_.GetMinimumWorkSizeForProcessing()) {
- const int num_tasks = std::clamp(
- non_collisions / params_.GetMinimumWorkPerTaskForProcessing(), 2,
- params_.GetTaskWorkersPerSearchWorker() + 1);
- // Round down, left overs can go to main thread so it waits less.
- int per_worker = non_collisions / num_tasks;
- needs_wait = true;
- ResetTasks();
- int found = 0;
- for (int i = new_start; i < static_cast(minibatch_.size()); i++) {
- auto& picked_node = minibatch_[i];
- if (picked_node.IsCollision()) {
- continue;
- }
- ++found;
- if (found == per_worker) {
- picking_tasks_.emplace_back(ppt_start, i + 1);
- task_count_.fetch_add(1, std::memory_order_acq_rel);
- ppt_start = i + 1;
- found = 0;
- if (picking_tasks_.size() == static_cast(num_tasks - 1)) {
- break;
+ {
+ // This lock must be held until after the task_completed_ wait succeeds
+ // below. Since the tasks perform work which assumes they have the lock,
+ // even though actually this thread does.
+ SharedMutex::Lock lock(search_->nodes_mutex_);
+
+ bool needs_wait = false;
+ int ppt_start = new_start;
+ if (params_.GetTaskWorkersPerSearchWorker() > 0 &&
+ non_collisions >= params_.GetMinimumWorkSizeForProcessing()) {
+ const int num_tasks = std::clamp(
+ non_collisions / params_.GetMinimumWorkPerTaskForProcessing(), 2,
+ params_.GetTaskWorkersPerSearchWorker() + 1);
+ // Round down, left overs can go to main thread so it waits less.
+ int per_worker = non_collisions / num_tasks;
+ needs_wait = true;
+ ResetTasks();
+ int found = 0;
+ for (int i = new_start; i < static_cast(minibatch_.size()); i++) {
+ auto& picked_node = minibatch_[i];
+ if (picked_node.IsCollision()) {
+ continue;
+ }
+ ++found;
+ if (found == per_worker) {
+ picking_tasks_.emplace_back(ppt_start, i + 1);
+ task_count_.fetch_add(1, std::memory_order_acq_rel);
+ ppt_start = i + 1;
+ found = 0;
+ if (picking_tasks_.size() == static_cast(num_tasks - 1)) {
+ break;
+ }
}
}
}
- }
- ProcessPickedTask(ppt_start, static_cast(minibatch_.size()),
- &main_workspace_);
- if (needs_wait) {
- WaitForTasks();
+ ProcessPickedTask(ppt_start, static_cast(minibatch_.size()));
+ if (needs_wait) {
+ WaitForTasks();
+ }
}
bool some_ooo = false;
for (int i = static_cast(minibatch_.size()) - 1; i >= new_start; i--) {
@@ -1273,14 +1350,13 @@ void SearchWorker::GatherMinibatch() {
// This may remove too many items, but hopefully most of the time they
// will just be added back in the same in the next gather.
if (minibatch_[i].IsCollision()) {
- Node* node = minibatch_[i].node;
- for (node = node->GetParent();
- node != search_->root_node_->GetParent();
- node = node->GetParent()) {
- node->CancelScoreUpdate(minibatch_[i].multivisit);
+ for (auto it = ++(minibatch_[i].path.crbegin());
+ it != minibatch_[i].path.crend(); ++it) {
+ std::get<0>(*it)->CancelScoreUpdate(minibatch_[i].multivisit);
}
minibatch_.erase(minibatch_.begin() + i);
} else if (minibatch_[i].ooo_completed) {
+ FetchSingleNodeResult(&minibatch_[i], minibatch_[i], 0);
DoBackupUpdateSingleNode(minibatch_[i]);
minibatch_.erase(minibatch_.begin() + i);
--minibatch_size;
@@ -1292,15 +1368,13 @@ void SearchWorker::GatherMinibatch() {
// If there was no OOO, there can stil be collisions.
// There are no OOO though.
// Also terminals when OOO is disabled.
- if (!minibatch_[i].nn_queried) continue;
+ if (!minibatch_[i].ShouldAddToInput()) continue;
if (minibatch_[i].is_cache_hit) {
// Since minibatch_[i] holds cache lock, this is guaranteed to succeed.
computation_->AddInputByHash(minibatch_[i].hash,
std::move(minibatch_[i].lock));
} else {
- computation_->AddInput(minibatch_[i].hash,
- std::move(minibatch_[i].input_planes),
- std::move(minibatch_[i].probabilities_to_cache));
+ computation_->AddInput(minibatch_[i].hash, minibatch_[i].history);
}
}
@@ -1316,11 +1390,9 @@ void SearchWorker::GatherMinibatch() {
int extra = std::min(picked_node.maxvisit, collisions_left) -
picked_node.multivisit;
picked_node.multivisit += extra;
- Node* node = picked_node.node;
- for (node = node->GetParent();
- node != search_->root_node_->GetParent();
- node = node->GetParent()) {
- node->IncrementNInFlight(extra);
+ for (auto it = ++(picked_node.path.crbegin());
+ it != picked_node.path.crend(); ++it) {
+ std::get<0>(*it)->IncrementNInFlight(extra);
}
}
if ((collisions_left -= picked_node.multivisit) <= 0) return;
@@ -1330,51 +1402,20 @@ void SearchWorker::GatherMinibatch() {
}
}
-void SearchWorker::ProcessPickedTask(int start_idx, int end_idx,
- TaskWorkspace* workspace) {
- auto& history = workspace->history;
- history = search_->played_history_;
-
+void SearchWorker::ProcessPickedTask(int start_idx, int end_idx) {
for (int i = start_idx; i < end_idx; i++) {
auto& picked_node = minibatch_[i];
if (picked_node.IsCollision()) continue;
- auto* node = picked_node.node;
-
- // If node is already known as terminal (win/loss/draw according to rules
- // of the game), it means that we already visited this node before.
+ // If node is a collision, known as a terminal (win/loss/draw according to
+ // the rules of the game) or has a low node, it means that we have already
+ // visited this node before and can't extend it.
if (picked_node.IsExtendable()) {
// Node was never visited, extend it.
- ExtendNode(node, picked_node.depth, picked_node.moves_to_visit, &history);
- if (!node->IsTerminal()) {
- picked_node.nn_queried = true;
- const auto hash = history.HashLast(params_.GetCacheHistoryLength() + 1);
- picked_node.hash = hash;
- picked_node.lock = NNCacheLock(search_->cache_, hash);
- picked_node.is_cache_hit = picked_node.lock;
- if (!picked_node.is_cache_hit) {
- int transform;
- picked_node.input_planes = EncodePositionForNN(
- search_->network_->GetCapabilities().input_format, history, 8,
- params_.GetHistoryFill(), &transform);
- picked_node.probability_transform = transform;
-
- std::vector& moves = picked_node.probabilities_to_cache;
- // Legal moves are known, use them.
- moves.reserve(node->GetNumEdges());
- for (const auto& edge : node->Edges()) {
- moves.emplace_back(edge.GetMove().as_nn_index(transform));
- }
- } else {
- picked_node.probability_transform = TransformForPosition(
- search_->network_->GetCapabilities().input_format, history);
- }
- }
- }
- if (params_.GetOutOfOrderEval() && picked_node.CanEvalOutOfOrder()) {
- // Perform out of order eval for the last entry in minibatch_.
- FetchSingleNodeResult(&picked_node, picked_node, 0);
- picked_node.ooo_completed = true;
+ ExtendNode(picked_node);
}
+
+ picked_node.ooo_completed =
+ params_.GetOutOfOrderEval() && picked_node.CanEvalOutOfOrder();
}
}
@@ -1412,8 +1453,10 @@ void SearchWorker::PickNodesToExtend(int collision_limit) {
// Since the tasks perform work which assumes they have the lock, even though
// actually this thread does.
SharedMutex::Lock lock(search_->nodes_mutex_);
- PickNodesToExtendTask(search_->root_node_, 0, collision_limit, empty_movelist,
- &minibatch_, &main_workspace_);
+ history_.Trim(search_->played_history_.GetLength());
+ PickNodesToExtendTask({std::make_tuple(search_->root_node_, 0, 0)},
+ collision_limit, history_, &minibatch_,
+ &main_workspace_);
WaitForTasks();
for (int i = 0; i < static_cast(picking_tasks_.size()); i++) {
@@ -1424,52 +1467,78 @@ void SearchWorker::PickNodesToExtend(int collision_limit) {
}
}
-void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node,
- int depth) {
- // Check whether first repetition was before root. If yes, remove
- // terminal status of node and revert all visits in the tree.
- // Length of repetition was stored in m_. This code will only do
- // something when tree is reused and twofold visits need to be
- // reverted.
- if (child_node->IsTwoFoldTerminal() && depth < child_node->GetM()) {
- // Take a mutex - any SearchWorker specific mutex... since this is
- // not safe to do concurrently between multiple tasks.
- Mutex::Lock lock(picking_tasks_mutex_);
- int depth_counter = 0;
- // Cache node's values as we reset them in the process. We could
- // manually set wl and d, but if we want to reuse this for reverting
- // other terminal nodes this is the way to go.
- const auto wl = child_node->GetWL();
- const auto d = child_node->GetD();
- const auto m = child_node->GetM();
- const auto terminal_visits = child_node->GetN();
- for (Node* node_to_revert = child_node; node_to_revert != nullptr;
- node_to_revert = node_to_revert->GetParent()) {
- // Revert all visits on twofold draw when making it non terminal.
- node_to_revert->RevertTerminalVisits(wl, d, m + (float)depth_counter,
- terminal_visits);
- depth_counter++;
- // Even if original tree still exists, we don't want to revert
- // more than until new root.
- if (depth_counter > depth) break;
- // If wl != 0, we would have to switch signs at each depth.
- }
- // Mark the prior twofold draw as non terminal to extend it again.
- child_node->MakeNotTerminal();
- // When reverting the visits, we also need to revert the initial
- // visits, as we reused fewer nodes than anticipated.
- search_->initial_visits_ -= terminal_visits;
- // Max depth doesn't change when reverting the visits, and
- // cum_depth_ only counts the average depth of new nodes, not reused
- // ones.
+// Depth starts with 0 at root, so number of plies in PV equals depth.
+std::pair SearchWorker::GetRepetitions(int depth,
+ const Position& position) {
+ const auto repetitions = position.GetRepetitions();
+
+ if (repetitions == 0) return {0, 0};
+
+ if (repetitions >= 2) return {repetitions, 0};
+
+ const auto plies = position.GetPliesSincePrevRepetition();
+ if (params_.GetTwoFoldDraws() && /*repetitions == 1 &&*/ depth >= 4 &&
+ depth >= plies) {
+ return {1, plies};
}
+
+ return {0, 0};
+}
+
+// Check if PickNodesToExtendTask should stop picking at this @node.
+bool SearchWorker::ShouldStopPickingHere(Node* node, bool is_root_node,
+ int repetitions) {
+ constexpr double wl_diff_limit = 0.01f;
+ constexpr float d_diff_limit = 0.01f;
+ constexpr float m_diff_limit = 2.0f;
+
+ if (node->GetN() == 0 || node->IsTerminal()) return true;
+
+ // Only stop at root when there is no other option.
+ assert(!is_root_node || node == search_->root_node_);
+ if (is_root_node) return false;
+
+ // Stop at draws by repetition.
+ if (repetitions >= 2) return true;
+
+ // Check if Node and LowNode differ significantly.
+ auto low_node = node->GetLowNode();
+ assert(low_node);
+
+ // Only known transpositions can differ.
+ if (!low_node->IsTransposition()) return false;
+
+ // LowNode is terminal when Node is not.
+ if (low_node->IsTerminal()) return true;
+
+ // Bounds differ (swap).
+ auto [low_node_lower, low_node_upper] = low_node->GetBounds();
+ auto [node_lower, node_upper] = node->GetBounds();
+ if (low_node_lower != -node_upper || low_node_upper != -node_lower)
+ return true;
+
+ // WL differs significantly (flip).
+ auto wl_diff = std::abs(low_node->GetWL() + node->GetWL());
+ if (wl_diff >= wl_diff_limit) return true;
+
+ // D differs significantly.
+ auto d_diff = std::abs(low_node->GetD() - node->GetD());
+ if (d_diff >= d_diff_limit) return true;
+
+ // M differs significantly (increment).
+ auto m_diff = std::abs(low_node->GetM() + 1 - node->GetM());
+ if (m_diff >= m_diff_limit) return true;
+
+ return false;
}
void SearchWorker::PickNodesToExtendTask(
- Node* node, int base_depth, int collision_limit,
- const std::vector& moves_to_base,
+ const BackupPath& path, int collision_limit, PositionHistory& history,
std::vector* receiver,
TaskWorkspace* workspace) NO_THREAD_SAFETY_ANALYSIS {
+ assert(path.size() == (size_t)history.GetLength() -
+ search_->played_history_.GetLength() + 1);
+
// TODO: Bring back pre-cached nodes created outside locks in a way that works
// with tasks.
// TODO: pre-reserve visits_to_perform for expected depth and likely maximum
@@ -1481,15 +1550,16 @@ void SearchWorker::PickNodesToExtendTask(
vtp_last_filled.clear();
auto& current_path = workspace->current_path;
current_path.clear();
- auto& moves_to_path = workspace->moves_to_path;
- moves_to_path = moves_to_base;
+ auto& full_path = workspace->full_path;
+ full_path = path;
+ assert(full_path.size() > 0);
+ auto [node, repetitions, moves_left] = full_path.back();
// Sometimes receiver is reused, othertimes not, so only jump start if small.
if (receiver->capacity() < 30) {
receiver->reserve(receiver->size() + 30);
}
- // These 2 are 'filled pre-emptively'.
- std::array current_pol;
+ // This 1 is 'filled pre-emptively'.
std::array current_util;
// These 3 are 'filled on demand'.
@@ -1515,6 +1585,7 @@ void SearchWorker::PickNodesToExtendTask(
current_path.push_back(-1);
while (current_path.size() > 0) {
+ assert(full_path.size() >= path.size());
// First prepare visits_to_perform.
if (current_path.back() == -1) {
// Need to do n visits, where n is either collision_limit, or comes from
@@ -1526,31 +1597,37 @@ void SearchWorker::PickNodesToExtendTask(
}
// First check if node is terminal or not-expanded. If either than create
// a collision of appropriate size and pop current_path.
- if (node->GetN() == 0 || node->IsTerminal()) {
+ if (ShouldStopPickingHere(node, is_root_node, repetitions)) {
if (is_root_node) {
// Root node is special - since its not reached from anywhere else, so
// it needs its own logic. Still need to create the collision to
// ensure the outer gather loop gives up.
if (node->TryStartScoreUpdate()) {
cur_limit -= 1;
- minibatch_.push_back(NodeToProcess::Visit(
- node, static_cast(current_path.size() + base_depth)));
+ minibatch_.push_back(
+ NodeToProcess::Visit(full_path, search_->played_history_));
completed_visits++;
}
}
// Visits are created elsewhere, just need the collisions here.
if (cur_limit > 0) {
int max_count = 0;
- if (cur_limit == collision_limit && base_depth == 0 &&
+ if (cur_limit == collision_limit && path.size() == 1 &&
max_limit > cur_limit) {
max_count = max_limit;
}
- receiver->push_back(NodeToProcess::Collision(
- node, static_cast(current_path.size() + base_depth),
- cur_limit, max_count));
+ receiver->push_back(
+ NodeToProcess::Collision(full_path, cur_limit, max_count));
completed_visits += cur_limit;
}
- node = node->GetParent();
+ history.Pop();
+ full_path.pop_back();
+ if (full_path.size() > 0) {
+ std::tie(node, repetitions, moves_left) = full_path.back();
+ } else {
+ node = nullptr;
+ repetitions = 0;
+ }
current_path.pop_back();
continue;
}
@@ -1570,34 +1647,20 @@ void SearchWorker::PickNodesToExtendTask(
vtp_last_filled.push_back(-1);
// Cache all constant UCT parameters.
- // When we're near the leaves we can copy less of the policy, since there
- // is no way iteration will ever reach it.
- // TODO: This is a very conservative formula. It assumes every visit we're
- // aiming to add is going to trigger a new child, and that any visits
- // we've already had have also done so and then a couple extra since we go
- // to 2 unvisited to get second best in worst case.
- // Unclear we can do better without having already walked the children.
- // Which we are putting off until after policy is copied so we can create
- // visited policy without having to cache it in the node (allowing the
- // node to stay at 64 bytes).
+
int max_needed = node->GetNumEdges();
- if (!is_root_node || root_move_filter.empty()) {
- max_needed = std::min(max_needed, node->GetNStarted() + cur_limit + 2);
- }
- node->CopyPolicy(max_needed, current_pol.data());
for (int i = 0; i < max_needed; i++) {
current_util[i] = std::numeric_limits::lowest();
}
// Root depth is 1 here, while for GetDrawScore() it's 0-based, that's why
// the weirdness.
- const float draw_score = ((current_path.size() + base_depth) % 2 == 0)
- ? odd_draw_score
- : even_draw_score;
+ const float draw_score =
+ (full_path.size() % 2 == 0) ? odd_draw_score : even_draw_score;
m_evaluator.SetParent(node);
float visited_pol = 0.0f;
for (Node* child : node->VisitedNodes()) {
int index = child->Index();
- visited_pol += current_pol[index];
+ visited_pol += child->GetP();
float q = child->GetQ(draw_score);
current_util[index] = q + m_evaluator.GetM(child, q);
}
@@ -1609,7 +1672,8 @@ void SearchWorker::PickNodesToExtendTask(
}
}
- const float cpuct = ComputeCpuct(params_, node->GetN(), is_root_node);
+ const float cpuct =
+ ComputeCpuct(params_, node->GetTotalVisits(), is_root_node);
const float puct_mult =
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
int cache_filled_idx = -1;
@@ -1635,7 +1699,7 @@ void SearchWorker::PickNodesToExtendTask(
const float util = current_util[idx];
if (idx > cache_filled_idx) {
current_score[idx] =
- current_pol[idx] * puct_mult / (1 + nstarted) + util;
+ cur_iters[idx].GetP() * puct_mult / (1 + nstarted) + util;
cache_filled_idx++;
}
if (is_root_node) {
@@ -1683,7 +1747,7 @@ void SearchWorker::PickNodesToExtendTask(
if (best_without_u < second_best) {
const auto n1 = current_nstarted[best_idx] + 1;
estimated_visits_to_change_best = static_cast(
- std::max(1.0f, std::min(current_pol[best_idx] * puct_mult /
+ std::max(1.0f, std::min(cur_iters[best_idx].GetP() * puct_mult /
(second_best - best_without_u) -
n1 + 1,
1e9f)));
@@ -1702,43 +1766,35 @@ void SearchWorker::PickNodesToExtendTask(
}
(*visits_to_perform.back())[best_idx] += new_visits;
cur_limit -= new_visits;
- Node* child_node = best_edge.GetOrSpawnNode(/* parent */ node);
-
- // Probably best place to check for two-fold draws consistently.
- // Depth starts with 1 at root, so real depth is depth - 1.
- EnsureNodeTwoFoldCorrectForDepth(
- child_node, current_path.size() + base_depth + 1 - 1);
- bool decremented = false;
+ Node* child_node = best_edge.GetOrSpawnNode(/* parent */ node);
+ history.Append(best_edge.GetMove());
+ auto [child_repetitions, child_moves_left] =
+ GetRepetitions(full_path.size(), history.Last());
+ full_path.push_back({child_node, child_repetitions, child_moves_left});
if (child_node->TryStartScoreUpdate()) {
current_nstarted[best_idx]++;
new_visits -= 1;
- decremented = true;
- if (child_node->GetN() > 0 && !child_node->IsTerminal()) {
+ if (ShouldStopPickingHere(child_node, false, child_repetitions)) {
+ // Reduce 1 for the visits_to_perform to ensure the collision
+ // created doesn't include this visit.
+ (*visits_to_perform.back())[best_idx] -= 1;
+ receiver->push_back(NodeToProcess::Visit(full_path, history));
+ completed_visits++;
+ } else {
child_node->IncrementNInFlight(new_visits);
current_nstarted[best_idx] += new_visits;
}
- current_score[best_idx] = current_pol[best_idx] * puct_mult /
+ current_score[best_idx] = cur_iters[best_idx].GetP() * puct_mult /
(1 + current_nstarted[best_idx]) +
current_util[best_idx];
}
- if ((decremented &&
- (child_node->GetN() == 0 || child_node->IsTerminal()))) {
- // Reduce 1 for the visits_to_perform to ensure the collision created
- // doesn't include this visit.
- (*visits_to_perform.back())[best_idx] -= 1;
- receiver->push_back(NodeToProcess::Visit(
- child_node,
- static_cast(current_path.size() + 1 + base_depth)));
- completed_visits++;
- receiver->back().moves_to_visit.reserve(moves_to_path.size() + 1);
- receiver->back().moves_to_visit = moves_to_path;
- receiver->back().moves_to_visit.push_back(best_edge.GetMove());
- }
if (best_idx > vtp_last_filled.back() &&
(*visits_to_perform.back())[best_idx] > 0) {
vtp_last_filled.back() = best_idx;
}
+ history.Pop();
+ full_path.pop_back();
}
is_root_node = false;
// Actively do any splits now rather than waiting for potentially long
@@ -1753,29 +1809,32 @@ void SearchWorker::PickNodesToExtendTask(
collision_limit -
params_.GetMinimumRemainingWorkSizeForPicking()) {
Node* child_node = cur_iters[i].GetOrSpawnNode(/* parent */ node);
+ history.Append(cur_iters[i].GetMove());
+ auto [child_repetitions, child_moves_left] =
+ GetRepetitions(full_path.size(), history.Last());
+ full_path.push_back(
+ {child_node, child_repetitions, child_moves_left});
// Don't split if not expanded or terminal.
- if (child_node->GetN() == 0 || child_node->IsTerminal()) continue;
-
- bool passed = false;
- {
- // Multiple writers, so need mutex here.
- Mutex::Lock lock(picking_tasks_mutex_);
- // Ensure not to exceed size of reservation.
- if (picking_tasks_.size() < MAX_TASKS) {
- moves_to_path.push_back(cur_iters[i].GetMove());
- picking_tasks_.emplace_back(
- child_node, current_path.size() - 1 + base_depth + 1,
- moves_to_path, child_limit);
- moves_to_path.pop_back();
- task_count_.fetch_add(1, std::memory_order_acq_rel);
- task_added_.notify_all();
- passed = true;
- passed_off += child_limit;
+ if (!ShouldStopPickingHere(child_node, false, child_repetitions)) {
+ bool passed = false;
+ {
+ // Multiple writers, so need mutex here.
+ Mutex::Lock lock(picking_tasks_mutex_);
+ // Ensure not to exceed size of reservation.
+ if (picking_tasks_.size() < MAX_TASKS) {
+ picking_tasks_.emplace_back(full_path, history, child_limit);
+ task_count_.fetch_add(1, std::memory_order_acq_rel);
+ task_added_.notify_all();
+ passed = true;
+ passed_off += child_limit;
+ }
+ }
+ if (passed) {
+ (*visits_to_perform.back())[i] = 0;
}
}
- if (passed) {
- (*visits_to_perform.back())[i] = 0;
- }
+ history.Pop();
+ full_path.pop_back();
}
}
// Fall through to select the first child.
@@ -1787,14 +1846,13 @@ void SearchWorker::PickNodesToExtendTask(
for (auto& child : node->Edges()) {
idx++;
if (idx > min_idx && (*visits_to_perform.back())[idx] > 0) {
- if (moves_to_path.size() != current_path.size() + base_depth) {
- moves_to_path.push_back(child.GetMove());
- } else {
- moves_to_path.back() = child.GetMove();
- }
current_path.back() = idx;
current_path.push_back(-1);
node = child.GetOrSpawnNode(/* parent */ node);
+ history.Append(child.GetMove());
+ std::tie(repetitions, moves_left) =
+ GetRepetitions(full_path.size(), history.Last());
+ full_path.push_back({node, repetitions, moves_left});
found_child = true;
break;
}
@@ -1802,8 +1860,14 @@ void SearchWorker::PickNodesToExtendTask(
}
}
if (!found_child) {
- node = node->GetParent();
- if (!moves_to_path.empty()) moves_to_path.pop_back();
+ history.Pop();
+ full_path.pop_back();
+ if (full_path.size() > 0) {
+ std::tie(node, repetitions, moves_left) = full_path.back();
+ } else {
+ node = nullptr;
+ repetitions = 0;
+ }
current_path.pop_back();
vtp_buffer.push_back(std::move(visits_to_perform.back()));
visits_to_perform.pop_back();
@@ -1812,22 +1876,20 @@ void SearchWorker::PickNodesToExtendTask(
}
}
-void SearchWorker::ExtendNode(Node* node, int depth,
- const std::vector& moves_to_node,
- PositionHistory* history) {
- // Initialize position sequence with pre-move position.
- history->Trim(search_->played_history_.GetLength());
- for (size_t i = 0; i < moves_to_node.size(); i++) {
- history->Append(moves_to_node[i]);
- }
+void SearchWorker::ExtendNode(NodeToProcess& picked_node) {
+ const auto path = picked_node.path;
+ assert(!std::get<0>(path.back())->GetLowNode());
+
+ const PositionHistory& history = picked_node.history;
// We don't need the mutex because other threads will see that N=0 and
// N-in-flight=1 and will not touch this node.
- const auto& board = history->Last().GetBoard();
- auto legal_moves = board.GenerateLegalMoves();
+ const auto& board = history.Last().GetBoard();
+ std::vector legal_moves = board.GenerateLegalMoves();
// Check whether it's a draw/lose by position. Importantly, we must check
// these before doing the by-rule checks below.
+ auto node = picked_node.node;
if (legal_moves.empty()) {
// Could be a checkmate or a stalemate
if (board.IsUnderCheck()) {
@@ -1846,58 +1908,41 @@ void SearchWorker::ExtendNode(Node* node, int depth,
return;
}
- if (history->Last().GetRule50Ply() >= 100) {
+ if (history.Last().GetRule50Ply() >= 100) {
node->MakeTerminal(GameResult::DRAW);
return;
}
- const auto repetitions = history->Last().GetRepetitions();
- // Mark two-fold repetitions as draws according to settings.
- // Depth starts with 1 at root, so number of plies in PV is depth - 1.
- if (repetitions >= 2) {
- node->MakeTerminal(GameResult::DRAW);
- return;
- } else if (repetitions == 1 && depth - 1 >= 4 &&
- params_.GetTwoFoldDraws() &&
- depth - 1 >= history->Last().GetPliesSincePrevRepetition()) {
- const auto cycle_length = history->Last().GetPliesSincePrevRepetition();
- // use plies since first repetition as moves left; exact if forced draw.
- node->MakeTerminal(GameResult::DRAW, (float)cycle_length,
- Node::Terminal::TwoFold);
- return;
+ // Handle repetition draws as pseudo-terminals.
+ if (picked_node.repetitions >= 2) {
+ // Not a real terminal, set low node.
}
-
- // Neither by-position or by-rule termination, but maybe it's a TB position.
- if (search_->syzygy_tb_ && !search_->root_is_in_dtz_ &&
- board.castlings().no_legal_castle() &&
- history->Last().GetRule50Ply() == 0 &&
- (board.ours() | board.theirs()).count() <=
- search_->syzygy_tb_->max_cardinality()) {
+ // Neither by-position or by-rule termination, but maybe it's a TB
+ // position.
+ else if (search_->syzygy_tb_ && !search_->root_is_in_dtz_ &&
+ board.castlings().no_legal_castle() &&
+ history.Last().GetRule50Ply() == 0 &&
+ (board.ours() | board.theirs()).count() <=
+ search_->syzygy_tb_->max_cardinality()) {
ProbeState state;
const WDLScore wdl =
- search_->syzygy_tb_->probe_wdl(history->Last(), &state);
+ search_->syzygy_tb_->probe_wdl(history.Last(), &state);
// Only fail state means the WDL is wrong, probe_wdl may produce correct
// result with a stat other than OK.
if (state != FAIL) {
// TB nodes don't have NN evaluation, assign M from parent node.
float m = 0.0f;
- // Need a lock to access parent, in case MakeSolid is in progress.
- {
- SharedMutex::SharedLock lock(search_->nodes_mutex_);
- auto parent = node->GetParent();
- if (parent) {
- m = std::max(0.0f, parent->GetM() - 1.0f);
- }
+ if (path.size() > 1) {
+ auto parent = std::get<0>(path[path.size() - 2]);
+ m = std::max(0.0f, parent->GetM() - 1.0f);
}
// If the colors seem backwards, check the checkmate check above.
if (wdl == WDL_WIN) {
- node->MakeTerminal(GameResult::BLACK_WON, m,
- Node::Terminal::Tablebase);
+ node->MakeTerminal(GameResult::BLACK_WON, m, Terminal::Tablebase);
} else if (wdl == WDL_LOSS) {
- node->MakeTerminal(GameResult::WHITE_WON, m,
- Node::Terminal::Tablebase);
+ node->MakeTerminal(GameResult::WHITE_WON, m, Terminal::Tablebase);
} else { // Cursed wins and blessed losses count as draws.
- node->MakeTerminal(GameResult::DRAW, m, Node::Terminal::Tablebase);
+ node->MakeTerminal(GameResult::DRAW, m, Terminal::Tablebase);
}
search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel);
return;
@@ -1905,42 +1950,19 @@ void SearchWorker::ExtendNode(Node* node, int depth,
}
}
- // Add legal moves as edges of this node.
- node->CreateEdges(legal_moves);
-}
+ picked_node.nn_queried = true; // Node::SetLowNode() required.
-// Returns whether node was already in cache.
-bool SearchWorker::AddNodeToComputation(Node* node) {
- const auto hash = history_.HashLast(params_.GetCacheHistoryLength() + 1);
- if (search_->cache_->ContainsKey(hash)) {
- return true;
- }
- int transform;
- auto planes =
- EncodePositionForNN(search_->network_->GetCapabilities().input_format,
- history_, 8, params_.GetHistoryFill(), &transform);
-
- std::vector moves;
-
- if (node && node->HasChildren()) {
- // Legal moves are known, use them.
- moves.reserve(node->GetNumEdges());
- for (const auto& edge : node->Edges()) {
- moves.emplace_back(edge.GetMove().as_nn_index(transform));
- }
+ // Check the transposition table first and NN cache second before asking for
+ // NN evaluation.
+ picked_node.hash = search_->dag_->GetHistoryHash(history);
+ auto tt_low_node = search_->dag_->TTFind(picked_node.hash);
+ if (tt_low_node != nullptr) {
+ picked_node.tt_low_node = tt_low_node;
+ picked_node.is_tt_hit = true;
} else {
- // Cache pseudolegal moves. A bit of a waste, but faster.
- const auto& pseudolegal_moves =
- history_.Last().GetBoard().GeneratePseudolegalMoves();
- moves.reserve(pseudolegal_moves.size());
- for (auto iter = pseudolegal_moves.begin(), end = pseudolegal_moves.end();
- iter != end; ++iter) {
- moves.emplace_back(iter->as_nn_index(transform));
- }
+ picked_node.lock = NNCacheLock(search_->cache_, picked_node.hash);
+ picked_node.is_cache_hit = picked_node.lock;
}
-
- computation_->AddInput(hash, std::move(planes), std::move(moves));
- return false;
}
// 2b. Copy collisions into shared collisions.
@@ -1949,187 +1971,81 @@ void SearchWorker::CollectCollisions() {
for (const NodeToProcess& node_to_process : minibatch_) {
if (node_to_process.IsCollision()) {
- search_->shared_collisions_.emplace_back(node_to_process.node,
+ search_->shared_collisions_.emplace_back(node_to_process.path,
node_to_process.multivisit);
}
}
}
-// 3. Prefetch into cache.
-// ~~~~~~~~~~~~~~~~~~~~~~~
-void SearchWorker::MaybePrefetchIntoCache() {
- // TODO(mooskagh) Remove prefetch into cache if node collisions work well.
- // If there are requests to NN, but the batch is not full, try to prefetch
- // nodes which are likely useful in future.
- if (search_->stop_.load(std::memory_order_acquire)) return;
- if (computation_->GetCacheMisses() > 0 &&
- computation_->GetCacheMisses() < params_.GetMaxPrefetchBatch()) {
- history_.Trim(search_->played_history_.GetLength());
- SharedMutex::SharedLock lock(search_->nodes_mutex_);
- PrefetchIntoCache(
- search_->root_node_,
- params_.GetMaxPrefetchBatch() - computation_->GetCacheMisses(), false);
- }
-}
-
-// Prefetches up to @budget nodes into cache. Returns number of nodes
-// prefetched.
-int SearchWorker::PrefetchIntoCache(Node* node, int budget, bool is_odd_depth) {
- const float draw_score = search_->GetDrawScore(is_odd_depth);
- if (budget <= 0) return 0;
-
- // We are in a leaf, which is not yet being processed.
- if (!node || node->GetNStarted() == 0) {
- if (AddNodeToComputation(node)) {
- // Make it return 0 to make it not use the slot, so that the function
- // tries hard to find something to cache even among unpopular moves.
- // In practice that slows things down a lot though, as it's not always
- // easy to find what to cache.
- return 1;
- }
- return 1;
- }
-
- assert(node);
- // n = 0 and n_in_flight_ > 0, that means the node is being extended.
- if (node->GetN() == 0) return 0;
- // The node is terminal; don't prefetch it.
- if (node->IsTerminal()) return 0;
-
- // Populate all subnodes and their scores.
- typedef std::pair ScoredEdge;
- std::vector scores;
- const float cpuct =
- ComputeCpuct(params_, node->GetN(), node == search_->root_node_);
- const float puct_mult =
- cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
- const float fpu =
- GetFpu(params_, node, node == search_->root_node_, draw_score);
- for (auto& edge : node->Edges()) {
- if (edge.GetP() == 0.0f) continue;
- // Flip the sign of a score to be able to easily sort.
- // TODO: should this use logit_q if set??
- scores.emplace_back(-edge.GetU(puct_mult) - edge.GetQ(fpu, draw_score),
- edge);
- }
-
- size_t first_unsorted_index = 0;
- int total_budget_spent = 0;
- int budget_to_spend = budget; // Initialize for the case where there's only
- // one child.
- for (size_t i = 0; i < scores.size(); ++i) {
- if (search_->stop_.load(std::memory_order_acquire)) break;
- if (budget <= 0) break;
-
- // Sort next chunk of a vector. 3 at a time. Most of the time it's fine.
- if (first_unsorted_index != scores.size() &&
- i + 2 >= first_unsorted_index) {
- const int new_unsorted_index =
- std::min(scores.size(), budget < 2 ? first_unsorted_index + 2
- : first_unsorted_index + 3);
- std::partial_sort(scores.begin() + first_unsorted_index,
- scores.begin() + new_unsorted_index, scores.end(),
- [](const ScoredEdge& a, const ScoredEdge& b) {
- return a.first < b.first;
- });
- first_unsorted_index = new_unsorted_index;
- }
-
- auto edge = scores[i].second;
- // Last node gets the same budget as prev-to-last node.
- if (i != scores.size() - 1) {
- // Sign of the score was flipped for sorting, so flip it back.
- const float next_score = -scores[i + 1].first;
- // TODO: As above - should this use logit_q if set?
- const float q = edge.GetQ(-fpu, draw_score);
- if (next_score > q) {
- budget_to_spend =
- std::min(budget, int(edge.GetP() * puct_mult / (next_score - q) -
- edge.GetNStarted()) +
- 1);
- } else {
- budget_to_spend = budget;
- }
- }
- history_.Append(edge.GetMove());
- const int budget_spent =
- PrefetchIntoCache(edge.node(), budget_to_spend, !is_odd_depth);
- history_.Pop();
- budget -= budget_spent;
- total_budget_spent += budget_spent;
- }
- return total_budget_spent;
-}
-
// 4. Run NN computation.
// ~~~~~~~~~~~~~~~~~~~~~~
-void SearchWorker::RunNNComputation() { computation_->ComputeBlocking(); }
+void SearchWorker::RunNNComputation() {
+ computation_->ComputeBlocking(params_.GetPolicySoftmaxTemp());
+}
// 5. Retrieve NN computations (and terminal values) into nodes.
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void SearchWorker::FetchMinibatchResults() {
+ SharedMutex::Lock nodes_lock(search_->nodes_mutex_);
// Populate NN/cached results, or terminal results, into nodes.
int idx_in_computation = 0;
for (auto& node_to_process : minibatch_) {
FetchSingleNodeResult(&node_to_process, *computation_, idx_in_computation);
- if (node_to_process.nn_queried) ++idx_in_computation;
+ if (node_to_process.ShouldAddToInput()) ++idx_in_computation;
}
}
template
void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
const Computation& computation,
- int idx_in_computation) {
- if (node_to_process->IsCollision()) return;
- Node* node = node_to_process->node;
- if (!node_to_process->nn_queried) {
- // Terminal nodes don't involve the neural NetworkComputation, nor do
- // they require any further processing after value retrieval.
- node_to_process->v = node->GetWL();
- node_to_process->d = node->GetD();
- node_to_process->m = node->GetM();
- return;
- }
- // For NN results, we need to populate policy as well as value.
- // First the value...
- node_to_process->v = -computation.GetQVal(idx_in_computation);
- node_to_process->d = computation.GetDVal(idx_in_computation);
- node_to_process->m = computation.GetMVal(idx_in_computation);
- // ...and secondly, the policy data.
- // Calculate maximum first.
- float max_p = -std::numeric_limits::infinity();
- // Intermediate array to store values when processing policy.
- // There are never more than 256 valid legal moves in any legal position.
- std::array intermediate;
- int counter = 0;
- for (auto& edge : node->Edges()) {
- float p = computation.GetPVal(
- idx_in_computation,
- edge.GetMove().as_nn_index(node_to_process->probability_transform));
- intermediate[counter++] = p;
- max_p = std::max(max_p, p);
- }
- float total = 0.0;
- for (int i = 0; i < counter; i++) {
- // Perform softmax and take into account policy softmax temperature T.
- // Note that we want to calculate (exp(p-max_p))^(1/T) = exp((p-max_p)/T).
- float p =
- FastExp((intermediate[i] - max_p) / params_.GetPolicySoftmaxTemp());
- intermediate[i] = p;
- total += p;
- }
- counter = 0;
- // Normalize P values to add up to 1.0.
- const float scale = total > 0.0f ? 1.0f / total : 1.0f;
- for (auto& edge : node->Edges()) {
- edge.edge()->SetP(intermediate[counter++] * scale);
+ int idx_in_computation)
+ REQUIRES(search_->nodes_mutex_) {
+ if (!node_to_process->nn_queried) return;
+
+ if (!node_to_process->is_tt_hit) {
+ auto [tt_low_node, is_tt_miss] =
+ search_->dag_->TTGetOrCreate(node_to_process->hash);
+
+ assert(tt_low_node != nullptr);
+ node_to_process->tt_low_node = tt_low_node;
+ if (is_tt_miss) {
+ auto nn_eval = computation.GetNNEval(idx_in_computation).get();
+ if (params_.GetContemptPerspective() != ContemptPerspective::NONE) {
+ bool root_stm =
+ (params_.GetContemptPerspective() == ContemptPerspective::STM
+ ? !(search_->played_history_.Last().IsBlackToMove())
+ : (params_.GetContemptPerspective() ==
+ ContemptPerspective::WHITE));
+ auto sign = (root_stm ^ node_to_process->history.IsBlackToMove())
+ ? 1.0f
+ : -1.0f;
+ if (params_.GetWDLRescaleRatio() != 1.0f ||
+ params_.GetWDLRescaleDiff() != 0.0f) {
+ float v = nn_eval->q;
+ float d = nn_eval->d;
+ WDLRescale(v, d, nullptr, params_.GetWDLRescaleRatio(),
+ params_.GetWDLRescaleDiff(), sign, false);
+ nn_eval->q = v;
+ nn_eval->d = d;
+ }
+ }
+ node_to_process->tt_low_node->SetNNEval(nn_eval);
+ }
}
+
+ // Add NN results to node.
+ Node* node = node_to_process->node;
// Add Dirichlet noise if enabled and at root.
if (params_.GetNoiseEpsilon() && node == search_->root_node_) {
+ auto low_node = search_->dag_->NonTTAddClone(*node_to_process->tt_low_node);
+ assert(low_node != nullptr);
+ node->SetLowNode(low_node);
ApplyDirichletNoise(node, params_.GetNoiseEpsilon(),
params_.GetNoiseAlpha());
+ node->SortEdges();
+ } else {
+ node->SetLowNode(node_to_process->tt_low_node);
}
- node->SortEdges();
}
// 6. Propagate the new nodes' information to all their parents in the tree.
@@ -2150,60 +2066,152 @@ void SearchWorker::DoBackupUpdate() {
search_->total_batches_ += 1;
}
+bool SearchWorker::MaybeAdjustForTerminalOrTransposition(
+ Node* n, const LowNode* nl, float& v, float& d, float& m,
+ uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta,
+ bool& update_parent_bounds) const {
+ if (n->IsTerminal()) {
+ v = n->GetWL();
+ d = n->GetD();
+ m = n->GetM();
+
+ return true;
+ }
+
+ // Use information from transposition or a new terminal.
+ if (nl->IsTransposition() || nl->IsTerminal()) {
+ // Adapt information from low node to node by flipping Q sign, bounds,
+ // result and incrementing m.
+ v = -nl->GetWL();
+ d = nl->GetD();
+ m = nl->GetM() + 1;
+ // When starting at or going through a transposition/terminal, make sure to
+ // use the information it has already acquired.
+ n_to_fix = n->GetN();
+ v_delta = v - n->GetWL();
+ d_delta = d - n->GetD();
+ m_delta = m - n->GetM();
+ // Update bounds.
+ if (params_.GetStickyEndgames()) {
+ auto tt = nl->GetTerminalType();
+ if (tt != Terminal::NonTerminal) {
+ GameResult r;
+ if (v == 1.0f) {
+ r = GameResult::WHITE_WON;
+ } else if (v == -1.0f) {
+ r = GameResult::BLACK_WON;
+ } else {
+ r = GameResult::DRAW;
+ }
+
+ n->MakeTerminal(r, m, tt);
+ update_parent_bounds = true;
+ } else {
+ auto [lower, upper] = nl->GetBounds();
+ n->SetBounds(-upper, -lower);
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
+// Use information from terminal status or low node to update node and node's
+// parent low node and so on until the root is reached. Low node may become a
+// transposition and/or get more information even during this batch. Both low
+// node and node may adjust bounds and become a terminal during this batch.
void SearchWorker::DoBackupUpdateSingleNode(
const NodeToProcess& node_to_process) REQUIRES(search_->nodes_mutex_) {
- Node* node = node_to_process.node;
if (node_to_process.IsCollision()) {
// Collisions are handled via shared_collisions instead.
return;
}
+ auto path = node_to_process.path;
+ auto [n, nr, nm] = path.back();
// For the first visit to a terminal, maybe update parent bounds too.
auto update_parent_bounds =
- params_.GetStickyEndgames() && node->IsTerminal() && !node->GetN();
-
- // Backup V value up to a root. After 1 visit, V = Q.
- float v = node_to_process.v;
- float d = node_to_process.d;
- float m = node_to_process.m;
- int n_to_fix = 0;
+ params_.GetStickyEndgames() && n->IsTerminal() && !n->GetN();
+ auto nl = n->GetLowNode();
+ float v = 0.0f;
+ float d = 0.0f;
+ float m = 0.0f;
+ uint32_t n_to_fix = 0;
float v_delta = 0.0f;
float d_delta = 0.0f;
float m_delta = 0.0f;
- uint32_t solid_threshold =
- static_cast(params_.GetSolidTreeThreshold());
- for (Node *n = node, *p; n != search_->root_node_->GetParent(); n = p) {
- p = n->GetParent();
-
- // Current node might have become terminal from some other descendant, so
- // backup the rest of the way with more accurate values.
- if (n->IsTerminal()) {
- v = n->GetWL();
- d = n->GetD();
- m = n->GetM();
- }
+
+ // Update the low node at the start of the backup path first, but only visit
+ // it the first time that backup sees it.
+ if (nl && nl->GetN() == 0) {
+ nl->FinalizeScoreUpdate(nl->GetWL(), nl->GetD(), nl->GetM(),
+ node_to_process.multivisit);
+ }
+
+ if (nr >= 2) {
+ // Three-fold itself has to be handled as a terminal to produce relevant
+ // results. Unlike two-folds that can keep updating their "real" values.
+ n->SetRepetition();
+ v = 0.0f;
+ d = 1.0f;
+ m = 1;
+ } else if (!MaybeAdjustForTerminalOrTransposition(n, nl, v, d, m, n_to_fix,
+ v_delta, d_delta, m_delta,
+ update_parent_bounds)) {
+ // If there is nothing better, use original NN values adjusted for node.
+ v = -nl->GetWL();
+ d = nl->GetD();
+ m = nl->GetM() + 1;
+ }
+
+ // Backup V value up to a root. After 1 visit, V = Q.
+ for (auto it = path.crbegin(); it != path.crend();
+ /* ++it in the body */) {
n->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit);
if (n_to_fix > 0 && !n->IsTerminal()) {
+ // First part of the path might be never as it was removed and recreated.
+ n_to_fix = std::min(n_to_fix, n->GetN());
n->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix);
}
- if (n->GetN() >= solid_threshold) {
- if (n->MakeSolid() && n == search_->root_node_) {
- // If we make the root solid, the current_best_edge_ becomes invalid and
- // we should repopulate it.
- search_->current_best_edge_ =
- search_->GetBestChildNoTemperature(search_->root_node_, 0);
- }
+
+ // Stop delta update on repetition "terminal" and propagate a draw above
+ // repetitions valid on the current path.
+ // Only do this after edge update to have good values if play goes here.
+ if (nr == 1 && !n->IsTerminal()) {
+ n->SetRepetition();
+ v = 0.0f;
+ d = 1.0f;
+ m = nm + 1;
}
+ if (n->IsRepetition()) n_to_fix = 0;
// Nothing left to do without ancestors to update.
- if (!p) break;
+ if (++it == path.crend()) break;
+ auto [p, pr, pm] = *it;
+ auto pl = p->GetLowNode();
+
+ assert(!p->IsTerminal() ||
+ (p->IsTerminal() && pl->IsTerminal() && p->GetWL() == -pl->GetWL() &&
+ p->GetD() == pl->GetD()));
+ // If parent low node is already a (new) terminal, then change propagated
+ // values and stop terminal adjustment.
+ if (pl->IsTerminal()) {
+ v = pl->GetWL();
+ d = pl->GetD();
+ m = pl->GetM();
+ n_to_fix = 0;
+ }
+ pl->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit);
+ if (n_to_fix > 0) {
+ pl->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix);
+ }
bool old_update_parent_bounds = update_parent_bounds;
- // If parent already is terminal further adjustment is not required.
- if (p->IsTerminal()) n_to_fix = 0;
// Try setting parent bounds except the root or those already terminal.
update_parent_bounds =
- update_parent_bounds && p != search_->root_node_ && !p->IsTerminal() &&
+ update_parent_bounds && p != search_->root_node_ && !pl->IsTerminal() &&
MaybeSetBounds(p, m, &n_to_fix, &v_delta, &d_delta, &m_delta);
// Q will be flipped for opponent.
@@ -2211,6 +2219,10 @@ void SearchWorker::DoBackupUpdateSingleNode(
v_delta = -v_delta;
m++;
+ MaybeAdjustForTerminalOrTransposition(p, pl, v, d, m, n_to_fix, v_delta,
+ d_delta, m_delta,
+ update_parent_bounds);
+
// Update the stats.
// Best move.
// If update_parent_bounds was set, we just adjusted bounds on the
@@ -2226,19 +2238,25 @@ void SearchWorker::DoBackupUpdateSingleNode(
search_->current_best_edge_ =
search_->GetBestChildNoTemperature(search_->root_node_, 0);
}
+
+ n = p;
+ nr = pr;
+ nm = pm;
}
search_->total_playouts_ += node_to_process.multivisit;
- search_->cum_depth_ += node_to_process.depth * node_to_process.multivisit;
- search_->max_depth_ = std::max(search_->max_depth_, node_to_process.depth);
+ search_->cum_depth_ +=
+ node_to_process.path.size() * node_to_process.multivisit;
+ search_->max_depth_ =
+ std::max(search_->max_depth_, (uint16_t)node_to_process.path.size());
}
-bool SearchWorker::MaybeSetBounds(Node* p, float m, int* n_to_fix,
+bool SearchWorker::MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix,
float* v_delta, float* d_delta,
float* m_delta) const {
auto losing_m = 0.0f;
auto prefer_tb = false;
- // Determine the maximum (lower, upper) bounds across all children.
+ // Determine the maximum (lower, upper) bounds across all edges.
// (-1,-1) Loss (initial and lowest bounds)
// (-1, 0) Can't Win
// (-1, 1) Regular node
@@ -2274,6 +2292,7 @@ bool SearchWorker::MaybeSetBounds(Node* p, float m, int* n_to_fix,
// Win ( 1, 1) -> (-1,-1) Loss
// Nothing left to do for ancestors if the parent would be a regular node.
+ auto pl = p->GetLowNode();
if (lower == GameResult::BLACK_WON && upper == GameResult::WHITE_WON) {
return false;
} else if (lower == upper) {
@@ -2281,19 +2300,19 @@ bool SearchWorker::MaybeSetBounds(Node* p, float m, int* n_to_fix,
// it terminal preferring shorter wins and longer losses.
*n_to_fix = p->GetN();
assert(*n_to_fix > 0);
- float cur_v = p->GetWL();
- float cur_d = p->GetD();
- float cur_m = p->GetM();
+ pl->MakeTerminal(
+ upper, (upper == GameResult::BLACK_WON ? std::max(losing_m, m) : m),
+ prefer_tb ? Terminal::Tablebase : Terminal::EndOfGame);
+ // v, d and m will be set in MaybeAdjustForTerminalOrTransposition.
+ *v_delta = pl->GetWL() + p->GetWL();
+ *d_delta = pl->GetD() - p->GetD();
+ *m_delta = pl->GetM() + 1 - p->GetM();
p->MakeTerminal(
-upper,
(upper == GameResult::BLACK_WON ? std::max(losing_m, m) : m) + 1.0f,
- prefer_tb ? Node::Terminal::Tablebase : Node::Terminal::EndOfGame);
- // Negate v_delta because we're calculating for the parent, but immediately
- // afterwards we'll negate v_delta in case it has come from the child.
- *v_delta = -(p->GetWL() - cur_v);
- *d_delta = p->GetD() - cur_d;
- *m_delta = p->GetM() - cur_m;
+ prefer_tb ? Terminal::Tablebase : Terminal::EndOfGame);
} else {
+ pl->SetBounds(lower, upper);
p->SetBounds(-upper, -lower);
}
diff --git a/src/mcts/search.h b/src/mcts/search.h
index 1323320bca..b2547b9b03 100644
--- a/src/mcts/search.h
+++ b/src/mcts/search.h
@@ -33,6 +33,8 @@
#include
#include
#include
+#include
+#include
#include "chess/callbacks.h"
#include "chess/uciloop.h"
@@ -40,16 +42,17 @@
#include "mcts/params.h"
#include "mcts/stoppers/timemgr.h"
#include "neural/cache.h"
-#include "neural/network.h"
#include "syzygy/syzygy.h"
#include "utils/logging.h"
#include "utils/mutex.h"
namespace lczero {
+typedef std::vector> BackupPath;
+
class Search {
public:
- Search(const NodeTree& tree, Network* network,
+ Search(NodeTree* dag, Network* network,
std::unique_ptr uci_responder,
const MoveList& searchmoves,
std::chrono::steady_clock::time_point start_time,
@@ -96,7 +99,7 @@ class Search {
void ResetBestMove();
// Returns NN eval for a given node from cache, if that node is cached.
- NNCacheLock GetCachedNNEval(const Node* node) const;
+ NNCacheLock GetCachedNNEval(const PositionHistory& history) const;
private:
// Computes the best move, maybe with temperature (according to the settings).
@@ -163,6 +166,7 @@ class Search {
Node* root_node_;
NNCache* cache_;
+ NodeTree* dag_;
SyzygyTablebase* syzygy_tb_;
// Fixed positions which happened before the search.
const PositionHistory& played_history_;
@@ -196,7 +200,7 @@ class Search {
std::atomic backend_waiting_counter_{0};
std::atomic