-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay.py
51 lines (40 loc) · 1.55 KB
/
replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from collections import namedtuple, deque
import random
# Batch namedtuple, i.e. a class which contains the given attributes
Batch = namedtuple(
'Batch', ('states', 'actions', 'rewards', 'next_states', 'dones')
)
Transition = namedtuple(
'Transition', ('state', 'action', 'reward', 'next_state', 'done')
)
class ReplayMemory:
def __init__(self, max_size):
"""
Replay Memory initialized as a circular buffer
:param max_size: The size of the Memory
"""
self.memory = deque([], maxlen=max_size)
self.max_size = max_size
def add(self, state, action, reward, next_state, done):
"""
Add a transition to the buffer.
:param state: 1-D np.ndarray of state-features.
:param action: integer action.
:param reward: float reward.
:param next_state: 1-D np.ndarray of state-features.
:param done: boolean value indicating the end of an episode.
"""
transition = Transition(state, action, reward, next_state, done)
self.memory.append(transition)
def sample(self, batch_size) -> Batch:
"""Sample a batch of experiences.
If the buffer contains less that `batch_size` transitions, sample all
of them.
:param batch_size: Number of transitions to sample.
:rtype: Batch
"""
if len(self.memory) < batch_size:
batch_size = len(self.memory)
transitions = random.sample(self.memory, k=batch_size)
batch = Batch(*zip(*transitions))
return batch