Skip to content

Commit

Permalink
Move root checks to Params and default to same
Browse files Browse the repository at this point in the history
  • Loading branch information
Mardak committed Mar 8, 2019
1 parent 2f9f00f commit 39625b2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
11 changes: 7 additions & 4 deletions src/mcts/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<ChoiceOption>(kFpuStrategyId, fpu_strategy) = "reduction";
options->Add<FloatOption>(kFpuValueId, -100.0f, 100.0f) = 1.2f;
fpu_strategy.push_back("same");
options->Add<ChoiceOption>(kFpuStrategyAtRootId, fpu_strategy) = "absolute";
options->Add<ChoiceOption>(kFpuStrategyAtRootId, fpu_strategy) = "same";
options->Add<FloatOption>(kFpuValueAtRootId, -100.0f, 100.0f) = 1.0f;
options->Add<IntOption>(kCacheHistoryLengthId, 0, 7) = 0;
options->Add<FloatOption>(kPolicySoftmaxTempId, 0.1f, 10.0f) = 2.2f;
Expand Down Expand Up @@ -232,10 +232,13 @@ SearchParams::SearchParams(const OptionsDict& options)
"absolute"),
kFpuValue(options.Get<float>(kFpuValueId.GetId())),
kFpuAbsoluteAtRoot(
(options.Get<std::string>(kFpuStrategyAtRootId.GetId()) == "same" &&
kFpuAbsolute) ||
options.Get<std::string>(kFpuStrategyAtRootId.GetId()) == "absolute"),
kFpuReductionAtRoot(options.Get<std::string>(
kFpuStrategyAtRootId.GetId()) == "reduction"),
kFpuValueAtRoot(options.Get<float>(kFpuValueAtRootId.GetId())),
kFpuValueAtRoot(options.Get<std::string>(kFpuStrategyAtRootId.GetId()) ==
"same"
? kFpuValue
: options.Get<float>(kFpuValueAtRootId.GetId())),
kCacheHistoryLength(options.Get<int>(kCacheHistoryLengthId.GetId())),
kPolicySoftmaxTemp(options.Get<float>(kPolicySoftmaxTempId.GetId())),
kMaxCollisionEvents(options.Get<int>(kMaxCollisionEventsId.GetId())),
Expand Down
8 changes: 2 additions & 6 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,8 @@ class SearchParams {
return options_.Get<bool>(kLogLiveStatsId.GetId());
}
float GetSmartPruningFactor() const { return kSmartPruningFactor; }
bool GetFpuAbsolute() const { return kFpuAbsolute; }
float GetFpuValue() const { return kFpuValue; }
bool GetFpuAbsoluteAtRoot() const { return kFpuAbsoluteAtRoot; }
bool GetFpuReductionAtRoot() const { return kFpuReductionAtRoot; }
float GetFpuValueAtRoot() const { return kFpuValueAtRoot; }
bool GetFpuAbsolute(bool at_root) const { return at_root ? kFpuAbsoluteAtRoot : kFpuAbsolute; }
float GetFpuValue(bool at_root) const { return at_root ? kFpuValueAtRoot : kFpuValue; }
int GetCacheHistoryLength() const { return kCacheHistoryLength; }
float GetPolicySoftmaxTemp() const { return kPolicySoftmaxTemp; }
int GetMaxCollisionEvents() const { return kMaxCollisionEvents; }
Expand Down Expand Up @@ -149,7 +146,6 @@ class SearchParams {
const bool kFpuAbsolute;
const float kFpuValue;
const bool kFpuAbsoluteAtRoot;
const bool kFpuReductionAtRoot;
const float kFpuValueAtRoot;
const int kCacheHistoryLength;
const float kPolicySoftmaxTemp;
Expand Down
15 changes: 4 additions & 11 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,10 @@ int64_t Search::GetTimeToDeadline() const {

namespace {
inline float GetFpu(const SearchParams& params, Node* node, bool is_root_node) {
// Use root FPU behavior unless it's "same"
if (is_root_node) {
if (params.GetFpuAbsoluteAtRoot()) return params.GetFpuValueAtRoot();
if (params.GetFpuReductionAtRoot())
return -node->GetQ() -
params.GetFpuValueAtRoot() * std::sqrt(node->GetVisitedPolicy());
}
return params.GetFpuAbsolute()
? params.GetFpuValue()
: -node->GetQ() -
params.GetFpuValue() * std::sqrt(node->GetVisitedPolicy());
const auto value = params.GetFpuValue(is_root_node);
return params.GetFpuAbsolute(is_root_node)
? value
: -node->GetQ() - value * std::sqrt(node->GetVisitedPolicy());
}

inline float ComputeCpuct(const SearchParams& params, uint32_t N) {
Expand Down

0 comments on commit 39625b2

Please sign in to comment.