Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] MCTS Scoring functions #2358

Open
wants to merge 14 commits into
base: gh/vmoens/6/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ def vertices(
if h in memo and not use_path:
continue
memo.add(h)
r = tree.rollout
if r is not None:
r = r["next", "observation"]
if use_path:
result[cur_path] = tree
elif use_id:
Expand Down
5 changes: 5 additions & 0 deletions torchrl/modules/mcts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .scores import PUCTScore, UCBScore
100 changes: 100 additions & 0 deletions torchrl/modules/mcts/scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import math
from abc import abstractmethod
from enum import Enum

from tensordict import NestedKey, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from torch import nn


class MCTSScore(TensorDictModuleBase):
@abstractmethod
def forward(self, node):
pass


class PUCTScore(MCTSScore):
c: float

def __init__(
self,
*,
c: float,
win_count_key: NestedKey = "win_count",
visits_key: NestedKey = "visits",
total_visits_key: NestedKey = "total_visits",
prior_prob_key: NestedKey = "prior_prob",
score_key: NestedKey = "score",
):
super().__init__()
self.c = c
self.win_count_key = win_count_key
self.visits_key = visits_key
self.total_visits_key = total_visits_key
self.prior_prob_key = prior_prob_key
self.score_key = score_key
self.in_keys = [
self.win_count_key,
self.prior_prob_key,
self.total_visits_key,
self.visits_key,
]
self.out_keys = [self.score_key]

def forward(self, node: TensorDictBase) -> TensorDictBase:
win_count = node.get(self.win_count_key)
visits = node.get(self.visits_key)
n_total = node.get(self.total_visits_key)
prior_prob = node.get(self.prior_prob_key)
node.set(
self.score_key,
(win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits),
)
return node


class UCBScore(MCTSScore):
c: float

def __init__(
self,
*,
c: float,
win_count_key: NestedKey = "win_count",
visits_key: NestedKey = "visits",
total_visits_key: NestedKey = "total_visits",
score_key: NestedKey = "score",
):
super().__init__()
self.c = c
self.win_count_key = win_count_key
self.visits_key = visits_key
self.total_visits_key = total_visits_key
self.score_key = score_key
self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key]
self.out_keys = [self.score_key]

def forward(self, node: TensorDictBase) -> TensorDictBase:
win_count = node.get(self.win_count_key)
visits = node.get(self.visits_key)
n_total = node.get(self.total_visits_key)
node.set(
self.score_key,
(win_count / visits) + self.c * n_total.sqrt() / (1 + visits),
)
return node


class MCTSScores(Enum):
PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value
UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002
UCB1_TUNED = "UCB1-Tuned"
EXP3 = "EXP3"
PUCT_VARIANT = "PUCT-Variant"
Loading