-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathutils.py
32 lines (25 loc) · 1.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
class ReplayBuffer:
def __init__(self, max_size=5e5):
self.buffer = []
self.max_size = int(max_size)
self.size = 0
def add(self, transition):
self.size +=1
# transiton is tuple of (state, action, reward, next_state, done)
self.buffer.append(transition)
def sample(self, batch_size):
# delete 1/5th of the buffer when full
if self.size > self.max_size:
del self.buffer[0:int(self.size/5)]
self.size = len(self.buffer)
indexes = np.random.randint(0, len(self.buffer), size=batch_size)
state, action, reward, next_state, done = [], [], [], [], []
for i in indexes:
s, a, r, s_, d = self.buffer[i]
state.append(np.array(s, copy=False))
action.append(np.array(a, copy=False))
reward.append(np.array(r, copy=False))
next_state.append(np.array(s_, copy=False))
done.append(np.array(d, copy=False))
return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)