diff --git a/src/mcts/params.cc b/src/mcts/params.cc index e30b7d94b4..416aca51b7 100644 --- a/src/mcts/params.cc +++ b/src/mcts/params.cc @@ -202,7 +202,7 @@ void SearchParams::Populate(OptionsParser* options) { options->Add(kFpuStrategyId, fpu_strategy) = "reduction"; options->Add(kFpuValueId, -100.0f, 100.0f) = 1.2f; fpu_strategy.push_back("same"); - options->Add(kFpuStrategyAtRootId, fpu_strategy) = "absolute"; + options->Add(kFpuStrategyAtRootId, fpu_strategy) = "same"; options->Add(kFpuValueAtRootId, -100.0f, 100.0f) = 1.0f; options->Add(kCacheHistoryLengthId, 0, 7) = 0; options->Add(kPolicySoftmaxTempId, 0.1f, 10.0f) = 2.2f; @@ -232,10 +232,13 @@ SearchParams::SearchParams(const OptionsDict& options) "absolute"), kFpuValue(options.Get(kFpuValueId.GetId())), kFpuAbsoluteAtRoot( + (options.Get(kFpuStrategyAtRootId.GetId()) == "same" && + kFpuAbsolute) || options.Get(kFpuStrategyAtRootId.GetId()) == "absolute"), - kFpuReductionAtRoot(options.Get( - kFpuStrategyAtRootId.GetId()) == "reduction"), - kFpuValueAtRoot(options.Get(kFpuValueAtRootId.GetId())), + kFpuValueAtRoot(options.Get(kFpuStrategyAtRootId.GetId()) == + "same" + ? kFpuValue + : options.Get(kFpuValueAtRootId.GetId())), kCacheHistoryLength(options.Get(kCacheHistoryLengthId.GetId())), kPolicySoftmaxTemp(options.Get(kPolicySoftmaxTempId.GetId())), kMaxCollisionEvents(options.Get(kMaxCollisionEventsId.GetId())), diff --git a/src/mcts/params.h b/src/mcts/params.h index 5d6612ee6c..c106e6c46e 100644 --- a/src/mcts/params.h +++ b/src/mcts/params.h @@ -78,11 +78,8 @@ class SearchParams { return options_.Get(kLogLiveStatsId.GetId()); } float GetSmartPruningFactor() const { return kSmartPruningFactor; } - bool GetFpuAbsolute() const { return kFpuAbsolute; } - float GetFpuValue() const { return kFpuValue; } - bool GetFpuAbsoluteAtRoot() const { return kFpuAbsoluteAtRoot; } - bool GetFpuReductionAtRoot() const { return kFpuReductionAtRoot; } - float GetFpuValueAtRoot() const { return kFpuValueAtRoot; } + bool GetFpuAbsolute(bool at_root) const { return at_root ? kFpuAbsoluteAtRoot : kFpuAbsolute; } + float GetFpuValue(bool at_root) const { return at_root ? kFpuValueAtRoot : kFpuValue; } int GetCacheHistoryLength() const { return kCacheHistoryLength; } float GetPolicySoftmaxTemp() const { return kPolicySoftmaxTemp; } int GetMaxCollisionEvents() const { return kMaxCollisionEvents; } @@ -149,7 +146,6 @@ class SearchParams { const bool kFpuAbsolute; const float kFpuValue; const bool kFpuAbsoluteAtRoot; - const bool kFpuReductionAtRoot; const float kFpuValueAtRoot; const int kCacheHistoryLength; const float kPolicySoftmaxTemp; diff --git a/src/mcts/search.cc b/src/mcts/search.cc index cdb100125c..ae4169e6d9 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -189,17 +189,10 @@ int64_t Search::GetTimeToDeadline() const { namespace { inline float GetFpu(const SearchParams& params, Node* node, bool is_root_node) { - // Use root FPU behavior unless it's "same" - if (is_root_node) { - if (params.GetFpuAbsoluteAtRoot()) return params.GetFpuValueAtRoot(); - if (params.GetFpuReductionAtRoot()) - return -node->GetQ() - - params.GetFpuValueAtRoot() * std::sqrt(node->GetVisitedPolicy()); - } - return params.GetFpuAbsolute() - ? params.GetFpuValue() - : -node->GetQ() - - params.GetFpuValue() * std::sqrt(node->GetVisitedPolicy()); + const auto value = params.GetFpuValue(is_root_node); + return params.GetFpuAbsolute(is_root_node) + ? value + : -node->GetQ() - value * std::sqrt(node->GetVisitedPolicy()); } inline float ComputeCpuct(const SearchParams& params, uint32_t N) {