-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples: add MNIST training + missing ops
- Loading branch information
1 parent
46e22f5
commit 879dcb8
Showing
24 changed files
with
1,819 additions
and
1,680 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ CMakeSettings.json | |
.clangd | ||
|
||
.venv/ | ||
ggml_env/ | ||
.exrc | ||
.cache | ||
.DS_Store | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
data/ | ||
*.gguf | ||
*.ggml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,20 @@ | ||
# | ||
# mnist | ||
# mnist-common | ||
|
||
set(TEST_TARGET mnist) | ||
add_executable(${TEST_TARGET} main.cpp) | ||
set(TEST_TARGET mnist-common) | ||
add_library(${TEST_TARGET} mnist-common.cpp) | ||
target_link_libraries(${TEST_TARGET} PRIVATE ggml common) | ||
|
||
# | ||
# mnist-cnn | ||
# mnist-eval | ||
|
||
set(TEST_TARGET mnist-cnn) | ||
add_executable(${TEST_TARGET} main-cnn.cpp) | ||
target_link_libraries(${TEST_TARGET} PRIVATE ggml common) | ||
set(TEST_TARGET mnist-eval) | ||
add_executable(${TEST_TARGET} mnist-eval.cpp) | ||
target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common) | ||
|
||
# | ||
# mnist-cpu | ||
|
||
set(TEST_TARGET mnist-cpu) | ||
add_executable(${TEST_TARGET} main-cpu.cpp) | ||
target_link_libraries(${TEST_TARGET} PRIVATE ggml) | ||
|
||
if (APPLE) | ||
# | ||
# mnist-mtl | ||
|
||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED) | ||
find_library(METAL_FRAMEWORK Metal REQUIRED) | ||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) | ||
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) | ||
# mnist-train | ||
|
||
set(TEST_TARGET mnist-mtl) | ||
add_executable(${TEST_TARGET} main-mtl.cpp main-mtl.h main-mtl.m) | ||
target_link_libraries(${TEST_TARGET} PRIVATE | ||
ggml | ||
${FOUNDATION_LIBRARY} | ||
${METAL_FRAMEWORK} | ||
${METALKIT_FRAMEWORK} | ||
${METALPERFORMANCE_FRAMEWORK} | ||
) | ||
endif() | ||
set(TEST_TARGET mnist-train) | ||
add_executable(${TEST_TARGET} mnist-train.cpp) | ||
target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,119 +1,187 @@ | ||
# MNIST Examples for GGML | ||
|
||
These are simple examples of how to use GGML for inferencing. | ||
The first example uses convolutional neural network (CNN), the second one uses fully connected neural network. | ||
This directory contains simple examples of how to use GGML for training and inference using the [MNIST dataset](https://yann.lecun.com/exdb/mnist/). | ||
All commands listed in this README assume the working directory to be `examples/mnist`. | ||
Please note that training in GGML is a work-in-progress and not production ready. | ||
|
||
## MNIST with CNN | ||
## Obtaining the data | ||
|
||
This implementation achieves ~99% accuracy on the MNIST test set. | ||
The data can either be downloaded [here](https://yann.lecun.com/exdb/mnist/) or it will be downloaded automatically when running `mnist-train-fc.py`. | ||
|
||
### Training the model | ||
## Fully connected network | ||
|
||
Setup the Python environemt and build the examples according to the main README. | ||
Use the `mnist-cnn.py` script to train the model and convert it to GGUF format: | ||
For our first example we will train a fully connected network. | ||
To train a fully connected model in PyTorch and save it as a GGUF file, run: | ||
|
||
```bash | ||
$ python3 ../examples/mnist/mnist-cnn.py train mnist-cnn-model | ||
$ python3 mnist-train-fc.py mnist-fc-f32.gguf | ||
|
||
... | ||
Keras model saved to 'mnist-cnn-model' | ||
``` | ||
|
||
Convert the model to GGUF format: | ||
Test loss: 0.069983+-0.009196, Test accuracy: 97.94+-0.14% | ||
|
||
```bash | ||
$ python3 ../examples/mnist/mnist-cnn.py convert mnist-cnn-model | ||
... | ||
Model converted and saved to 'mnist-cnn-model.gguf' | ||
Model tensors saved to mnist-fc-f32.gguf: | ||
fc1.weight (500, 784) | ||
fc1.bias (500,) | ||
fc2.weight (10, 500) | ||
fc2.bias (10,) | ||
``` | ||
|
||
### Running the example | ||
The training script includes an evaluation of the model on the test set. | ||
To evaluate the model using GGML, run: | ||
|
||
```bash | ||
$ ./bin/mnist-cnn mnist-cnn-model.gguf ../examples/mnist/models/mnist/t10k-images.idx3-ubyte | ||
main: loaded model in 5.17 ms | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * * * * _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ * * * * * * * * _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ * * * * * _ _ _ * * _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * * _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * * _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ * * * * * _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ * * * * * * * * * _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ * * * * * * * * * * _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ * * * * * * _ _ * * * _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ * * * _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ * * * * * * * * * * _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ * * * * * * _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ||
|
||
ggml_graph_dump_dot: dot -Tpng mnist-cnn.dot -o mnist-cnn.dot.png && open mnist-cnn.dot.png | ||
main: predicted digit is 8 | ||
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte | ||
|
||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________######__________________ | ||
____________________________########____________________ | ||
________________________########________________________ | ||
____________________########________________##__________ | ||
__________________######____________________##__________ | ||
________________######______________________####________ | ||
______________######________________________####________ | ||
____________######__________________________####________ | ||
____________####____________________________####________ | ||
__________####______________________________####________ | ||
__________####______________________________####________ | ||
__________##________________________________####________ | ||
__________##______________________________####__________ | ||
__________##____________________________######__________ | ||
__________##__________________________######____________ | ||
____________##____________________########______________ | ||
____________##########################__________________ | ||
______________##################________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
mnist_graph_eval: trying to load a ggml graph from mnist-fc-f32.gguf | ||
ggml_graph_import: invalid magic number, got 46554747 | ||
mnist_graph_eval: could not load a ggml graph from mnist-fc-f32.gguf | ||
mnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf' | ||
mnist_model_init_from_file: model arch is mnist-fc | ||
mnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf | ||
main: loaded model in 1.52 ms | ||
mnist_model_eval: model evaluation on 10000 images took 26.65 ms, 2.66 us/image | ||
main: predicted digit is 0 | ||
main: test_loss=0.069983+-0.009196 | ||
main: test_acc=97.94+-0.14% | ||
``` | ||
|
||
Computation graph: | ||
In addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image. | ||
To train a fully connected model using GGML run: | ||
|
||
![mnist dot](https://user-images.githubusercontent.com/1991296/263763842-3b679b45-7ca1-4ee9-b19a-82e34396624f.png) | ||
|
||
## MNIST with fully connected network | ||
|
||
A fully connected layer + relu, followed by a fully connected layer + softmax. | ||
|
||
### Training the Model | ||
``` bash | ||
$ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte | ||
``` | ||
|
||
A Google Colab notebook for training a simple two-layer network to recognize digits is located here. You can | ||
use this to save a pytorch model to be converted to ggml format. | ||
It can then be evaluated with the same binary as above. | ||
When training a model with GGML the computation graph for the forward pass is also exported to `mnist-fc-f32.ggml`. | ||
Compared to the GGUF (which only contains the weights) this file also contains the model architecture. | ||
As long as the input and output tensors are well-defined an exported GGML graph is fully agnostic w.r.t. the model architecture. | ||
It can be evaluated using the `mnist-eval` binary by substituting the argument for the GGUF file. | ||
|
||
[Colab](https://colab.research.google.com/drive/12n_8VNJnolBnX5dVS0HNWubnOjyEaFSb?usp=sharing) | ||
## Convolutional network | ||
|
||
GGML "format" is whatever you choose for efficient loading. In our case, we just save the hyperparameters used | ||
plus the model weights and biases. Run convert-h5-to-ggml.py to convert your pytorch model. The output format is: | ||
To train a convolutional network using TensorFlow run: | ||
|
||
- magic constant (int32) | ||
- repeated list of tensors | ||
- number of dimensions of tensor (int32) | ||
- tensor dimension (int32 repeated) | ||
- values of tensor (int32) | ||
```bash | ||
$ python3 mnist-train-cnn.py mnist-cnn-f32.gguf | ||
|
||
Run ```convert-h5-to-ggml.py mnist_model.state_dict``` where `mnist_model.state_dict` is the saved pytorch model from the Google Colab. For | ||
quickstart, it is included in the mnist/models directory. | ||
... | ||
|
||
```bash | ||
mkdir -p models/mnist | ||
python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict | ||
Test loss: 0.046456 | ||
Test accuracy: 98.40% | ||
GGUF model saved to 'mnist-cnn-f32.gguf' | ||
``` | ||
|
||
### Running the example | ||
The saved model can be evaluated using the `mnist-eval` binary: | ||
|
||
```bash | ||
./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte | ||
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte | ||
|
||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________####____________________________ | ||
__________________________##____________________________ | ||
__________________________##____________________________ | ||
__________________________##____________________________ | ||
__________________________##____________________________ | ||
__________________________##____________________________ | ||
____________________________##__________________________ | ||
____________________________##__________________________ | ||
____________________________##__________________________ | ||
______________________________##________________________ | ||
______________________________##________________________ | ||
______________________________####______________________ | ||
________________________________##______________________ | ||
________________________________##______________________ | ||
________________________________####____________________ | ||
__________________________________##____________________ | ||
________________________________##______________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
________________________________________________________ | ||
mnist_graph_eval: trying to load a ggml graph from mnist-cnn-f32.gguf | ||
ggml_graph_import: invalid magic number, got 46554747 | ||
mnist_graph_eval: could not load a ggml graph from mnist-cnn-f32.gguf | ||
mnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf' | ||
mnist_model_init_from_file: model arch is mnist-cnn | ||
mnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf | ||
main: loaded model in 5.45 ms | ||
mnist_model_eval: model evaluation on 10000 images took 605.60 ms, 60.56 us/image | ||
main: predicted digit is 1 | ||
main: test_loss=0.046456+-0.007354 | ||
main: test_acc=98.40+-0.13% | ||
``` | ||
|
||
Computation graph: | ||
Like with the fully connected network the convolutional network can also be trained using GGML: | ||
|
||
![mnist dot](https://user-images.githubusercontent.com/1991296/231882071-84e29d53-b226-4d73-bdc2-5bd6dcb7efd1.png) | ||
``` bash | ||
$ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte | ||
``` | ||
|
||
As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`. | ||
|
||
## Web demo | ||
|
||
The example can be compiled with Emscripten like this: | ||
The evaluation code can be compiled to WebAssembly using [Emscripten](https://emscripten.org/) (may need to re-login to update `$PATH` after installation). | ||
First, copy the GGUF file of either of the trained models to `examples/mnist` and name it `mnist-f32.gguf`. | ||
Copy the test set to `examples/mnist` and name it `t10k-images-idx3-ubyte`. | ||
Symlinking these files will *not* work! | ||
Compile the code like so: | ||
|
||
```bash | ||
cd examples/mnist | ||
emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c main.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file models/mnist | ||
$ emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c ../../src/ggml-aarch64.c mnist-common.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file mnist-f32.gguf --preload-file t10k-images-idx3-ubyte | ||
``` | ||
|
||
The compilation output is in `examples/mnist/web`. | ||
To run it, you need an HTTP server. | ||
For example: | ||
|
||
``` bash | ||
$ cd web | ||
$ python3 -m http.server | ||
|
||
Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ... | ||
``` | ||
|
||
Online demo: https://mnist.ggerganov.com | ||
The web demo can then be accessed via the link printed on the console. | ||
Simply draw a digit on the canvas and the model will try to predict what it's supposed to be. | ||
Alternatively, click the "Random" button to retrieve a random digit from the test set. | ||
Be aware that like all neural networks the one we trained is susceptible to distributional shift: | ||
if the numbers you draw look different than the ones in the training set | ||
(e.g. because they're not centered) the model will perform comparatively worse. | ||
An online demo can be accessed [here](https://mnist.ggerganov.com). |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.