-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
index.js
418 lines (391 loc) · 14.9 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* TensorFlow.js Reinforcement Learning Example: Balancing a Cart-Pole System.
*
* The simulation, training, testing and visualization parts are written
* purely in JavaScript and can run in the web browser with WebGL acceleration.
*
* This reinforcement learning (RL) problem was proposed in:
*
* - Barto, Sutton, and Anderson, "Neuronlike Adaptive Elements That Can Solve
* Difficult Learning Control Problems," IEEE Trans. Syst., Man, Cybern.,
* Vol. SMC-13, pp. 834--846, Sept.--Oct. 1983
* - Sutton, "Temporal Aspects of Credit Assignment in Reinforcement Learning",
* Ph.D. Dissertation, Department of Computer and Information Science,
* University of Massachusetts, Amherst, 1984.
*
* It later became one of OpenAI's gym environmnets:
* https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
*/
import * as tf from '@tensorflow/tfjs';
import {maybeRenderDuringTraining, onGameEnd, setUpUI} from './ui';
/**
* Policy network for controlling the cart-pole system.
*
* The role of the policy network is to select an action based on the observed
* state of the system. In this case, the action is the leftward or rightward
* force and the observed system state is a four-dimensional vector, consisting
* of cart position, cart velocity, pole angle and pole angular velocity.
*
*/
class PolicyNetwork {
/**
* Constructor of PolicyNetwork.
*
* @param {number | number[] | tf.LayersModel} hiddenLayerSizes
* Can be any of the following
* - Size of the hidden layer, as a single number (for a single hidden
* layer)
* - An Array of numbers (for any number of hidden layers).
* - An instance of tf.LayersModel.
*/
constructor(hiddenLayerSizesOrModel) {
if (hiddenLayerSizesOrModel instanceof tf.LayersModel) {
this.policyNet = hiddenLayerSizesOrModel;
} else {
this.createPolicyNetwork(hiddenLayerSizesOrModel);
}
}
/**
* Create the underlying model of this policy network.
*
* @param {number | number[]} hiddenLayerSizes Size of the hidden layer, as
* a single number (for a single hidden layer) or an Array of numbers (for
* any number of hidden layers).
*/
createPolicyNetwork(hiddenLayerSizes) {
if (!Array.isArray(hiddenLayerSizes)) {
hiddenLayerSizes = [hiddenLayerSizes];
}
this.policyNet = tf.sequential();
hiddenLayerSizes.forEach((hiddenLayerSize, i) => {
this.policyNet.add(tf.layers.dense({
units: hiddenLayerSize,
activation: 'elu',
// `inputShape` is required only for the first layer.
inputShape: i === 0 ? [4] : undefined
}));
});
// The last layer has only one unit. The single output number will be
// converted to a probability of selecting the leftward-force action.
this.policyNet.add(tf.layers.dense({units: 1}));
}
/**
* Train the policy network's model.
*
* @param {CartPole} cartPoleSystem The cart-pole system object to use during
* training.
* @param {tf.train.Optimizer} optimizer An instance of TensorFlow.js
* Optimizer to use for training.
* @param {number} discountRate Reward discounting rate: a number between 0
* and 1.
* @param {number} numGames Number of game to play for each model parameter
* update.
* @param {number} maxStepsPerGame Maximum number of steps to perform during
* a game. If this number is reached, the game will end immediately.
* @returns {number[]} The number of steps completed in the `numGames` games
* in this round of training.
*/
async train(
cartPoleSystem, optimizer, discountRate, numGames, maxStepsPerGame) {
const allGradients = [];
const allRewards = [];
const gameSteps = [];
onGameEnd(0, numGames);
for (let i = 0; i < numGames; ++i) {
// Randomly initialize the state of the cart-pole system at the beginning
// of every game.
cartPoleSystem.setRandomState();
const gameRewards = [];
const gameGradients = [];
for (let j = 0; j < maxStepsPerGame; ++j) {
// For every step of the game, remember gradients of the policy
// network's weights with respect to the probability of the action
// choice that lead to the reward.
const gradients = tf.tidy(() => {
const inputTensor = cartPoleSystem.getStateTensor();
return this.getGradientsAndSaveActions(inputTensor).grads;
});
this.pushGradients(gameGradients, gradients);
const action = this.currentActions_[0];
const isDone = cartPoleSystem.update(action);
await maybeRenderDuringTraining(cartPoleSystem);
if (isDone) {
// When the game ends before max step count is reached, a reward of
// 0 is given.
gameRewards.push(0);
break;
} else {
// As long as the game doesn't end, each step leads to a reward of 1.
// These reward values will later be "discounted", leading to
// higher reward values for longer-lasting games.
gameRewards.push(1);
}
}
onGameEnd(i + 1, numGames);
gameSteps.push(gameRewards.length);
this.pushGradients(allGradients, gameGradients);
allRewards.push(gameRewards);
await tf.nextFrame();
}
tf.tidy(() => {
// The following line does three things:
// 1. Performs reward discounting, i.e., make recent rewards count more
// than rewards from the further past. The effect is that the reward
// values from a game with many steps become larger than the values
// from a game with fewer steps.
// 2. Normalize the rewards, i.e., subtract the global mean value of the
// rewards and divide the result by the global standard deviation of
// the rewards. Together with step 1, this makes the rewards from
// long-lasting games positive and rewards from short-lasting
// negative.
// 3. Scale the gradients with the normalized reward values.
const normalizedRewards =
discountAndNormalizeRewards(allRewards, discountRate);
// Add the scaled gradients to the weights of the policy network. This
// step makes the policy network more likely to make choices that lead
// to long-lasting games in the future (i.e., the crux of this RL
// algorithm.)
optimizer.applyGradients(
scaleAndAverageGradients(allGradients, normalizedRewards));
});
tf.dispose(allGradients);
return gameSteps;
}
getGradientsAndSaveActions(inputTensor) {
const f = () => tf.tidy(() => {
const [logits, actions] = this.getLogitsAndActions(inputTensor);
this.currentActions_ = actions.dataSync();
const labels =
tf.sub(1, tf.tensor2d(this.currentActions_, actions.shape));
return tf.losses.sigmoidCrossEntropy(labels, logits).asScalar();
});
return tf.variableGrads(f);
}
getCurrentActions() {
return this.currentActions_;
}
/**
* Get policy-network logits and the action based on state-tensor inputs.
*
* @param {tf.Tensor} inputs A tf.Tensor instance of shape `[batchSize, 4]`.
* @returns {[tf.Tensor, tf.Tensor]}
* 1. The logits tensor, of shape `[batchSize, 1]`.
* 2. The actions tensor, of shape `[batchSize, 1]`.
*/
getLogitsAndActions(inputs) {
return tf.tidy(() => {
const logits = this.policyNet.predict(inputs);
// Get the probability of the leftward action.
const leftProb = tf.sigmoid(logits);
// Probabilities of the left and right actions.
const leftRightProbs = tf.concat([leftProb, tf.sub(1, leftProb)], 1);
const actions = tf.multinomial(leftRightProbs, 1, null, true);
return [logits, actions];
});
}
/**
* Get actions based on a state-tensor input.
*
* @param {tf.Tensor} inputs A tf.Tensor instance of shape `[batchSize, 4]`.
* @param {Float32Array} inputs The actions for the inputs, with length
* `batchSize`.
*/
getActions(inputs) {
return this.getLogitsAndActions(inputs)[1].dataSync();
}
/**
* Push a new dictionary of gradients into records.
*
* @param {{[varName: string]: tf.Tensor[]}} record The record of variable
* gradient: a map from variable name to the Array of gradient values for
* the variable.
* @param {{[varName: string]: tf.Tensor}} gradients The new gradients to push
* into `record`: a map from variable name to the gradient Tensor.
*/
pushGradients(record, gradients) {
for (const key in gradients) {
if (key in record) {
record[key].push(gradients[key]);
} else {
record[key] = [gradients[key]];
}
}
}
}
// The IndexedDB path where the model of the policy network will be saved.
const MODEL_SAVE_PATH_ = 'indexeddb://cart-pole-v1';
/**
* A subclass of PolicyNetwork that supports saving and loading.
*/
export class SaveablePolicyNetwork extends PolicyNetwork {
/**
* Constructor of SaveablePolicyNetwork
*
* @param {number | number[]} hiddenLayerSizesOrModel
*/
constructor(hiddenLayerSizesOrModel) {
super(hiddenLayerSizesOrModel);
}
/**
* Save the model to IndexedDB.
*/
async saveModel() {
return await this.policyNet.save(MODEL_SAVE_PATH_);
}
/**
* Load the model from IndexedDB.
*
* @returns {SaveablePolicyNetwork} The instance of loaded
* `SaveablePolicyNetwork`.
* @throws {Error} If no model can be found in IndexedDB.
*/
static async loadModel() {
const modelsInfo = await tf.io.listModels();
if (MODEL_SAVE_PATH_ in modelsInfo) {
console.log(`Loading existing model...`);
const model = await tf.loadLayersModel(MODEL_SAVE_PATH_);
console.log(`Loaded model from ${MODEL_SAVE_PATH_}`);
return new SaveablePolicyNetwork(model);
} else {
throw new Error(`Cannot find model at ${MODEL_SAVE_PATH_}.`);
}
}
/**
* Check the status of locally saved model.
*
* @returns If the locally saved model exists, the model info as a JSON
* object. Else, `undefined`.
*/
static async checkStoredModelStatus() {
const modelsInfo = await tf.io.listModels();
return modelsInfo[MODEL_SAVE_PATH_];
}
/**
* Remove the locally saved model from IndexedDB.
*/
async removeModel() {
return await tf.io.removeModel(MODEL_SAVE_PATH_);
}
/**
* Get the sizes of the hidden layers.
*
* @returns {number | number[]} If the model has only one hidden layer,
* return the size of the layer as a single number. If the model has
* multiple hidden layers, return the sizes as an Array of numbers.
*/
hiddenLayerSizes() {
const sizes = [];
for (let i = 0; i < this.policyNet.layers.length - 1; ++i) {
sizes.push(this.policyNet.layers[i].units);
}
return sizes.length === 1 ? sizes[0] : sizes;
}
}
/**
* Discount the reward values.
*
* @param {number[]} rewards The reward values to be discounted.
* @param {number} discountRate Discount rate: a number between 0 and 1, e.g.,
* 0.95.
* @returns {tf.Tensor} The discounted reward values as a 1D tf.Tensor.
*/
function discountRewards(rewards, discountRate) {
const discountedBuffer = tf.buffer([rewards.length]);
let prev = 0;
for (let i = rewards.length - 1; i >= 0; --i) {
const current = discountRate * prev + rewards[i];
discountedBuffer.set(current, i);
prev = current;
}
return discountedBuffer.toTensor();
}
/**
* Discount and normalize reward values.
*
* This function performs two steps:
*
* 1. Discounts the reward values using `discountRate`.
* 2. Normalize the reward values with the global reward mean and standard
* deviation.
*
* @param {number[][]} rewardSequences Sequences of reward values.
* @param {number} discountRate Discount rate: a number between 0 and 1, e.g.,
* 0.95.
* @returns {tf.Tensor[]} The discounted and normalize reward values as an
* Array of tf.Tensor.
*/
function discountAndNormalizeRewards(rewardSequences, discountRate) {
return tf.tidy(() => {
const discounted = [];
for (const sequence of rewardSequences) {
discounted.push(discountRewards(sequence, discountRate))
}
// Compute the overall mean and stddev.
const concatenated = tf.concat(discounted);
const mean = tf.mean(concatenated);
const std = tf.sqrt(tf.mean(tf.square(concatenated.sub(mean))));
// Normalize the reward sequences using the mean and std.
const normalized = discounted.map(rs => rs.sub(mean).div(std));
return normalized;
});
}
/**
* Scale the gradient values using normalized reward values and compute average.
*
* The gradient values are scaled by the normalized reward values. Then they
* are averaged across all games and all steps.
*
* @param {{[varName: string]: tf.Tensor[][]}} allGradients A map from variable
* name to all the gradient values for the variable across all games and all
* steps.
* @param {tf.Tensor[]} normalizedRewards An Array of normalized reward values
* for all the games. Each element of the Array is a 1D tf.Tensor of which
* the length equals the number of steps in the game.
* @returns {{[varName: string]: tf.Tensor}} Scaled and averaged gradients
* for the variables.
*/
function scaleAndAverageGradients(allGradients, normalizedRewards) {
return tf.tidy(() => {
const gradients = {};
for (const varName in allGradients) {
gradients[varName] = tf.tidy(() => {
// Stack gradients together.
const varGradients = allGradients[varName].map(
varGameGradients => tf.stack(varGameGradients));
// Expand dimensions of reward tensors to prepare for multiplication
// with broadcasting.
const expandedDims = [];
for (let i = 0; i < varGradients[0].rank - 1; ++i) {
expandedDims.push(1);
}
const reshapedNormalizedRewards = normalizedRewards.map(
rs => rs.reshape(rs.shape.concat(expandedDims)));
for (let g = 0; g < varGradients.length; ++g) {
// This mul() call uses broadcasting.
varGradients[g] = varGradients[g].mul(reshapedNormalizedRewards[g]);
}
// Concatenate the scaled gradients together, then average them across
// all the steps of all the games.
return tf.mean(tf.concat(varGradients, 0), 0);
});
}
return gradients;
});
}
setUpUI();