-
Notifications
You must be signed in to change notification settings - Fork 15
/
train.js
118 lines (100 loc) · 2.96 KB
/
train.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
const fs = require('fs');
const path = require('path');
const parseArgs = require('minimist');
const linesCount = require('file-lines-count');
const csv = require('csv-parser');
const tf = require('@tensorflow/tfjs-node');
const {
data: dataDir = 'data',
model: modelDir = 'model',
epochs = 10
} = parseArgs(process.argv.slice(2));
const pathToCSV = path.join(dataDir, 'driving_log.csv');
async function* dataGenerator() {
while (true) {
const csvStream = fs.createReadStream(pathToCSV).pipe(
csv({
headers: [
'center',
'left',
'right',
'steering',
'throttle',
'brake',
'speed'
],
mapValues: ({ value }) => value.trim()
})
);
for await (const { center, left, right, steering } of csvStream) {
const centerImageBuffer = fs.promises.readFile(center);
const leftImageBuffer = fs.promises.readFile(left);
const rightImageBuffer = fs.promises.readFile(right);
const offset = 0.333;
yield [await centerImageBuffer, Number(steering)];
yield [await leftImageBuffer, Number(steering) + offset];
yield [await rightImageBuffer, Number(steering) - offset];
}
csvStream.destroy();
}
}
async function initModel() {
let model;
try {
model = await tf.loadLayersModel(`file://${modelDir}/model.json`);
console.log(`Model loaded from: ${modelDir}`);
} catch {
model = tf.sequential({
layers: [
tf.layers.cropping2D({
cropping: [
[75, 25],
[0, 0]
],
inputShape: [160, 320, 3]
}),
tf.layers.conv2d({
filters: 16,
kernelSize: [3, 3],
strides: [2, 2],
activation: 'relu'
}),
tf.layers.maxPool2d({ poolSize: [2, 2] }),
tf.layers.conv2d({
filters: 32,
kernelSize: [3, 3],
strides: [2, 2],
activation: 'relu'
}),
tf.layers.maxPool2d({ poolSize: [2, 2] }),
tf.layers.flatten(),
tf.layers.dense({ units: 1024, activation: 'relu' }),
tf.layers.dropout({ rate: 0.25 }),
tf.layers.dense({ units: 128, activation: 'relu' }),
tf.layers.dense({ units: 1, activation: 'linear' })
]
});
}
model.compile({ optimizer: 'adam', loss: 'meanSquaredError' });
return model;
}
(async function () {
const batchSize = 64;
const dataset = tf.data
.generator(dataGenerator)
.map(([imageBuffer, steering]) => {
const xs = tf.node.decodeJpeg(imageBuffer).div(255);
const ys = tf.tensor1d([steering]);
return { xs, ys };
})
.shuffle(batchSize)
.batch(batchSize);
const model = await initModel();
const totalSamples = (await linesCount(pathToCSV)) * 3;
await model.fitDataset(dataset, {
epochs,
batchesPerEpoch: Math.floor(totalSamples / batchSize)
});
await model.save(`file://${modelDir}`);
console.log(`Model saved to: ${modelDir}`);
})();