-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.js
109 lines (88 loc) · 3.79 KB
/
main.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
document.addEventListener("DOMContentLoaded", () => {
// Execute our main function after DOM is loaded.
main();
});
function main() {
// Create our Model
const numberRecognizerModel = createModel();
const padCanvas = document.getElementById("drawPad");
const padContext = padCanvas.getContext("2d");
let mouseDownInDrawPad = false;
let lastX = 0;
let lastY = 0;
function updateMouseStates(mouseDown, x, y) {
mouseDownInDrawPad = mouseDown;
lastX = x;
lastY = y;
}
padCanvas.addEventListener('mousedown', (event) => {
drawPixel(padContext, lastX, lastY, event.offsetX, event.offsetY);
// Update mouse states
updateMouseStates(true, event.offsetX, event.offsetY);
});
padCanvas.addEventListener('mousemove', (event) => {
if (mouseDownInDrawPad) {
drawPixel(padContext, lastX, lastY, event.offsetX, event.offsetY);
// Update mouse states
updateMouseStates(mouseDownInDrawPad, event.offsetX, event.offsetY);
}
});
window.addEventListener('mouseup', (event) => {
if (mouseDownInDrawPad) {
// We will trigger digit recognition after the user
// finishes drawing a segment.
recognizeDigitWithModel(numberRecognizerModel);
// Reset mouse states
updateMouseStates(false, 0, 0);
}
});
document.getElementById("train").addEventListener("click", async () => {
// Execute model training.
await trainModel(numberRecognizerModel);
});
document.getElementById("clear").addEventListener('click', () => {
// Clear drawPad canvas by initializing it with black.
initBlackCanvas(padContext, padCanvas.width, padCanvas.height)
});
// Initialize the drawPad canvas with black
initBlackCanvas(padContext, padCanvas.width, padCanvas.height)
}
function recognizeDigitWithModel(model) {
const padCanvas = document.getElementById("drawPad");
const tfInputCanvas = document.getElementById("tfInput");
const tfInputContext = tfInputCanvas.getContext("2d");
// Reset to identity matrix to unset any previous scale calls...
tfInputContext.setTransform(1, 0, 0, 1, 0, 0);
// Scale down the image from the drawPad into a size expected by the mode...
tfInputContext.clearRect(0, 0, tfInputCanvas.width, tfInputCanvas.height);
tfInputContext.scale(0.1, 0.1);
tfInputContext.drawImage(padCanvas, 0, 0);
// Get and process the scaled down image data building an array of values matching the mnist dataset...
const digitImage = tfInputContext.getImageData(0, 0, tfInputCanvas.width, tfInputCanvas.height);
const bwImage = [];
for (let pixelIdx = 0; pixelIdx < (digitImage.height * digitImage.width * 4); pixelIdx += 4) {
const r = digitImage.data[pixelIdx];
const g = digitImage.data[pixelIdx + 1];
const b = digitImage.data[pixelIdx + 2];
const a = digitImage.data[pixelIdx + 3];
const avgPixVal = (r + g + b) / 3;
bwImage.push(avgPixVal / 255);
}
const inputTensor = tf.tensor(bwImage).reshape([1, 28, 28, 1]);
const prediction = predictWithModel(model, inputTensor);
// Print the prediction itself
prediction.print();
// Get the index of the output with the highest probability
const digitIdx = prediction.argMax(1).get(0);
// Set the output in HTML
setOutputToDigit(digitIdx);
}
function setOutputToDigit(idx) {
console.log('Setting digit: ' + idx);
const outputElement = document.getElementById("output");
if (idx >= 0 && idx <= 9) {
outputElement.innerHTML = idx.toString();
} else {
outputElement.innerHTML = "Unknown!";
}
}