-
Notifications
You must be signed in to change notification settings - Fork 3
/
policy.py
159 lines (121 loc) · 4.12 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""RL Policy classes.
We have provided you with a base policy class, some example
implementations and some unimplemented classes that should be useful
in your code.
"""
import numpy as np
# import attr
class Policy:
"""Base class representing an MDP policy.
Policies are used by the agent to choose actions.
Policies are designed to be stacked to get interesting behaviors
of choices. For instances in a discrete action space the lowest
level policy may take in Q-Values and select the action index
corresponding to the largest value. If this policy is wrapped in
an epsilon greedy policy then with some probability epsilon, a
random action will be chosen.
"""
def select_action(self, **kwargs):
"""Used by agents to select actions.
Returns
-------
Any:
An object representing the chosen action. Type depends on
the hierarchy of policy instances.
"""
raise NotImplementedError('This method should be overriden.')
class UniformRandomPolicy(Policy):
"""Chooses a discrete action with uniform random probability.
This is provided as a reference on how to use the policy class.
Parameters
----------
num_actions: int
Number of actions to choose from. Must be > 0.
Raises
------
ValueError:
If num_actions <= 0
"""
def __init__(self, num_actions):
assert num_actions >= 1
self.num_actions = num_actions
def select_action(self, **kwargs):
"""Return a random action index.
This policy cannot contain others (as they would just be ignored).
Returns
-------
int:
Action index in range [0, num_actions)
"""
return np.random.randint(0, self.num_actions)
def get_config(self): # noqa: D102
return {'num_actions': self.num_actions}
class GreedyPolicy(Policy):
"""Always returns best action according to Q-values.
This is a pure exploitation policy.
"""
def select_action(self, q_values, **kwargs): # noqa: D102
return np.argmax(q_values)
class GreedyEpsilonPolicy(Policy):
"""Selects greedy action or with some probability a random action.
Standard greedy-epsilon implementation. With probability epsilon
choose a random action. Otherwise choose the greedy action.
Parameters
----------
epsilon: float
Initial probability of choosing a random action. Can be changed
over time.
"""
def __init__(self, epsilon):
self.epsilon = epsilon
def select_action(self, q_values, **kwargs):
"""Run Greedy-Epsilon for the given Q-values.
Parameters
----------
q_values: array-like
Array-like structure of floats representing the Q-values for
each action.
Returns
-------
int:
The action index chosen.
"""
sample = np.random.random_sample()
nA = q_values.shape[0]
if sample > self.epsilon:
return np.argmax(q_values)
else:
return np.random.randint(0, nA)
class LinearDecayGreedyEpsilonPolicy(Policy):
"""Policy with a parameter that decays linearly.
Like GreedyEpsilonPolicy but the epsilon decays from a start value
to an end value over k steps.
Parameters
----------
start_value: int, float
The initial value of the parameter
end_value: int, float
The value of the policy at the end of the decay.
num_steps: int
The number of steps over which to decay the value.
"""
def __init__(self, policy, attr_name, start_value, end_value,
num_steps): # noqa: D102
pass
def select_action(self, **kwargs):
"""Decay parameter and select action.
Parameters
----------
q_values: np.array
The Q-values for each action.
is_training: bool, optional
If true then parameter will be decayed. Defaults to true.
Returns
-------
Any:
Selected action.
"""
pass
def reset(self):
"""Start the decay over at the start value."""
pass