-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathvrmccfr.go
145 lines (122 loc) · 4.13 KB
/
vrmccfr.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package cfr
import (
"math/rand"
"github.com/timpalpant/go-cfr/internal/f32"
)
type VRMCCFR struct {
strategyProfile StrategyProfile
traversingSampler Sampler
notTraversingSampler Sampler
slicePool *floatSlicePool
mapPool *keyIntMapPool
rng *rand.Rand
traversingPlayer int
sampledActions map[string]int
}
func NewVRMCCFR(strategyProfile StrategyProfile, traversingSampler, notTraversingSampler Sampler) *VRMCCFR {
return &VRMCCFR{
strategyProfile: strategyProfile,
traversingSampler: traversingSampler,
notTraversingSampler: notTraversingSampler,
slicePool: &floatSlicePool{},
mapPool: &keyIntMapPool{},
rng: rand.New(rand.NewSource(rand.Int63())),
}
}
func (c *VRMCCFR) Run(node GameTreeNode) float32 {
iter := c.strategyProfile.Iter()
c.traversingPlayer = int(iter % 2)
c.sampledActions = c.mapPool.alloc()
defer c.mapPool.free(c.sampledActions)
return c.runHelper(node, node.Player(), 1.0, 1.0)
}
func (c *VRMCCFR) runHelper(node GameTreeNode, lastPlayer int, sampleProb, reachProb float32) float32 {
var ev float32
switch node.Type() {
case TerminalNodeType:
ev = float32(node.Utility(lastPlayer))
case ChanceNodeType:
ev = c.handleChanceNode(node, lastPlayer, sampleProb, reachProb)
default:
sgn := getSign(lastPlayer, node.Player())
ev = sgn * c.handlePlayerNode(node, sampleProb, reachProb)
}
node.Close()
return ev
}
func (c *VRMCCFR) handleChanceNode(node GameTreeNode, lastPlayer int, sampleProb, reachProb float32) float32 {
child, p := node.SampleChild()
return c.runHelper(child, lastPlayer, float32(p)*sampleProb, float32(p)*reachProb)
}
func (c *VRMCCFR) handlePlayerNode(node GameTreeNode, sampleProb, reachProb float32) float32 {
if node.Player() == c.traversingPlayer {
return c.handleTraversingPlayerNode(node, sampleProb, reachProb)
} else {
return c.handleSampledPlayerNode(node, sampleProb, reachProb)
}
}
func (c *VRMCCFR) handleTraversingPlayerNode(node GameTreeNode, sampleProb, reachProb float32) float32 {
player := node.Player()
nChildren := node.NumChildren()
if nChildren == 1 {
// Optimization to skip trivial nodes with no real choice.
child := node.GetChild(0)
return c.runHelper(child, player, sampleProb, reachProb)
}
policy := c.strategyProfile.GetPolicy(node)
baseline := policy.GetBaseline()
qs := c.slicePool.alloc(nChildren)
copy(qs, c.traversingSampler.Sample(node, policy))
regrets := c.slicePool.alloc(nChildren)
oldSampledActions := c.sampledActions
c.sampledActions = c.mapPool.alloc()
for i, q := range qs {
child := node.GetChild(i)
uHat := baseline[i]
if q > 0 {
u := c.runHelper(child, player, q*sampleProb, reachProb)
uHat += (u - baseline[i]) / q
policy.UpdateBaseline(1.0/q, i, u)
}
regrets[i] = uHat
}
cfValue := f32.DotUnitary(policy.GetStrategy(), regrets)
f32.AddConst(-cfValue, regrets)
policy.AddRegret(reachProb/sampleProb, qs, regrets)
c.slicePool.free(qs)
c.slicePool.free(regrets)
c.mapPool.free(c.sampledActions)
c.sampledActions = oldSampledActions
return cfValue
}
// Sample player action according to strategy, do not update policy.
// Save selected action so that they are reused if this infoset is hit again.
func (c *VRMCCFR) handleSampledPlayerNode(node GameTreeNode, sampleProb, reachProb float32) float32 {
policy := c.strategyProfile.GetPolicy(node)
player := node.Player()
nChildren := node.NumChildren()
baseline := policy.GetBaseline()
strategy := policy.GetStrategy()
// Update average strategy for this node.
// We perform "stochastic" updates as described in the MC-CFR paper.
if sampleProb > 0 {
policy.AddStrategyWeight(1.0 / sampleProb)
}
qs := c.slicePool.alloc(nChildren)
copy(qs, c.notTraversingSampler.Sample(node, policy))
regrets := c.slicePool.alloc(nChildren)
for i, q := range qs {
p := strategy[i]
child := node.GetChild(i)
uHat := baseline[i]
if q > 0 {
u := c.runHelper(child, player, q*sampleProb, p*reachProb)
uHat += (u - baseline[i]) / q
policy.UpdateBaseline(1.0/q, i, u)
}
regrets[i] = uHat
}
c.slicePool.free(qs)
c.slicePool.free(regrets)
return f32.DotUnitary(policy.GetStrategy(), regrets)
}