-
Notifications
You must be signed in to change notification settings - Fork 6
/
dataset.py
150 lines (130 loc) · 5.42 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import numpy as np
import pickle
import os
import sys
import time
from scipy import misc
class Transition(object):
def __init__(self, s, a, r, ss, t):
self.state = s
self.action = a
self.reward = r
self.next_state = ss
self.terminal = t
class DataSet(object):
def __init__(self, path, dataset_size=10000, state_shape=[4], state_dtype=np.float32, nb_actions=9):
self.size = 0
self.head = 0
self.tail = 0
self.frame_shape = (42, 42)
self.dataset_size = dataset_size
self.nb_actions = nb_actions
self.state_shape = state_shape
self.state_dtype = state_dtype
self._minibatch_size = None
self.states = np.zeros([self.dataset_size] +
list(self.state_shape), dtype=self.state_dtype)
self.actions = np.zeros(self.dataset_size, dtype='int32')
self.rewards = np.zeros(self.dataset_size, dtype='float32')
self.terms = np.zeros(self.dataset_size, dtype='bool')
self.policy = np.zeros([self.dataset_size, self.nb_actions], dtype=self.state_dtype)
self.qfunction = np.zeros([self.dataset_size, self.nb_actions], dtype=self.state_dtype)
# Path to the dataset
self.path = path
if self.path != None and not os.path.exists(self.path):
os.makedirs(self.path)
def add(self, s, a, r, t, p, q=[]):
self.states[self.tail] = s
self.actions[self.tail] = a
self.rewards[self.tail] = r
self.terms[self.tail] = t
self.policy[self.tail] = p
if len(q) > 0:
self.qfunction[self.tail] = q
self.tail = (self.tail + 1) % self.dataset_size
if self.size == self.dataset_size:
self.head = (self.head + 1) % self.dataset_size
else:
self.size += 1
def reset(self):
self.size = 0
self.head = 0
self.tail = 0
self._minibatch_size = None
self.states = np.zeros([self.dataset_size] + list(self.state_shape), dtype=self.state_dtype)
self.terms = np.zeros(self.dataset_size, dtype='bool')
self.actions = np.zeros(self.dataset_size, dtype='int32')
self.rewards = np.zeros(self.dataset_size, dtype='float32')
self.policy = np.zeros([self.dataset_size, self.nb_actions], dtype='float32')
self.qfunction = np.zeros([self.dataset_size, self.nb_actions], dtype='float32')
def save_dataset(self, filename):
full_path = os.path.join(self.path, filename)
print("Saving dataset to {}".format(full_path))
if not os.path.exists(os.path.dirname(full_path)):
print('Creating dataset directory {}'.format(os.path.dirname(full_path)))
os.makedirs(os.path.dirname(full_path))
with open(full_path, "wb") as f:
pickle.dump([self.states, self.actions, self.rewards, self.terms, self.policy, self.qfunction], f)
print("Dataset saved")
'''
Loads the batch of transitions and creates transitions objects from them.
(s, a, r, s')
'''
def load_dataset(self, filename):
'''
It is used to compute the density of state-action pairs.
rho(s, a) = rho_a(s) * marginal(a)
Where rho(s, a) is the density of the pair (s, a)
And rho_action is the density of state s estimated from all the states where action a was taken.
See Bellemare's Unifying count-based Exploration and Intrinsic Motivation for more details
'''
full_path = os.path.join(self.path, filename)
if not os.path.exists(full_path):
raise ValueError("We could not find the dataset file: {}".format(full_path))
print("\nLoading dataset from file {}".format(full_path))
with open(full_path, "rb") as f:
self.states, self.actions, self.rewards, self.terms, self.policy, self.qfunction = pickle.load(f)
self.dataset_size = self.states.shape[0]
self.size = self.dataset_size
def counts_weights(self):
return np.mean(self.states, axis=0), np.std(self.states, axis=0)
class Dataset_Counts(object):
def __init__(self, data, param, dtype=np.float32):
self.data = data
self.dataset_size = len(self.data['s'])
self.state_shape = self.data['s'][0].shape
self.nb_actions = self.data['p'][0].shape[0]
self.dtype = dtype
self.param = param
self.mean = np.mean(self.data['s'], axis=0)
self.std = np.std(self.data['s'], axis=0)
@staticmethod
def distance(x1, x2):
return np.linalg.norm(x1-x2)
@staticmethod
def similarite(x1, x2, param, mean, std):
return max(0, 1 - Dataset_Counts.distance(x1, x2) / param)
def sample(self, batch_size=1):
s = np.zeros([batch_size] + list(self.state_shape),
dtype=self.dtype)
s2 = np.zeros([batch_size] + list(self.state_shape),
dtype=self.dtype)
t = np.zeros(batch_size, dtype='bool')
a = np.zeros(batch_size, dtype='int32')
r = np.zeros(batch_size, dtype='float32')
c1 = np.zeros(batch_size, dtype='float32')
c = np.zeros([batch_size, self.nb_actions], dtype='float32')
p = np.zeros([batch_size, self.nb_actions], dtype='float32')
for i in range(batch_size):
j = np.random.randint(self.dataset_size)
s[i], a[i], r[i] = self.data['s'][j], self.data['a'][j], self.data['r'][j]
s2[i], t[i], c[i], p[i] = self.data['s2'][j], self.data['t'][j], self.data['c'][j], self.data['p'][j]
if 'c1' in self.data:
c1[i] = self.data['c1'][j]
return s, a, r, s2, t, c, p, c1
def compute_counts(self, state):
counts = np.zeros(self.nb_actions)
for j in range(self.dataset_size):
s = Dataset_Counts.similarite(state, self.data['s'][j], self.param, self.mean, self.std)
counts[self.data['a'][j]] += s
return counts