Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

To ensure MakeTerminal value overwrites any MCTS calculated value #677

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,17 @@ std::string Node::DebugString() const {
return oss.str();
}

void Node::MakeTerminal(GameResult result) {
is_terminal_ = true;
if (result == GameResult::DRAW) {
q_ = 0.0f;
} else if (result == GameResult::WHITE_WON) {
q_ = 1.0f;
} else if (result == GameResult::BLACK_WON) {
q_ = -1.0f;
}
void Node::MakeTerminal(GameResult result, int depth) {
is_terminal_ = true;
if (result == GameResult::DRAW) {
q_ = 0.0f;
}
else if (result == GameResult::WHITE_WON) {
q_ = 10.0f - (depth / 50.0f);
}
else if (result == GameResult::BLACK_WON) {
q_ = -10.0f + (depth / 50.0f);
}
}

bool Node::TryStartScoreUpdate() {
Expand All @@ -240,8 +242,10 @@ void Node::CancelScoreUpdate(int multivisit) {

void Node::FinalizeScoreUpdate(float v, int multivisit) {
// Recompute Q.
q_ += multivisit * (v - q_) / (n_ + multivisit);
// If first visit, update parent's sum of policies visited at least once.
if (q_ > -0.98f && q_ < 0.98f) { // keep terminal values of -1 and 1 when found
q_ += multivisit * (v - q_) / (n_ + multivisit);
}
// If first visit, update parent's sum of policies visited at least once.
if (n_ == 0 && parent_ != nullptr) {
parent_->visited_policy_ += parent_->edges_[index_].GetP();
}
Expand Down
2 changes: 1 addition & 1 deletion src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class Node {
uint16_t GetNumEdges() const { return edges_.size(); }

// Makes the node terminal and sets it's score.
void MakeTerminal(GameResult result);
void MakeTerminal(GameResult result, int depth);

// If this node is not in the process of being expanded by another thread
// (which can happen only if n==0 and n-in-flight==1), mark the node as
Expand Down
19 changes: 9 additions & 10 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,7 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend(
// Either terminal or unexamined leaf node -- the end of this playout.
if (!node->HasChildren()) {
if (node->IsTerminal()) {
IncrementNInFlight(node, search_->root_node_, collision_limit - 1);
return NodeToProcess::TerminalHit(node, depth, collision_limit);
return NodeToProcess::TerminalHit(node, depth, 1);
} else {
return NodeToProcess::Extension(node, depth);
}
Expand Down Expand Up @@ -968,9 +967,9 @@ void SearchWorker::ExtendNode(Node* node) {
if (legal_moves.empty()) {
// Could be a checkmate or a stalemate
if (board.IsUnderCheck()) {
node->MakeTerminal(GameResult::WHITE_WON);
node->MakeTerminal(GameResult::WHITE_WON, search_->played_history_.GetLength());
} else {
node->MakeTerminal(GameResult::DRAW);
node->MakeTerminal(GameResult::DRAW, search_->played_history_.GetLength());
}
return;
}
Expand All @@ -979,17 +978,17 @@ void SearchWorker::ExtendNode(Node* node) {
// if they are root, then thinking about them is the point.
if (node != search_->root_node_) {
if (!board.HasMatingMaterial()) {
node->MakeTerminal(GameResult::DRAW);
node->MakeTerminal(GameResult::DRAW, search_->played_history_.GetLength());
return;
}

if (history_.Last().GetNoCaptureNoPawnPly() >= 100) {
node->MakeTerminal(GameResult::DRAW);
node->MakeTerminal(GameResult::DRAW, search_->played_history_.GetLength());
return;
}

if (history_.Last().GetRepetitions() >= 2) {
node->MakeTerminal(GameResult::DRAW);
node->MakeTerminal(GameResult::DRAW, search_->played_history_.GetLength());
return;
}

Expand All @@ -1005,11 +1004,11 @@ void SearchWorker::ExtendNode(Node* node) {
if (state != FAIL) {
// If the colors seem backwards, check the checkmate check above.
if (wdl == WDL_WIN) {
node->MakeTerminal(GameResult::BLACK_WON);
node->MakeTerminal(GameResult::BLACK_WON, search_->played_history_.GetLength());
} else if (wdl == WDL_LOSS) {
node->MakeTerminal(GameResult::WHITE_WON);
node->MakeTerminal(GameResult::WHITE_WON, search_->played_history_.GetLength());
} else { // Cursed wins and blessed losses count as draws.
node->MakeTerminal(GameResult::DRAW);
node->MakeTerminal(GameResult::DRAW, search_->played_history_.GetLength());
}
search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel);
return;
Expand Down