diff --git a/src/mcts/params.cc b/src/mcts/params.cc index 5c1665d218..416aca51b7 100644 --- a/src/mcts/params.cc +++ b/src/mcts/params.cc @@ -109,24 +109,27 @@ const OptionId SearchParams::kSmartPruningFactorId{ "pruning is deactivated."}; const OptionId SearchParams::kFpuStrategyId{ "fpu-strategy", "FpuStrategy", - "How is an eval of unvisited node determined. \"reduction\" subtracts " - "--fpu-reduction value from the parent eval. \"absolute\" sets eval of " - "unvisited nodes to the value specified in --fpu-value."}; -// TODO(crem) Make FPU in "reduction" mode use fpu-value too. For now it's kept -// for backwards compatibility. -const OptionId SearchParams::kFpuReductionId{ - "fpu-reduction", "FpuReduction", - "\"First Play Urgency\" reduction (used when FPU strategy is " - "\"reduction\"). Normally when a move has no visits, " - "it's eval is assumed to be equal to parent's eval. With non-zero FPU " - "reduction, eval of unvisited move is decreased by that value, " - "discouraging visits of unvisited moves, and saving those visits for " - "(hopefully) more promising moves."}; + "How is an eval of unvisited node determined. \"First Play Urgency\" " + "changes search behavior to visit unvisited nodes earlier or later by " + "using a placeholder eval before checking the network. The value specified " + "with --fpu-value results in \"reduction\" subtracting that value from the " + "parent eval while \"absolute\" directly uses that value."}; const OptionId SearchParams::kFpuValueId{ "fpu-value", "FpuValue", - "\"First Play Urgency\" value. When FPU strategy is \"absolute\", value of " - "unvisited node is assumed to be equal to this value, and does not depend " - "on parent eval."}; + "\"First Play Urgency\" value used to adjust unvisited node eval based on " + "--fpu-strategy."}; +const OptionId SearchParams::kFpuStrategyAtRootId{ + "fpu-strategy-at-root", "FpuStrategyAtRoot", + "How is an eval of unvisited root children determined. Just like " + "--fpu-strategy except only at the root level and adjusts unvisited root " + "children eval with --fpu-value-at-root. In addition to matching the " + "strategies from --fpu-strategy, this can be \"same\" to disable the " + "special root behavior."}; +const OptionId SearchParams::kFpuValueAtRootId{ + "fpu-value-at-root", "FpuValueAtRoot", + "\"First Play Urgency\" value used to adjust unvisited root children eval " + "based on --fpu-strategy-at-root. Has no effect if --fpu-strategy-at-root " + "is \"same\"."}; const OptionId SearchParams::kCacheHistoryLengthId{ "cache-history-length", "CacheHistoryLength", "Length of history, in half-moves, to include into the cache key. When " @@ -197,8 +200,10 @@ void SearchParams::Populate(OptionsParser* options) { options->Add(kSmartPruningFactorId, 0.0f, 10.0f) = 1.33f; std::vector fpu_strategy = {"reduction", "absolute"}; options->Add(kFpuStrategyId, fpu_strategy) = "reduction"; - options->Add(kFpuReductionId, -100.0f, 100.0f) = 1.2f; - options->Add(kFpuValueId, -1.0f, 1.0f) = -1.0f; + options->Add(kFpuValueId, -100.0f, 100.0f) = 1.2f; + fpu_strategy.push_back("same"); + options->Add(kFpuStrategyAtRootId, fpu_strategy) = "same"; + options->Add(kFpuValueAtRootId, -100.0f, 100.0f) = 1.0f; options->Add(kCacheHistoryLengthId, 0, 7) = 0; options->Add(kPolicySoftmaxTempId, 0.1f, 10.0f) = 2.2f; options->Add(kMaxCollisionEventsId, 1, 1024) = 32; @@ -225,8 +230,15 @@ SearchParams::SearchParams(const OptionsDict& options) kSmartPruningFactor(options.Get(kSmartPruningFactorId.GetId())), kFpuAbsolute(options.Get(kFpuStrategyId.GetId()) == "absolute"), - kFpuReduction(options.Get(kFpuReductionId.GetId())), kFpuValue(options.Get(kFpuValueId.GetId())), + kFpuAbsoluteAtRoot( + (options.Get(kFpuStrategyAtRootId.GetId()) == "same" && + kFpuAbsolute) || + options.Get(kFpuStrategyAtRootId.GetId()) == "absolute"), + kFpuValueAtRoot(options.Get(kFpuStrategyAtRootId.GetId()) == + "same" + ? kFpuValue + : options.Get(kFpuValueAtRootId.GetId())), kCacheHistoryLength(options.Get(kCacheHistoryLengthId.GetId())), kPolicySoftmaxTemp(options.Get(kPolicySoftmaxTempId.GetId())), kMaxCollisionEvents(options.Get(kMaxCollisionEventsId.GetId())), diff --git a/src/mcts/params.h b/src/mcts/params.h index 21bcf0d9a2..c106e6c46e 100644 --- a/src/mcts/params.h +++ b/src/mcts/params.h @@ -78,9 +78,8 @@ class SearchParams { return options_.Get(kLogLiveStatsId.GetId()); } float GetSmartPruningFactor() const { return kSmartPruningFactor; } - bool GetFpuAbsolute() const { return kFpuAbsolute; } - float GetFpuReduction() const { return kFpuReduction; } - float GetFpuValue() const { return kFpuValue; } + bool GetFpuAbsolute(bool at_root) const { return at_root ? kFpuAbsoluteAtRoot : kFpuAbsolute; } + float GetFpuValue(bool at_root) const { return at_root ? kFpuValueAtRoot : kFpuValue; } int GetCacheHistoryLength() const { return kCacheHistoryLength; } float GetPolicySoftmaxTemp() const { return kPolicySoftmaxTemp; } int GetMaxCollisionEvents() const { return kMaxCollisionEvents; } @@ -116,8 +115,9 @@ class SearchParams { static const OptionId kLogLiveStatsId; static const OptionId kSmartPruningFactorId; static const OptionId kFpuStrategyId; - static const OptionId kFpuReductionId; static const OptionId kFpuValueId; + static const OptionId kFpuStrategyAtRootId; + static const OptionId kFpuValueAtRootId; static const OptionId kCacheHistoryLengthId; static const OptionId kPolicySoftmaxTempId; static const OptionId kMaxCollisionEventsId; @@ -144,8 +144,9 @@ class SearchParams { const bool kNoise; const float kSmartPruningFactor; const bool kFpuAbsolute; - const float kFpuReduction; const float kFpuValue; + const bool kFpuAbsoluteAtRoot; + const float kFpuValueAtRoot; const int kCacheHistoryLength; const float kPolicySoftmaxTemp; const int kMaxCollisionEvents; diff --git a/src/mcts/search.cc b/src/mcts/search.cc index 29d995f0ab..ae4169e6d9 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -189,13 +189,10 @@ int64_t Search::GetTimeToDeadline() const { namespace { inline float GetFpu(const SearchParams& params, Node* node, bool is_root_node) { - return params.GetFpuAbsolute() - ? params.GetFpuValue() - : ((is_root_node && params.GetNoise()) || - !params.GetFpuReduction()) - ? -node->GetQ() - : -node->GetQ() - params.GetFpuReduction() * - std::sqrt(node->GetVisitedPolicy()); + const auto value = params.GetFpuValue(is_root_node); + return params.GetFpuAbsolute(is_root_node) + ? value + : -node->GetQ() - value * std::sqrt(node->GetVisitedPolicy()); } inline float ComputeCpuct(const SearchParams& params, uint32_t N) { diff --git a/src/selfplay/tournament.cc b/src/selfplay/tournament.cc index 6849797f35..5c2d7b4e59 100644 --- a/src/selfplay/tournament.cc +++ b/src/selfplay/tournament.cc @@ -98,7 +98,7 @@ void SelfPlayTournament::PopulateOptions(OptionsParser* options) { defaults->Set(SearchParams::kSmartPruningFactorId.GetId(), 0.0f); defaults->Set(SearchParams::kTemperatureId.GetId(), 1.0f); defaults->Set(SearchParams::kNoiseId.GetId(), true); - defaults->Set(SearchParams::kFpuReductionId.GetId(), 0.0f); + defaults->Set(SearchParams::kFpuValueId.GetId(), 0.0f); defaults->Set(SearchParams::kHistoryFillId.GetId(), "no"); defaults->Set(NetworkFactory::kBackendId.GetId(), "multiplexing");